Repository: unslothai/unsloth Branch: main Commit: d0e5a1d61e5c Files: 759 Total size: 5.8 MB Directory structure: gitextract_mbg5_5ju/ ├── .gitattributes ├── .github/ │ ├── CODEOWNERS │ ├── FUNDING.yml │ ├── ISSUE_TEMPLATE/ │ │ ├── bug---issue.md │ │ └── feature-request.md │ └── workflows/ │ └── stale.yml ├── .gitignore ├── .pre-commit-ci.yaml ├── .pre-commit-config.yaml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── COPYING ├── LICENSE ├── README.md ├── build.sh ├── cli.py ├── install.ps1 ├── install.sh ├── pyproject.toml ├── scripts/ │ ├── enforce_kwargs_spacing.py │ └── run_ruff_format.py ├── studio/ │ ├── LICENSE.AGPL-3.0 │ ├── Unsloth_Studio_Colab.ipynb │ ├── __init__.py │ ├── backend/ │ │ ├── __init__.py │ │ ├── assets/ │ │ │ ├── __init__.py │ │ │ └── configs/ │ │ │ ├── __init__.py │ │ │ ├── full_finetune.yaml │ │ │ ├── inference_defaults.json │ │ │ ├── lora_text.yaml │ │ │ ├── model_defaults/ │ │ │ │ ├── default.yaml │ │ │ │ ├── embedding/ │ │ │ │ │ ├── unsloth_Qwen3-Embedding-0.6B.yaml │ │ │ │ │ ├── unsloth_all-MiniLM-L6-v2.yaml │ │ │ │ │ ├── unsloth_bge-m3.yaml │ │ │ │ │ ├── unsloth_embeddinggemma-300m.yaml │ │ │ │ │ └── unsloth_gte-modernbert-base.yaml │ │ │ │ ├── ernie/ │ │ │ │ │ ├── unsloth_ERNIE-4.5-21B-A3B-PT.yaml │ │ │ │ │ └── unsloth_ERNIE-4.5-VL-28B-A3B-PT.yaml │ │ │ │ ├── falcon/ │ │ │ │ │ └── tiiuae_Falcon-H1-0.5B-Instruct.yaml │ │ │ │ ├── gemma/ │ │ │ │ │ ├── unsloth_codegemma-7b-bnb-4bit.yaml │ │ │ │ │ ├── unsloth_functiongemma-270m-it.yaml │ │ │ │ │ ├── unsloth_gemma-2-27b-bnb-4bit.yaml │ │ │ │ │ ├── unsloth_gemma-2-2b.yaml │ │ │ │ │ ├── unsloth_gemma-3-270m-it.yaml │ │ │ │ │ ├── unsloth_gemma-3-27b-it.yaml │ │ │ │ │ ├── unsloth_gemma-3-4b-it.yaml │ │ │ │ │ ├── unsloth_gemma-3-4b-pt.yaml │ │ │ │ │ ├── unsloth_gemma-3n-E4B-it.yaml │ │ │ │ │ └── unsloth_gemma-3n-E4B.yaml │ │ │ │ ├── gpt-oss/ │ │ │ │ │ ├── unsloth_gpt-oss-120b.yaml │ │ │ │ │ └── unsloth_gpt-oss-20b.yaml │ │ │ │ ├── granite/ │ │ │ │ │ ├── unsloth_granite-4.0-350m-unsloth-bnb-4bit.yaml │ │ │ │ │ └── unsloth_granite-4.0-h-micro.yaml │ │ │ │ ├── llama/ │ │ │ │ │ ├── unsloth_Llama-3.2-11B-Vision-Instruct.yaml │ │ │ │ │ ├── unsloth_Llama-3.2-1B-Instruct.yaml │ │ │ │ │ ├── unsloth_Llama-3.2-3B-Instruct.yaml │ │ │ │ │ ├── unsloth_Llama-3.3-70B-Instruct.yaml │ │ │ │ │ ├── unsloth_Meta-Llama-3.1-70B-bnb-4bit.yaml │ │ │ │ │ ├── unsloth_Meta-Llama-3.1-8B-Instruct-bnb-4bit.yaml │ │ │ │ │ ├── unsloth_llama-3-8b-Instruct-bnb-4bit.yaml │ │ │ │ │ └── unsloth_llama-3-8b-bnb-4bit.yaml │ │ │ │ ├── llasa/ │ │ │ │ │ └── unsloth_Llasa-3B.yaml │ │ │ │ ├── mistral/ │ │ │ │ │ ├── unsloth_Magistral-Small-2509-unsloth-bnb-4bit.yaml │ │ │ │ │ ├── unsloth_Ministral-3-3B-Instruct-2512.yaml │ │ │ │ │ ├── unsloth_Mistral-Nemo-Base-2407-bnb-4bit.yaml │ │ │ │ │ ├── unsloth_Mistral-Small-Instruct-2409.yaml │ │ │ │ │ ├── unsloth_Pixtral-12B-2409.yaml │ │ │ │ │ ├── unsloth_mistral-7b-instruct-v0.3-bnb-4bit.yaml │ │ │ │ │ └── unsloth_mistral-7b-v0.3-bnb-4bit.yaml │ │ │ │ ├── other/ │ │ │ │ │ ├── OuteAI_Llama-OuteTTS-1.0-1B.yaml │ │ │ │ │ ├── Spark-TTS-0.5B_LLM.yaml │ │ │ │ │ ├── sesame_csm-1b.yaml │ │ │ │ │ ├── unsloth_GLM-4.7-Flash.yaml │ │ │ │ │ ├── unsloth_LFM2-1.2B.yaml │ │ │ │ │ ├── unsloth_Nemotron-3-Nano-30B-A3B.yaml │ │ │ │ │ ├── unsloth_PaddleOCR-VL.yaml │ │ │ │ │ ├── unsloth_answerdotai_ModernBERT-large.yaml │ │ │ │ │ ├── unsloth_orpheus-3b-0.1-ft.yaml │ │ │ │ │ ├── unsloth_tinyllama-bnb-4bit.yaml │ │ │ │ │ └── unsloth_whisper-large-v3.yaml │ │ │ │ ├── phi/ │ │ │ │ │ ├── unsloth_Phi-3-medium-4k-instruct.yaml │ │ │ │ │ ├── unsloth_Phi-3.5-mini-instruct.yaml │ │ │ │ │ └── unsloth_Phi-4.yaml │ │ │ │ └── qwen/ │ │ │ │ ├── imdatta0_tiny_qwen3_moe_2.8B_0.7B.yaml │ │ │ │ ├── unsloth_Qwen2-7B.yaml │ │ │ │ ├── unsloth_Qwen2-VL-7B-Instruct.yaml │ │ │ │ ├── unsloth_Qwen2.5-1.5B-Instruct.yaml │ │ │ │ ├── unsloth_Qwen2.5-7B.yaml │ │ │ │ ├── unsloth_Qwen2.5-Coder-1.5B-Instruct.yaml │ │ │ │ ├── unsloth_Qwen2.5-Coder-14B-Instruct.yaml │ │ │ │ ├── unsloth_Qwen2.5-Coder-7B-Instruct-bnb-4bit.yaml │ │ │ │ ├── unsloth_Qwen2.5-VL-7B-Instruct-bnb-4bit.yaml │ │ │ │ ├── unsloth_Qwen3-0.6B.yaml │ │ │ │ ├── unsloth_Qwen3-14B-Base-unsloth-bnb-4bit.yaml │ │ │ │ ├── unsloth_Qwen3-14B.yaml │ │ │ │ ├── unsloth_Qwen3-30B-A3B-Instruct-2507.yaml │ │ │ │ ├── unsloth_Qwen3-32B.yaml │ │ │ │ ├── unsloth_Qwen3-4B-Instruct-2507.yaml │ │ │ │ ├── unsloth_Qwen3-4B-Thinking-2507.yaml │ │ │ │ └── unsloth_Qwen3-VL-8B-Instruct-unsloth-bnb-4bit.yaml │ │ │ └── vision_lora.yaml │ │ ├── auth/ │ │ │ ├── .gitkeep │ │ │ ├── __init__.py │ │ │ ├── authentication.py │ │ │ ├── hashing.py │ │ │ └── storage.py │ │ ├── colab.py │ │ ├── core/ │ │ │ ├── __init__.py │ │ │ ├── data_recipe/ │ │ │ │ ├── __init__.py │ │ │ │ ├── huggingface.py │ │ │ │ ├── jobs/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── constants.py │ │ │ │ │ ├── manager.py │ │ │ │ │ ├── parse.py │ │ │ │ │ ├── types.py │ │ │ │ │ └── worker.py │ │ │ │ ├── jsonable.py │ │ │ │ ├── local_callable_validators.py │ │ │ │ ├── oxc-validator/ │ │ │ │ │ ├── package.json │ │ │ │ │ └── validate.mjs │ │ │ │ └── service.py │ │ │ ├── export/ │ │ │ │ ├── __init__.py │ │ │ │ ├── export.py │ │ │ │ ├── orchestrator.py │ │ │ │ └── worker.py │ │ │ ├── inference/ │ │ │ │ ├── __init__.py │ │ │ │ ├── audio_codecs.py │ │ │ │ ├── defaults.py │ │ │ │ ├── inference.py │ │ │ │ ├── llama_cpp.py │ │ │ │ ├── orchestrator.py │ │ │ │ ├── tools.py │ │ │ │ └── worker.py │ │ │ └── training/ │ │ │ ├── __init__.py │ │ │ ├── trainer.py │ │ │ ├── training.py │ │ │ └── worker.py │ │ ├── loggers/ │ │ │ ├── .gitkeep │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ └── handlers.py │ │ ├── main.py │ │ ├── models/ │ │ │ ├── .gitkeep │ │ │ ├── __init__.py │ │ │ ├── auth.py │ │ │ ├── data_recipe.py │ │ │ ├── datasets.py │ │ │ ├── export.py │ │ │ ├── inference.py │ │ │ ├── models.py │ │ │ ├── responses.py │ │ │ ├── training.py │ │ │ └── users.py │ │ ├── plugins/ │ │ │ ├── __init__.py │ │ │ └── data-designer-unstructured-seed/ │ │ │ ├── __init__.py │ │ │ ├── pyproject.toml │ │ │ └── src/ │ │ │ └── data_designer_unstructured_seed/ │ │ │ ├── __init__.py │ │ │ ├── chunking.py │ │ │ ├── config.py │ │ │ ├── impl.py │ │ │ └── plugin.py │ │ ├── requirements/ │ │ │ ├── __init__.py │ │ │ ├── base.txt │ │ │ ├── extras-no-deps.txt │ │ │ ├── extras.txt │ │ │ ├── overrides.txt │ │ │ ├── single-env/ │ │ │ │ ├── constraints.txt │ │ │ │ ├── data-designer-deps.txt │ │ │ │ ├── data-designer.txt │ │ │ │ └── patch_metadata.py │ │ │ ├── studio.txt │ │ │ └── triton-kernels.txt │ │ ├── routes/ │ │ │ ├── .gitkeep │ │ │ ├── __init__.py │ │ │ ├── auth.py │ │ │ ├── data_recipe/ │ │ │ │ ├── __init__.py │ │ │ │ ├── jobs.py │ │ │ │ ├── mcp.py │ │ │ │ ├── seed.py │ │ │ │ └── validate.py │ │ │ ├── datasets.py │ │ │ ├── export.py │ │ │ ├── inference.py │ │ │ ├── models.py │ │ │ └── training.py │ │ ├── run.py │ │ ├── state/ │ │ │ ├── .gitkeep │ │ │ └── __init__.py │ │ ├── tests/ │ │ │ ├── __init__.py │ │ │ ├── conftest.py │ │ │ ├── test_data_recipe_seed.py │ │ │ └── test_utils.py │ │ └── utils/ │ │ ├── .gitkeep │ │ ├── __init__.py │ │ ├── cache_cleanup.py │ │ ├── datasets/ │ │ │ ├── __init__.py │ │ │ ├── chat_templates.py │ │ │ ├── data_collators.py │ │ │ ├── dataset_utils.py │ │ │ ├── format_conversion.py │ │ │ ├── format_detection.py │ │ │ ├── llm_assist.py │ │ │ ├── model_mappings.py │ │ │ └── vlm_processing.py │ │ ├── hardware/ │ │ │ ├── __init__.py │ │ │ └── hardware.py │ │ ├── inference/ │ │ │ ├── __init__.py │ │ │ └── inference_config.py │ │ ├── models/ │ │ │ ├── __init__.py │ │ │ ├── checkpoints.py │ │ │ └── model_config.py │ │ ├── paths/ │ │ │ ├── __init__.py │ │ │ ├── path_utils.py │ │ │ └── storage_roots.py │ │ ├── transformers_version.py │ │ └── utils.py │ ├── frontend/ │ │ ├── .gitignore │ │ ├── .gitkeep │ │ ├── biome.json │ │ ├── components.json │ │ ├── data-designer.openapi (1).yaml │ │ ├── eslint.config.js │ │ ├── index.html │ │ ├── package.json │ │ ├── public/ │ │ │ └── Hellix font official/ │ │ │ └── OTF/ │ │ │ └── Hellix-SemiBold.otf │ │ ├── src/ │ │ │ ├── app/ │ │ │ │ ├── app.tsx │ │ │ │ ├── auth-guards.ts │ │ │ │ ├── provider.tsx │ │ │ │ ├── router.tsx │ │ │ │ └── routes/ │ │ │ │ ├── __root.tsx │ │ │ │ ├── change-password.tsx │ │ │ │ ├── chat.tsx │ │ │ │ ├── data-recipes.$recipeId.tsx │ │ │ │ ├── data-recipes.tsx │ │ │ │ ├── export.tsx │ │ │ │ ├── grid-test.tsx │ │ │ │ ├── index.tsx │ │ │ │ ├── login.tsx │ │ │ │ ├── onboarding.tsx │ │ │ │ └── studio.tsx │ │ │ ├── components/ │ │ │ │ ├── assistant-ui/ │ │ │ │ │ ├── attachment.tsx │ │ │ │ │ ├── audio-player.tsx │ │ │ │ │ ├── badge.tsx │ │ │ │ │ ├── markdown-text.tsx │ │ │ │ │ ├── message-timing.tsx │ │ │ │ │ ├── model-selector/ │ │ │ │ │ │ ├── pickers.tsx │ │ │ │ │ │ └── types.ts │ │ │ │ │ ├── model-selector.tsx │ │ │ │ │ ├── reasoning.tsx │ │ │ │ │ ├── sources.tsx │ │ │ │ │ ├── thread.tsx │ │ │ │ │ ├── tool-fallback.tsx │ │ │ │ │ ├── tool-group.tsx │ │ │ │ │ ├── tool-ui-python.tsx │ │ │ │ │ ├── tool-ui-terminal.tsx │ │ │ │ │ ├── tool-ui-web-search.tsx │ │ │ │ │ └── tooltip-icon-button.tsx │ │ │ │ ├── example.tsx │ │ │ │ ├── layout/ │ │ │ │ │ ├── dashboard-grid.tsx │ │ │ │ │ ├── dashboard-layout.tsx │ │ │ │ │ └── index.ts │ │ │ │ ├── markdown/ │ │ │ │ │ ├── markdown-preview.tsx │ │ │ │ │ └── mermaid-error.tsx │ │ │ │ ├── navbar.tsx │ │ │ │ ├── section-card.tsx │ │ │ │ └── ui/ │ │ │ │ ├── accordion.tsx │ │ │ │ ├── alert-dialog.tsx │ │ │ │ ├── alert.tsx │ │ │ │ ├── animated-shiny-text.tsx │ │ │ │ ├── animated-theme-toggler.tsx │ │ │ │ ├── aspect-ratio.tsx │ │ │ │ ├── avatar.tsx │ │ │ │ ├── badge.tsx │ │ │ │ ├── breadcrumb.tsx │ │ │ │ ├── button.tsx │ │ │ │ ├── calendar.tsx │ │ │ │ ├── card.tsx │ │ │ │ ├── chart.tsx │ │ │ │ ├── checkbox.tsx │ │ │ │ ├── collapsible.tsx │ │ │ │ ├── combobox.tsx │ │ │ │ ├── command.tsx │ │ │ │ ├── confetti.tsx │ │ │ │ ├── context-menu.tsx │ │ │ │ ├── data-table.tsx │ │ │ │ ├── dialog.tsx │ │ │ │ ├── dropdown-menu.tsx │ │ │ │ ├── empty.tsx │ │ │ │ ├── field.tsx │ │ │ │ ├── hover-card.tsx │ │ │ │ ├── input-group.tsx │ │ │ │ ├── input.tsx │ │ │ │ ├── label.tsx │ │ │ │ ├── light-rays.tsx │ │ │ │ ├── menubar.tsx │ │ │ │ ├── navigation-menu.tsx │ │ │ │ ├── pagination.tsx │ │ │ │ ├── popover.tsx │ │ │ │ ├── progress.tsx │ │ │ │ ├── radio-group.tsx │ │ │ │ ├── resizable.tsx │ │ │ │ ├── scroll-area.tsx │ │ │ │ ├── select.tsx │ │ │ │ ├── separator.tsx │ │ │ │ ├── sheet.tsx │ │ │ │ ├── shine-border.tsx │ │ │ │ ├── sidebar.tsx │ │ │ │ ├── skeleton.tsx │ │ │ │ ├── slider.tsx │ │ │ │ ├── sonner.tsx │ │ │ │ ├── sparkles-text.tsx │ │ │ │ ├── spinner.tsx │ │ │ │ ├── switch.tsx │ │ │ │ ├── table.tsx │ │ │ │ ├── tabs.tsx │ │ │ │ ├── terminal.tsx │ │ │ │ ├── textarea.tsx │ │ │ │ ├── toggle-group.tsx │ │ │ │ ├── toggle.tsx │ │ │ │ └── tooltip.tsx │ │ │ ├── config/ │ │ │ │ ├── env.ts │ │ │ │ └── training.ts │ │ │ ├── features/ │ │ │ │ ├── auth/ │ │ │ │ │ ├── api.ts │ │ │ │ │ ├── change-password-page.tsx │ │ │ │ │ ├── components/ │ │ │ │ │ │ └── auth-form.tsx │ │ │ │ │ ├── index.ts │ │ │ │ │ ├── login-page.tsx │ │ │ │ │ └── session.ts │ │ │ │ ├── chat/ │ │ │ │ │ ├── api/ │ │ │ │ │ │ ├── chat-adapter.ts │ │ │ │ │ │ └── chat-api.ts │ │ │ │ │ ├── chat-page.tsx │ │ │ │ │ ├── chat-settings-sheet.tsx │ │ │ │ │ ├── components/ │ │ │ │ │ │ └── model-load-status.tsx │ │ │ │ │ ├── db.ts │ │ │ │ │ ├── hooks/ │ │ │ │ │ │ └── use-chat-model-runtime.ts │ │ │ │ │ ├── index.ts │ │ │ │ │ ├── runtime-provider.tsx │ │ │ │ │ ├── shared-composer.tsx │ │ │ │ │ ├── stores/ │ │ │ │ │ │ └── chat-runtime-store.ts │ │ │ │ │ ├── thread-sidebar.tsx │ │ │ │ │ ├── tour/ │ │ │ │ │ │ ├── index.ts │ │ │ │ │ │ └── steps.tsx │ │ │ │ │ ├── types/ │ │ │ │ │ │ ├── api.ts │ │ │ │ │ │ └── runtime.ts │ │ │ │ │ ├── types.ts │ │ │ │ │ └── utils/ │ │ │ │ │ └── parse-assistant-content.ts │ │ │ │ ├── data-recipes/ │ │ │ │ │ ├── data/ │ │ │ │ │ │ └── recipes-db.ts │ │ │ │ │ ├── index.ts │ │ │ │ │ ├── learning-recipes/ │ │ │ │ │ │ ├── conversation.json │ │ │ │ │ │ ├── index.ts │ │ │ │ │ │ ├── instruction-from-answer.json │ │ │ │ │ │ ├── ocr-document-extraction.json │ │ │ │ │ │ ├── pdf-grounded-qa.json │ │ │ │ │ │ ├── structured-outputs-jinja.json │ │ │ │ │ │ ├── text-to-python.json │ │ │ │ │ │ └── text-to-sql.json │ │ │ │ │ ├── pages/ │ │ │ │ │ │ ├── data-recipes-page.tsx │ │ │ │ │ │ └── edit-recipe-page.tsx │ │ │ │ │ └── types.ts │ │ │ │ ├── export/ │ │ │ │ │ ├── anim.ts │ │ │ │ │ ├── api/ │ │ │ │ │ │ └── export-api.ts │ │ │ │ │ ├── components/ │ │ │ │ │ │ ├── export-dialog.tsx │ │ │ │ │ │ ├── method-picker.tsx │ │ │ │ │ │ └── quant-picker.tsx │ │ │ │ │ ├── constants.ts │ │ │ │ │ ├── export-page.tsx │ │ │ │ │ ├── index.ts │ │ │ │ │ └── tour/ │ │ │ │ │ ├── index.ts │ │ │ │ │ └── steps.tsx │ │ │ │ ├── onboarding/ │ │ │ │ │ ├── components/ │ │ │ │ │ │ ├── splash-screen.tsx │ │ │ │ │ │ ├── steps/ │ │ │ │ │ │ │ ├── dataset-step.tsx │ │ │ │ │ │ │ ├── hyperparameters-step.tsx │ │ │ │ │ │ │ ├── model-selection-step.tsx │ │ │ │ │ │ │ ├── model-type-step.tsx │ │ │ │ │ │ │ └── summary-step.tsx │ │ │ │ │ │ ├── wizard-content.tsx │ │ │ │ │ │ ├── wizard-footer.tsx │ │ │ │ │ │ ├── wizard-layout.tsx │ │ │ │ │ │ ├── wizard-sidebar.tsx │ │ │ │ │ │ └── wizard-step-item.tsx │ │ │ │ │ └── index.ts │ │ │ │ ├── recipe-studio/ │ │ │ │ │ ├── api/ │ │ │ │ │ │ └── index.ts │ │ │ │ │ ├── blocks/ │ │ │ │ │ │ ├── definitions.ts │ │ │ │ │ │ ├── registry.ts │ │ │ │ │ │ └── render-dialog.tsx │ │ │ │ │ ├── components/ │ │ │ │ │ │ ├── block-sheet.tsx │ │ │ │ │ │ ├── chip-input.tsx │ │ │ │ │ │ ├── controls/ │ │ │ │ │ │ │ ├── layout-controls.tsx │ │ │ │ │ │ │ ├── run-validate-floating-controls.tsx │ │ │ │ │ │ │ └── viewport-controls.tsx │ │ │ │ │ │ ├── executions/ │ │ │ │ │ │ │ ├── execution-columns-tab.tsx │ │ │ │ │ │ │ ├── execution-data-tab.tsx │ │ │ │ │ │ │ ├── execution-overview-tab.tsx │ │ │ │ │ │ │ ├── execution-raw-tab.tsx │ │ │ │ │ │ │ ├── execution-sidebar.tsx │ │ │ │ │ │ │ ├── executions-view-helpers.ts │ │ │ │ │ │ │ ├── executions-view.tsx │ │ │ │ │ │ │ └── publish-execution-dialog.tsx │ │ │ │ │ │ ├── graph/ │ │ │ │ │ │ │ └── internals-sync.tsx │ │ │ │ │ │ ├── inline/ │ │ │ │ │ │ │ ├── inline-category-badges.tsx │ │ │ │ │ │ │ ├── inline-expression.tsx │ │ │ │ │ │ │ ├── inline-field.tsx │ │ │ │ │ │ │ ├── inline-llm.tsx │ │ │ │ │ │ │ ├── inline-model.tsx │ │ │ │ │ │ │ ├── inline-policy.ts │ │ │ │ │ │ │ ├── inline-sampler.tsx │ │ │ │ │ │ │ └── inline-seed.tsx │ │ │ │ │ │ ├── recipe-floating-icon-button-class.ts │ │ │ │ │ │ ├── recipe-graph-aux-node.tsx │ │ │ │ │ │ ├── recipe-graph-node.tsx │ │ │ │ │ │ ├── recipe-graph-semantic-edge.tsx │ │ │ │ │ │ ├── recipe-studio-header.tsx │ │ │ │ │ │ ├── rf-ui/ │ │ │ │ │ │ │ ├── base-handle.tsx │ │ │ │ │ │ │ ├── base-node.tsx │ │ │ │ │ │ │ ├── data-edge.tsx │ │ │ │ │ │ │ └── labeled-handle.tsx │ │ │ │ │ │ ├── runtime/ │ │ │ │ │ │ │ └── execution-progress-island.tsx │ │ │ │ │ │ └── shared/ │ │ │ │ │ │ ├── available-references-inline.tsx │ │ │ │ │ │ └── hf-dataset-combobox.tsx │ │ │ │ │ ├── constants.ts │ │ │ │ │ ├── data/ │ │ │ │ │ │ └── executions-db.ts │ │ │ │ │ ├── dialogs/ │ │ │ │ │ │ ├── config-dialog.tsx │ │ │ │ │ │ ├── expression/ │ │ │ │ │ │ │ └── expression-dialog.tsx │ │ │ │ │ │ ├── import-dialog.tsx │ │ │ │ │ │ ├── llm/ │ │ │ │ │ │ │ ├── general-tab.tsx │ │ │ │ │ │ │ ├── llm-dialog.tsx │ │ │ │ │ │ │ └── scores-tab.tsx │ │ │ │ │ │ ├── markdown-note/ │ │ │ │ │ │ │ └── markdown-note-dialog.tsx │ │ │ │ │ │ ├── models/ │ │ │ │ │ │ │ ├── model-config-dialog.tsx │ │ │ │ │ │ │ └── model-provider-dialog.tsx │ │ │ │ │ │ ├── preview-dialog.tsx │ │ │ │ │ │ ├── processors-dialog.tsx │ │ │ │ │ │ ├── samplers/ │ │ │ │ │ │ │ ├── bernoulli-dialog.tsx │ │ │ │ │ │ │ ├── category-dialog.tsx │ │ │ │ │ │ │ ├── datetime-dialog.tsx │ │ │ │ │ │ │ ├── gaussian-dialog.tsx │ │ │ │ │ │ │ ├── person-dialog.tsx │ │ │ │ │ │ │ ├── subcategory-dialog.tsx │ │ │ │ │ │ │ ├── timedelta-dialog.tsx │ │ │ │ │ │ │ ├── uniform-dialog.tsx │ │ │ │ │ │ │ └── uuid-dialog.tsx │ │ │ │ │ │ ├── seed/ │ │ │ │ │ │ │ └── seed-dialog.tsx │ │ │ │ │ │ ├── shared/ │ │ │ │ │ │ │ ├── available-variables.tsx │ │ │ │ │ │ │ ├── collapsible-section-trigger.tsx │ │ │ │ │ │ │ ├── dialog-shell.tsx │ │ │ │ │ │ │ ├── field-label.tsx │ │ │ │ │ │ │ ├── name-field.tsx │ │ │ │ │ │ │ └── validation-banner.tsx │ │ │ │ │ │ ├── tool-profile/ │ │ │ │ │ │ │ ├── helpers.ts │ │ │ │ │ │ │ └── tool-profile-dialog.tsx │ │ │ │ │ │ └── validators/ │ │ │ │ │ │ └── validator-dialog.tsx │ │ │ │ │ ├── execution-types.ts │ │ │ │ │ ├── executions/ │ │ │ │ │ │ ├── execution-helpers.ts │ │ │ │ │ │ ├── hydration.ts │ │ │ │ │ │ ├── run-settings.ts │ │ │ │ │ │ ├── runtime.ts │ │ │ │ │ │ └── tracker.ts │ │ │ │ │ ├── hooks/ │ │ │ │ │ │ ├── use-node-connection-status.ts │ │ │ │ │ │ ├── use-recipe-editor-graph.ts │ │ │ │ │ │ ├── use-recipe-executions.ts │ │ │ │ │ │ ├── use-recipe-persistence.ts │ │ │ │ │ │ ├── use-recipe-runtime-visuals.ts │ │ │ │ │ │ └── use-recipe-studio-actions.ts │ │ │ │ │ ├── index.ts │ │ │ │ │ ├── recipe-studio-page.tsx │ │ │ │ │ ├── stores/ │ │ │ │ │ │ ├── helpers/ │ │ │ │ │ │ │ ├── edge-sync.ts │ │ │ │ │ │ │ ├── model-infra-layout.ts │ │ │ │ │ │ │ ├── node-updates.ts │ │ │ │ │ │ │ ├── reference-sync.ts │ │ │ │ │ │ │ └── removals.ts │ │ │ │ │ │ ├── recipe-executions.ts │ │ │ │ │ │ ├── recipe-studio-helpers.ts │ │ │ │ │ │ └── recipe-studio.ts │ │ │ │ │ ├── types/ │ │ │ │ │ │ └── index.ts │ │ │ │ │ └── utils/ │ │ │ │ │ ├── config-factories.ts │ │ │ │ │ ├── config-labels.ts │ │ │ │ │ ├── config-type-guards.ts │ │ │ │ │ ├── graph/ │ │ │ │ │ │ ├── derive-display-graph.ts │ │ │ │ │ │ ├── fit-view.ts │ │ │ │ │ │ ├── recipe-graph-connection.ts │ │ │ │ │ │ ├── relations.ts │ │ │ │ │ │ └── runtime-visual-state.ts │ │ │ │ │ ├── graph-warnings.ts │ │ │ │ │ ├── graph.ts │ │ │ │ │ ├── handle-layout.ts │ │ │ │ │ ├── handles.ts │ │ │ │ │ ├── image-preview.ts │ │ │ │ │ ├── import/ │ │ │ │ │ │ ├── edges.ts │ │ │ │ │ │ ├── helpers.ts │ │ │ │ │ │ ├── importer.ts │ │ │ │ │ │ ├── index.ts │ │ │ │ │ │ ├── parsers/ │ │ │ │ │ │ │ ├── expression-parser.ts │ │ │ │ │ │ │ ├── llm-parser.ts │ │ │ │ │ │ │ ├── model-parser.ts │ │ │ │ │ │ │ ├── sampler-parser.ts │ │ │ │ │ │ │ ├── seed-config-parser.ts │ │ │ │ │ │ │ └── validator-parser.ts │ │ │ │ │ │ ├── parsers.ts │ │ │ │ │ │ ├── types.ts │ │ │ │ │ │ └── ui.ts │ │ │ │ │ ├── index.ts │ │ │ │ │ ├── layout.ts │ │ │ │ │ ├── naming.ts │ │ │ │ │ ├── node-data.ts │ │ │ │ │ ├── parse.ts │ │ │ │ │ ├── payload/ │ │ │ │ │ │ ├── build-payload.ts │ │ │ │ │ │ ├── builders-llm.ts │ │ │ │ │ │ ├── builders-model.ts │ │ │ │ │ │ ├── builders-processors.ts │ │ │ │ │ │ ├── builders-sampler.ts │ │ │ │ │ │ ├── builders-seed.ts │ │ │ │ │ │ ├── builders-validator.ts │ │ │ │ │ │ ├── builders.ts │ │ │ │ │ │ ├── empty.ts │ │ │ │ │ │ ├── index.ts │ │ │ │ │ │ ├── parse.ts │ │ │ │ │ │ ├── types.ts │ │ │ │ │ │ └── validate.ts │ │ │ │ │ ├── processors.ts │ │ │ │ │ ├── reactflow-changes.ts │ │ │ │ │ ├── recipe-studio-view.ts │ │ │ │ │ ├── refs.ts │ │ │ │ │ ├── rf-node-dimensions.ts │ │ │ │ │ ├── ui-tones.ts │ │ │ │ │ ├── validation.ts │ │ │ │ │ ├── validators/ │ │ │ │ │ │ ├── code-lang.ts │ │ │ │ │ │ ├── oxc-code-shape.ts │ │ │ │ │ │ └── oxc-mode.ts │ │ │ │ │ └── variables.ts │ │ │ │ ├── studio/ │ │ │ │ │ ├── index.ts │ │ │ │ │ ├── sections/ │ │ │ │ │ │ ├── charts/ │ │ │ │ │ │ │ ├── chart-preferences-store.ts │ │ │ │ │ │ │ ├── chart-settings-sheet.tsx │ │ │ │ │ │ │ ├── eval-loss-chart-card.tsx │ │ │ │ │ │ │ ├── grad-norm-chart-card.tsx │ │ │ │ │ │ │ ├── learning-rate-chart-card.tsx │ │ │ │ │ │ │ ├── training-loss-chart-card.tsx │ │ │ │ │ │ │ ├── types.ts │ │ │ │ │ │ │ └── utils.ts │ │ │ │ │ │ ├── charts-content.tsx │ │ │ │ │ │ ├── charts-section.tsx │ │ │ │ │ │ ├── dataset-preview-dialog-mapping.tsx │ │ │ │ │ │ ├── dataset-preview-dialog-utils.ts │ │ │ │ │ │ ├── dataset-preview-dialog.tsx │ │ │ │ │ │ ├── dataset-section.tsx │ │ │ │ │ │ ├── document-upload-redirect-dialog.tsx │ │ │ │ │ │ ├── model-section.tsx │ │ │ │ │ │ ├── params-section.tsx │ │ │ │ │ │ ├── progress-section-lib.ts │ │ │ │ │ │ ├── progress-section.tsx │ │ │ │ │ │ └── training-section.tsx │ │ │ │ │ ├── studio-page.tsx │ │ │ │ │ ├── tour/ │ │ │ │ │ │ ├── index.ts │ │ │ │ │ │ ├── steps/ │ │ │ │ │ │ │ ├── base-model.tsx │ │ │ │ │ │ │ ├── dataset.tsx │ │ │ │ │ │ │ ├── index.tsx │ │ │ │ │ │ │ ├── local-model.tsx │ │ │ │ │ │ │ ├── method.tsx │ │ │ │ │ │ │ ├── nav.tsx │ │ │ │ │ │ │ ├── params.tsx │ │ │ │ │ │ │ ├── save.tsx │ │ │ │ │ │ │ └── start.tsx │ │ │ │ │ │ └── training/ │ │ │ │ │ │ ├── index.ts │ │ │ │ │ │ └── steps.tsx │ │ │ │ │ ├── training-start-overlay.tsx │ │ │ │ │ └── training-view.tsx │ │ │ │ ├── tour/ │ │ │ │ │ ├── components/ │ │ │ │ │ │ ├── guided-tour.tsx │ │ │ │ │ │ ├── read-more.tsx │ │ │ │ │ │ └── spotlight-overlay.tsx │ │ │ │ │ ├── hooks/ │ │ │ │ │ │ └── use-guided-tour-controller.ts │ │ │ │ │ ├── index.ts │ │ │ │ │ └── types.ts │ │ │ │ └── training/ │ │ │ │ ├── api/ │ │ │ │ │ ├── datasets-api.ts │ │ │ │ │ ├── mappers.ts │ │ │ │ │ ├── models-api.ts │ │ │ │ │ └── train-api.ts │ │ │ │ ├── components/ │ │ │ │ │ └── hf-dataset-subset-split-selectors.tsx │ │ │ │ ├── hooks/ │ │ │ │ │ ├── use-max-steps-epochs-toggle.ts │ │ │ │ │ ├── use-training-actions.ts │ │ │ │ │ └── use-training-runtime-lifecycle.ts │ │ │ │ ├── index.ts │ │ │ │ ├── stores/ │ │ │ │ │ ├── dataset-preview-dialog-store.ts │ │ │ │ │ ├── training-config-store.ts │ │ │ │ │ └── training-runtime-store.ts │ │ │ │ └── types/ │ │ │ │ ├── api.ts │ │ │ │ ├── config.ts │ │ │ │ ├── datasets.ts │ │ │ │ └── runtime.ts │ │ │ ├── hooks/ │ │ │ │ ├── index.ts │ │ │ │ ├── use-debounced-value.ts │ │ │ │ ├── use-gpu-info.ts │ │ │ │ ├── use-gpu-utilization.ts │ │ │ │ ├── use-hardware-info.ts │ │ │ │ ├── use-hf-dataset-search.ts │ │ │ │ ├── use-hf-dataset-splits.ts │ │ │ │ ├── use-hf-model-search.ts │ │ │ │ ├── use-hf-paginated-search.ts │ │ │ │ ├── use-hf-token-validation.ts │ │ │ │ ├── use-infinite-scroll.ts │ │ │ │ ├── use-mobile.ts │ │ │ │ └── use-recommended-model-vram.ts │ │ │ ├── index.css │ │ │ ├── main.tsx │ │ │ ├── shared/ │ │ │ │ └── toast.ts │ │ │ ├── speech-recognition.d.ts │ │ │ ├── stores/ │ │ │ │ ├── index.ts │ │ │ │ └── training.ts │ │ │ ├── types/ │ │ │ │ ├── index.ts │ │ │ │ └── training.ts │ │ │ └── utils/ │ │ │ ├── index.ts │ │ │ └── strings.ts │ │ ├── tsconfig.app.json │ │ ├── tsconfig.json │ │ ├── tsconfig.node.json │ │ └── vite.config.ts │ ├── install_python_stack.py │ ├── setup.bat │ ├── setup.ps1 │ └── setup.sh ├── tests/ │ ├── __init__.py │ ├── qlora/ │ │ ├── README.md │ │ ├── test_hf_qlora_train_and_merge.py │ │ └── test_unsloth_qlora_train_and_merge.py │ ├── saving/ │ │ ├── gpt-oss-merge/ │ │ │ ├── run_test.sh │ │ │ ├── test_merged_model.py │ │ │ └── train_and_merge.py │ │ ├── language_models/ │ │ │ ├── test_merge_4bit_validation.py │ │ │ ├── test_merge_model_perplexity_llama-3.2.py │ │ │ ├── test_merge_model_perplexity_mistral.py │ │ │ ├── test_merge_model_perplexity_phi_4.py │ │ │ ├── test_merged_model_perplexity_llama-3.1-8b.py │ │ │ ├── test_merged_model_perplexity_qwen_2.5.py │ │ │ ├── test_push_to_hub_merged.py │ │ │ ├── test_push_to_hub_merged_sharded_index_file.py │ │ │ └── test_save_merged_grpo_model.py │ │ ├── non_peft/ │ │ │ ├── test_mistral_non_peft.py │ │ │ └── test_whisper_non_peft.py │ │ ├── test_unsloth_save.py │ │ ├── text_to_speech_models/ │ │ │ ├── test_csm.py │ │ │ ├── test_lasa.py │ │ │ ├── test_orpheus.py │ │ │ └── test_whisper.py │ │ └── vision_models/ │ │ ├── test_index_file_sharded_model.py │ │ ├── test_push_to_hub_merged.py │ │ ├── test_save_merge_qwen2.5vl32B_model_ocr_benchmark.py │ │ └── test_save_merge_vision_model_ocr_benchmark.py │ ├── test_get_model_name.py │ ├── test_model_registry.py │ ├── test_raw_text.py │ └── utils/ │ ├── __init__.py │ ├── aime_eval.md │ ├── aime_eval.py │ ├── cleanup_utils.py │ ├── data_utils.py │ ├── hf_utils.py │ ├── ocr_eval.md │ ├── ocr_eval.py │ ├── os_utils.py │ ├── perplexity_eval.md │ ├── perplexity_eval.py │ ├── test_attention_masks.py │ ├── test_packing.py │ ├── test_qat.py │ └── test_trunc_normal_patch.py ├── unsloth/ │ ├── __init__.py │ ├── _auto_install.py │ ├── chat_templates.py │ ├── dataprep/ │ │ ├── __init__.py │ │ ├── raw_text.py │ │ ├── synthetic.py │ │ └── synthetic_configs.py │ ├── device_type.py │ ├── import_fixes.py │ ├── kernels/ │ │ ├── __init__.py │ │ ├── cross_entropy_loss.py │ │ ├── fast_lora.py │ │ ├── flex_attention.py │ │ ├── fp8.py │ │ ├── geglu.py │ │ ├── layernorm.py │ │ ├── moe/ │ │ │ ├── LICENSE │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── autotune_cache.py │ │ │ ├── benchmark/ │ │ │ │ ├── benchmark_fused_moe.py │ │ │ │ └── utils.py │ │ │ ├── grouped_gemm/ │ │ │ │ ├── LICENSE │ │ │ │ ├── __init__.py │ │ │ │ ├── interface.py │ │ │ │ ├── kernels/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── autotuning.py │ │ │ │ │ ├── backward.py │ │ │ │ │ ├── forward.py │ │ │ │ │ └── tuning.py │ │ │ │ └── reference/ │ │ │ │ ├── __init__.py │ │ │ │ ├── layers/ │ │ │ │ │ ├── llama4_moe.py │ │ │ │ │ └── qwen3_moe.py │ │ │ │ ├── moe_block.py │ │ │ │ └── moe_ops.py │ │ │ ├── requirements.txt │ │ │ └── tests/ │ │ │ ├── __init__.py │ │ │ ├── common.py │ │ │ ├── moe_utils.py │ │ │ ├── run_qwen3_moe_tests.sh │ │ │ ├── test_grouped_gemm.py │ │ │ ├── test_llama4_moe.py │ │ │ └── test_qwen3_moe.py │ │ ├── rms_layernorm.py │ │ ├── rope_embedding.py │ │ ├── swiglu.py │ │ └── utils.py │ ├── models/ │ │ ├── __init__.py │ │ ├── _utils.py │ │ ├── cohere.py │ │ ├── dpo.py │ │ ├── falcon_h1.py │ │ ├── gemma.py │ │ ├── gemma2.py │ │ ├── glm4_moe.py │ │ ├── granite.py │ │ ├── llama.py │ │ ├── llama4.py │ │ ├── loader.py │ │ ├── loader_utils.py │ │ ├── mapper.py │ │ ├── mistral.py │ │ ├── qwen2.py │ │ ├── qwen3.py │ │ ├── qwen3_moe.py │ │ ├── rl.py │ │ ├── rl_replacements.py │ │ ├── sentence_transformer.py │ │ └── vision.py │ ├── ollama_template_mappers.py │ ├── registry/ │ │ ├── REGISTRY.md │ │ ├── __init__.py │ │ ├── _deepseek.py │ │ ├── _gemma.py │ │ ├── _llama.py │ │ ├── _mistral.py │ │ ├── _phi.py │ │ ├── _qwen.py │ │ └── registry.py │ ├── save.py │ ├── tokenizer_utils.py │ ├── trainer.py │ └── utils/ │ ├── __init__.py │ ├── attention_dispatch.py │ ├── hf_hub.py │ └── packing.py ├── unsloth-cli.py └── unsloth_cli/ ├── __init__.py ├── commands/ │ ├── __init__.py │ ├── export.py │ ├── inference.py │ ├── studio.py │ ├── train.py │ └── ui.py ├── config.py └── options.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitattributes ================================================ # Normalize Python files to LF line endings *.py text eol=lf ================================================ FILE: .github/CODEOWNERS ================================================ # Inspired from https://github.com/vllm-project/vllm/blob/main/.github/CODEOWNERS /unsloth/models/loader.py @danielhanchen @mmathew23 /unsloth/models/llama.py @Datta0 @danielhanchen @mmathew23 /unsloth/models/rl.py @Datta0 @pluesclues @danielhanchen /unsloth/models/rl_replacements.py @Datta0 @pluesclues @danielhanchen /unsloth/trainer.py @danielhanchen /unsloth/models/sentence_transformer.py @Etherll @danielhanchen /unsloth/save.py @rolandtannous @danielhanchen /unsloth/tokenizer_utils.py @mmathew23 @danielhanchen /unsloth/chat_templates.py @rolandtannous @danielhanchen /unsloth/ollama_template_mappers.py @rolandtannous @danielhanchen /unsloth/kernels/moe/*.py @Datta0 /unsloth/import_fixes.py @danielhanchen /unsloth/device_type.py @danielhanchen /unsloth/_auto_install.py @danielhanchen /unsloth/dataprep/*.py @danielhanchen /unsloth/kernels/cross_entropy_loss.py @danielhanchen /unsloth/kernels/fast_lora.py @danielhanchen /unsloth/kernels/flex_attention.py @danielhanchen /unsloth/kernels/fp8.py @Datta0 /unsloth/kernels/geglu.py @danielhanchen /unsloth/kernels/layernorm.py @danielhanchen /unsloth/kernels/rms_layernorm.py @danielhanchen /unsloth/kernels/rope_embedding.py @danielhanchen /unsloth/kernels/swiglu.py @danielhanchen /unsloth/kernels/utils.py @danielhanchen @Datta0 /unsloth/models/_utils.py @danielhanchen @mmathew23 /unsloth/models/cohere.py @danielhanchen /unsloth/models/dpo.py @danielhanchen /unsloth/models/falcon_h1.py @danielhanchen /unsloth/models/gemma.py @danielhanchen /unsloth/models/gemma2.py @danielhanchen /unsloth/models/glm4_moe.py @Datta0 /unsloth/models/granite.py @danielhanchen /unsloth/models/llama4.py @danielhanchen /unsloth/models/loader_utils.py @Datta0 @danielhanchen /unsloth/models/mapper.py @danielhanchen /unsloth/models/mistral.py @danielhanchen /unsloth/models/qwen2.py @danielhanchen /unsloth/models/qwen3.py @Datta0 /unsloth/models/qwen3_moe.py @Datta0 /unsloth/models/vision.py @mmathew23 @danielhanchen /unsloth/utils/attention_dispatch.py @mmathew23 /unsloth/utils/hf_hub.py @mmathew23 /unsloth/utils/packing.py @mmathew23 /cli/ @rolandtannous @Manan17 /studio/frontend/ @Shine1i @rolandtannous @Manan17 /studio/frontend/public/ @Shine1i /studio/backend/ @rolandtannous /studio/backend/core/data_recipe/ @rolandtannous /studio/backend/tests/ @rolandtannous @danielhanchen /tests/ @rolandtannous @danielhanchen /scripts/ @rolandtannous @danielhanchen ================================================ FILE: .github/FUNDING.yml ================================================ # These are supported funding model platforms github: unslothai patreon: # Replace with a single Patreon username open_collective: # Replace with a single Open Collective username ko_fi: # unsloth tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry liberapay: # Replace with a single Liberapay username issuehunt: # Replace with a single IssueHunt username otechie: # Replace with a single Otechie username lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] ================================================ FILE: .github/ISSUE_TEMPLATE/bug---issue.md ================================================ --- name: Bug / Issue about: Bug / Issue title: "[Bug] Please fill in your issue title here." labels: bug assignees: '' --- 1. Did you update? `pip install --upgrade unsloth unsloth_zoo` 2. `Colab` or `Kaggle` or local / cloud 3. Number GPUs used, use `nvidia-smi` 4. Which notebook? Please link! 5. Which Unsloth version, TRL version, transformers version, PyTorch version? 6. Which trainer? `SFTTrainer`, `GRPOTrainer` etc ```python Put Minimal code to reproduce error here ###Remove Hugging Face token### ``` 🦥 You can also ask via our Reddit page: https://reddit.com/r/unsloth/ ================================================ FILE: .github/ISSUE_TEMPLATE/feature-request.md ================================================ --- name: Feature Request about: New features, model support, ideas title: "[Feature]" labels: feature request assignees: '' --- For new models, have you tried: ```python from unsloth import FastModel model, tokenizer = FastModel.from_pretrained( "microsoft/Phi-4-multimodal-instruct", trust_remote_code = True, ) from transformers import AutoModelForSequenceClassification model, tokenizer = FastModel.from_pretrained( auto_model = AutoModelForSequenceClassification, ) ``` ================================================ FILE: .github/workflows/stale.yml ================================================ name: 'Inactive Issue Pinger' on: schedule: - cron: '30 5 * * *' # Runs at 5:30 UTC every day jobs: stale: runs-on: ubuntu-latest permissions: issues: write steps: - uses: actions/stale@v10 with: # The message to post on stale issues. # This message will ping the issue author. # Note: The stale bot action does not currently support a direct placeholder for the last commenter. # As a workaround, this message encourages any participant to reply. stale-issue-message: > Is this issue still important to you? Apologies in advance we might have missed this issue as well. For faster response times, please post on our Reddit server - https://www.reddit.com/r/unsloth or our Discord - https://discord.com/invite/unsloth # The number of days of inactivity before an issue is considered stale. days-before-issue-stale: 9999 # Set to -1 to never close stale issues. days-before-issue-close: -1 # A label to apply to stale issues. stale-issue-label: 'inactive' # The number of operations to perform per run to avoid rate limiting. operations-per-run: 500 enable-statistics: false ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *.class unsloth_compiled_cache/ # ML artifacts (large files) feature/ outputs/ exports/ /datasets/ studio/backend/assets/datasets/ unsloth_training_checkpoints/ *.gguf *.safetensors # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ cover/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder .pybuilder/ target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: # .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # UV # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. # This is especially recommended for binary packages to ensure reproducibility, and is more # commonly ignored for libraries. #uv.lock # poetry # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. # This is especially recommended for binary packages to ensure reproducibility, and is more # commonly ignored for libraries. # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control #poetry.lock # pdm # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. #pdm.lock # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it # in version control. # https://pdm.fming.dev/latest/usage/project/#working-with-version-control .pdm.toml .pdm-python .pdm-build/ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ .venv_overlay/ .venv_t5/ environment.yaml # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # pytype static type analyzer .pytype/ # Cython debug symbols cython_debug/ # PyCharm # JetBrains specific template is maintained in a separate JetBrains.gitignore that can # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ # Ruff stuff: .ruff_cache/ .pre-commit-cache/ # PyPI configuration file and IDE/Editors .pypirc .vscode .idea/ .claude/ *.swp *.swo # oh-my-codex .omx/ # Firebase firebase-debug.log # Other resources/ tmp/ **/node_modules/ auth.db # Local working docs **/CLAUDE.md **/claude.md **/AGENT.md **/agent.md docs/canvas-lab-architecture.md log_rtx.txt log.txt setup_leo.sh server.pid *.log package-lock.json ================================================ FILE: .pre-commit-ci.yaml ================================================ ci: autofix_prs: true autofix_prs_limit: 5 autoupdate_schedule: monthly autoupdate_commit_msg: "chore: pre-commit autoupdate" skip: [] ================================================ FILE: .pre-commit-config.yaml ================================================ repos: - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.15.6 hooks: - id: ruff args: - --fix - --exit-non-zero-on-fix - repo: local hooks: - id: ruff-format-with-kwargs name: Ruff format with kwarg spacing entry: scripts/run_ruff_format.py language: python types: [python] additional_dependencies: - ruff==0.6.9 ================================================ FILE: CODE_OF_CONDUCT.md ================================================ # Contributor Covenant Code of Conduct ## Our Pledge We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, caste, color, religion, or sexual identity and orientation. We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community. ## Our Standards Examples of behavior that contributes to a positive environment for our community include: * Demonstrating empathy and kindness toward other people * Being respectful of differing opinions, viewpoints, and experiences * Giving and gracefully accepting constructive feedback * Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience * Focusing on what is best not just for us as individuals, but for the overall community Examples of unacceptable behavior include: * The use of sexualized language or imagery, and sexual attention or advances of any kind * Trolling, insulting or derogatory comments, and personal or political attacks * Public or private harassment * Publishing others' private information, such as a physical or email address, without their explicit permission * Other conduct which could reasonably be considered inappropriate in a professional setting ## Enforcement Responsibilities Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful. Community leaders have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, and will communicate reasons for moderation decisions when appropriate. ## Scope This Code of Conduct applies within all community spaces, and also applies when an individual is officially representing the community in public spaces. Examples of representing our community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. ## Enforcement Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement at support@unsloth.ai. All complaints will be reviewed and investigated promptly and fairly. All community leaders are obligated to respect the privacy and security of the reporter of any incident. ## Enforcement Guidelines Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct: ### 1. Correction **Community Impact**: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community. **Consequence**: A private, written warning from community leaders, providing clarity around the nature of the violation and an explanation of why the behavior was inappropriate. A public apology may be requested. ### 2. Warning **Community Impact**: A violation through a single incident or series of actions. **Consequence**: A warning with consequences for continued behavior. No interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, for a specified period of time. This includes avoiding interactions in community spaces as well as external channels like social media. Violating these terms may lead to a temporary or permanent ban. ### 3. Temporary Ban **Community Impact**: A serious violation of community standards, including sustained inappropriate behavior. **Consequence**: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban. ### 4. Permanent Ban **Community Impact**: Demonstrating a pattern of violation of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals. **Consequence**: A permanent ban from any sort of public interaction within the community. ## Attribution This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.1, available at [https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1]. Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder][Mozilla CoC]. For answers to common questions about this code of conduct, see the FAQ at [https://www.contributor-covenant.org/faq][FAQ]. Translations are available at [https://www.contributor-covenant.org/translations][translations]. [homepage]: https://www.contributor-covenant.org [v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html [Mozilla CoC]: https://github.com/mozilla/diversity [FAQ]: https://www.contributor-covenant.org/faq [translations]: https://www.contributor-covenant.org/translations ================================================ FILE: CONTRIBUTING.md ================================================ # 🦥 Contributing to Unsloth Thank you for not only using Unsloth but also for being interested in helping out! We value all contributions, whether they come in the form of code, ideas, support for others or just by simply spreading the word of Unsloth! 💕 - **[Support the Community](https://github.com/unslothai/unsloth/issues)**: Answer questions, review pull requests, or assist others in discussions. - **Fix Bugs**: Identify and resolve issues with the existing codebase. - **Submit Ideas**: Request new features or share enhancements you'd like to see. - **Develop Features**: Implement new functionality or improve existing tools which can be done via PRs. - **[Improve Documentation](https://docs.unsloth.ai/)**: Help by creating guides, FAQs, or enhancing clarity. One of the best ways to support us is by spreading the word about Unsloth! Share how it’s powering your amazing projects in blog posts or social media, and inspire others to explore its potential. Even a simple star on our repo goes a long way in showing your support and helping the community grow. 🌟 ## Submitting Issues If you find a bug or have a feature idea, we’d love to hear from you! Here’s how to make your submission stand out: ### Reporting Bugs 1. **Search First**: Check if the issue has already been reported using GitHub’s search bar under Issues. 2. **Details Matter**: Is this on Google Colab, Kaggle, or on another platform service? Are you using Unsloth's official notebook? Include your OS, Python version, and other relevant details. For bugs, a concise code snippet that reproduces the issue is incredibly helpful. 3. **Be Thorough**: Attach screenshots, traceback logs, or any additional information that might speed up resolution. ## Spread the Word Your support extends beyond code: - Spread the word by writing about Unsloth in blogs or social media. - Share how Unsloth powers your projects. - Star our repository to show your appreciation. Finally, please be mindful of our [Code of Conduct](https://github.com/unslothai/unsloth/blob/main/CODE_OF_CONDUCT.md) to ensure a welcoming and inclusive environment for everyone. Thank you so much for reading and we hope you have lots of fun using Unsloth! 🦥 ================================================ FILE: COPYING ================================================ GNU AFFERO GENERAL PUBLIC LICENSE Version 3, 19 November 2007 Copyright (C) 2007 Free Software Foundation, Inc. Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed. Preamble The GNU Affero General Public License is a free, copyleft license for software and other kinds of works, specifically designed to ensure cooperation with the community in the case of network server software. The licenses for most software and other practical works are designed to take away your freedom to share and change the works. By contrast, our General Public Licenses are intended to guarantee your freedom to share and change all versions of a program--to make sure it remains free software for all its users. When we speak of free software, we are referring to freedom, not price. Our General Public Licenses are designed to make sure that you have the freedom to distribute copies of free software (and charge for them if you wish), that you receive source code or can get it if you want it, that you can change the software or use pieces of it in new free programs, and that you know you can do these things. Developers that use our General Public Licenses protect your rights with two steps: (1) assert copyright on the software, and (2) offer you this License which gives you legal permission to copy, distribute and/or modify the software. A secondary benefit of defending all users' freedom is that improvements made in alternate versions of the program, if they receive widespread use, become available for other developers to incorporate. Many developers of free software are heartened and encouraged by the resulting cooperation. However, in the case of software used on network servers, this result may fail to come about. The GNU General Public License permits making a modified version and letting the public access it on a server without ever releasing its source code to the public. The GNU Affero General Public License is designed specifically to ensure that, in such cases, the modified source code becomes available to the community. It requires the operator of a network server to provide the source code of the modified version running there to the users of that server. Therefore, public use of a modified version, on a publicly accessible server, gives the public access to the source code of the modified version. An older license, called the Affero General Public License and published by Affero, was designed to accomplish similar goals. This is a different license, not a version of the Affero GPL, but Affero has released a new version of the Affero GPL which permits relicensing under this license. The precise terms and conditions for copying, distribution and modification follow. TERMS AND CONDITIONS 0. Definitions. "This License" refers to version 3 of the GNU Affero General Public License. "Copyright" also means copyright-like laws that apply to other kinds of works, such as semiconductor masks. "The Program" refers to any copyrightable work licensed under this License. Each licensee is addressed as "you". "Licensees" and "recipients" may be individuals or organizations. To "modify" a work means to copy from or adapt all or part of the work in a fashion requiring copyright permission, other than the making of an exact copy. The resulting work is called a "modified version" of the earlier work or a work "based on" the earlier work. A "covered work" means either the unmodified Program or a work based on the Program. To "propagate" a work means to do anything with it that, without permission, would make you directly or secondarily liable for infringement under applicable copyright law, except executing it on a computer or modifying a private copy. Propagation includes copying, distribution (with or without modification), making available to the public, and in some countries other activities as well. To "convey" a work means any kind of propagation that enables other parties to make or receive copies. Mere interaction with a user through a computer network, with no transfer of a copy, is not conveying. An interactive user interface displays "Appropriate Legal Notices" to the extent that it includes a convenient and prominently visible feature that (1) displays an appropriate copyright notice, and (2) tells the user that there is no warranty for the work (except to the extent that warranties are provided), that licensees may convey the work under this License, and how to view a copy of this License. If the interface presents a list of user commands or options, such as a menu, a prominent item in the list meets this criterion. 1. Source Code. The "source code" for a work means the preferred form of the work for making modifications to it. "Object code" means any non-source form of a work. A "Standard Interface" means an interface that either is an official standard defined by a recognized standards body, or, in the case of interfaces specified for a particular programming language, one that is widely used among developers working in that language. The "System Libraries" of an executable work include anything, other than the work as a whole, that (a) is included in the normal form of packaging a Major Component, but which is not part of that Major Component, and (b) serves only to enable use of the work with that Major Component, or to implement a Standard Interface for which an implementation is available to the public in source code form. A "Major Component", in this context, means a major essential component (kernel, window system, and so on) of the specific operating system (if any) on which the executable work runs, or a compiler used to produce the work, or an object code interpreter used to run it. The "Corresponding Source" for a work in object code form means all the source code needed to generate, install, and (for an executable work) run the object code and to modify the work, including scripts to control those activities. However, it does not include the work's System Libraries, or general-purpose tools or generally available free programs which are used unmodified in performing those activities but which are not part of the work. For example, Corresponding Source includes interface definition files associated with source files for the work, and the source code for shared libraries and dynamically linked subprograms that the work is specifically designed to require, such as by intimate data communication or control flow between those subprograms and other parts of the work. The Corresponding Source need not include anything that users can regenerate automatically from other parts of the Corresponding Source. The Corresponding Source for a work in source code form is that same work. 2. Basic Permissions. All rights granted under this License are granted for the term of copyright on the Program, and are irrevocable provided the stated conditions are met. This License explicitly affirms your unlimited permission to run the unmodified Program. The output from running a covered work is covered by this License only if the output, given its content, constitutes a covered work. This License acknowledges your rights of fair use or other equivalent, as provided by copyright law. You may make, run and propagate covered works that you do not convey, without conditions so long as your license otherwise remains in force. You may convey covered works to others for the sole purpose of having them make modifications exclusively for you, or provide you with facilities for running those works, provided that you comply with the terms of this License in conveying all material for which you do not control copyright. Those thus making or running the covered works for you must do so exclusively on your behalf, under your direction and control, on terms that prohibit them from making any copies of your copyrighted material outside their relationship with you. Conveying under any other circumstances is permitted solely under the conditions stated below. Sublicensing is not allowed; section 10 makes it unnecessary. 3. Protecting Users' Legal Rights From Anti-Circumvention Law. No covered work shall be deemed part of an effective technological measure under any applicable law fulfilling obligations under article 11 of the WIPO copyright treaty adopted on 20 December 1996, or similar laws prohibiting or restricting circumvention of such measures. When you convey a covered work, you waive any legal power to forbid circumvention of technological measures to the extent such circumvention is effected by exercising rights under this License with respect to the covered work, and you disclaim any intention to limit operation or modification of the work as a means of enforcing, against the work's users, your or third parties' legal rights to forbid circumvention of technological measures. 4. Conveying Verbatim Copies. You may convey verbatim copies of the Program's source code as you receive it, in any medium, provided that you conspicuously and appropriately publish on each copy an appropriate copyright notice; keep intact all notices stating that this License and any non-permissive terms added in accord with section 7 apply to the code; keep intact all notices of the absence of any warranty; and give all recipients a copy of this License along with the Program. You may charge any price or no price for each copy that you convey, and you may offer support or warranty protection for a fee. 5. Conveying Modified Source Versions. You may convey a work based on the Program, or the modifications to produce it from the Program, in the form of source code under the terms of section 4, provided that you also meet all of these conditions: a) The work must carry prominent notices stating that you modified it, and giving a relevant date. b) The work must carry prominent notices stating that it is released under this License and any conditions added under section 7. This requirement modifies the requirement in section 4 to "keep intact all notices". c) You must license the entire work, as a whole, under this License to anyone who comes into possession of a copy. This License will therefore apply, along with any applicable section 7 additional terms, to the whole of the work, and all its parts, regardless of how they are packaged. This License gives no permission to license the work in any other way, but it does not invalidate such permission if you have separately received it. d) If the work has interactive user interfaces, each must display Appropriate Legal Notices; however, if the Program has interactive interfaces that do not display Appropriate Legal Notices, your work need not make them do so. A compilation of a covered work with other separate and independent works, which are not by their nature extensions of the covered work, and which are not combined with it such as to form a larger program, in or on a volume of a storage or distribution medium, is called an "aggregate" if the compilation and its resulting copyright are not used to limit the access or legal rights of the compilation's users beyond what the individual works permit. Inclusion of a covered work in an aggregate does not cause this License to apply to the other parts of the aggregate. 6. Conveying Non-Source Forms. You may convey a covered work in object code form under the terms of sections 4 and 5, provided that you also convey the machine-readable Corresponding Source under the terms of this License, in one of these ways: a) Convey the object code in, or embodied in, a physical product (including a physical distribution medium), accompanied by the Corresponding Source fixed on a durable physical medium customarily used for software interchange. b) Convey the object code in, or embodied in, a physical product (including a physical distribution medium), accompanied by a written offer, valid for at least three years and valid for as long as you offer spare parts or customer support for that product model, to give anyone who possesses the object code either (1) a copy of the Corresponding Source for all the software in the product that is covered by this License, on a durable physical medium customarily used for software interchange, for a price no more than your reasonable cost of physically performing this conveying of source, or (2) access to copy the Corresponding Source from a network server at no charge. c) Convey individual copies of the object code with a copy of the written offer to provide the Corresponding Source. This alternative is allowed only occasionally and noncommercially, and only if you received the object code with such an offer, in accord with subsection 6b. d) Convey the object code by offering access from a designated place (gratis or for a charge), and offer equivalent access to the Corresponding Source in the same way through the same place at no further charge. You need not require recipients to copy the Corresponding Source along with the object code. If the place to copy the object code is a network server, the Corresponding Source may be on a different server (operated by you or a third party) that supports equivalent copying facilities, provided you maintain clear directions next to the object code saying where to find the Corresponding Source. Regardless of what server hosts the Corresponding Source, you remain obligated to ensure that it is available for as long as needed to satisfy these requirements. e) Convey the object code using peer-to-peer transmission, provided you inform other peers where the object code and Corresponding Source of the work are being offered to the general public at no charge under subsection 6d. A separable portion of the object code, whose source code is excluded from the Corresponding Source as a System Library, need not be included in conveying the object code work. A "User Product" is either (1) a "consumer product", which means any tangible personal property which is normally used for personal, family, or household purposes, or (2) anything designed or sold for incorporation into a dwelling. In determining whether a product is a consumer product, doubtful cases shall be resolved in favor of coverage. For a particular product received by a particular user, "normally used" refers to a typical or common use of that class of product, regardless of the status of the particular user or of the way in which the particular user actually uses, or expects or is expected to use, the product. A product is a consumer product regardless of whether the product has substantial commercial, industrial or non-consumer uses, unless such uses represent the only significant mode of use of the product. "Installation Information" for a User Product means any methods, procedures, authorization keys, or other information required to install and execute modified versions of a covered work in that User Product from a modified version of its Corresponding Source. The information must suffice to ensure that the continued functioning of the modified object code is in no case prevented or interfered with solely because modification has been made. If you convey an object code work under this section in, or with, or specifically for use in, a User Product, and the conveying occurs as part of a transaction in which the right of possession and use of the User Product is transferred to the recipient in perpetuity or for a fixed term (regardless of how the transaction is characterized), the Corresponding Source conveyed under this section must be accompanied by the Installation Information. But this requirement does not apply if neither you nor any third party retains the ability to install modified object code on the User Product (for example, the work has been installed in ROM). The requirement to provide Installation Information does not include a requirement to continue to provide support service, warranty, or updates for a work that has been modified or installed by the recipient, or for the User Product in which it has been modified or installed. Access to a network may be denied when the modification itself materially and adversely affects the operation of the network or violates the rules and protocols for communication across the network. Corresponding Source conveyed, and Installation Information provided, in accord with this section must be in a format that is publicly documented (and with an implementation available to the public in source code form), and must require no special password or key for unpacking, reading or copying. 7. Additional Terms. "Additional permissions" are terms that supplement the terms of this License by making exceptions from one or more of its conditions. Additional permissions that are applicable to the entire Program shall be treated as though they were included in this License, to the extent that they are valid under applicable law. If additional permissions apply only to part of the Program, that part may be used separately under those permissions, but the entire Program remains governed by this License without regard to the additional permissions. When you convey a copy of a covered work, you may at your option remove any additional permissions from that copy, or from any part of it. (Additional permissions may be written to require their own removal in certain cases when you modify the work.) You may place additional permissions on material, added by you to a covered work, for which you have or can give appropriate copyright permission. Notwithstanding any other provision of this License, for material you add to a covered work, you may (if authorized by the copyright holders of that material) supplement the terms of this License with terms: a) Disclaiming warranty or limiting liability differently from the terms of sections 15 and 16 of this License; or b) Requiring preservation of specified reasonable legal notices or author attributions in that material or in the Appropriate Legal Notices displayed by works containing it; or c) Prohibiting misrepresentation of the origin of that material, or requiring that modified versions of such material be marked in reasonable ways as different from the original version; or d) Limiting the use for publicity purposes of names of licensors or authors of the material; or e) Declining to grant rights under trademark law for use of some trade names, trademarks, or service marks; or f) Requiring indemnification of licensors and authors of that material by anyone who conveys the material (or modified versions of it) with contractual assumptions of liability to the recipient, for any liability that these contractual assumptions directly impose on those licensors and authors. All other non-permissive additional terms are considered "further restrictions" within the meaning of section 10. If the Program as you received it, or any part of it, contains a notice stating that it is governed by this License along with a term that is a further restriction, you may remove that term. If a license document contains a further restriction but permits relicensing or conveying under this License, you may add to a covered work material governed by the terms of that license document, provided that the further restriction does not survive such relicensing or conveying. If you add terms to a covered work in accord with this section, you must place, in the relevant source files, a statement of the additional terms that apply to those files, or a notice indicating where to find the applicable terms. Additional terms, permissive or non-permissive, may be stated in the form of a separately written license, or stated as exceptions; the above requirements apply either way. 8. Termination. You may not propagate or modify a covered work except as expressly provided under this License. Any attempt otherwise to propagate or modify it is void, and will automatically terminate your rights under this License (including any patent licenses granted under the third paragraph of section 11). However, if you cease all violation of this License, then your license from a particular copyright holder is reinstated (a) provisionally, unless and until the copyright holder explicitly and finally terminates your license, and (b) permanently, if the copyright holder fails to notify you of the violation by some reasonable means prior to 60 days after the cessation. Moreover, your license from a particular copyright holder is reinstated permanently if the copyright holder notifies you of the violation by some reasonable means, this is the first time you have received notice of violation of this License (for any work) from that copyright holder, and you cure the violation prior to 30 days after your receipt of the notice. Termination of your rights under this section does not terminate the licenses of parties who have received copies or rights from you under this License. If your rights have been terminated and not permanently reinstated, you do not qualify to receive new licenses for the same material under section 10. 9. Acceptance Not Required for Having Copies. You are not required to accept this License in order to receive or run a copy of the Program. Ancillary propagation of a covered work occurring solely as a consequence of using peer-to-peer transmission to receive a copy likewise does not require acceptance. However, nothing other than this License grants you permission to propagate or modify any covered work. These actions infringe copyright if you do not accept this License. Therefore, by modifying or propagating a covered work, you indicate your acceptance of this License to do so. 10. Automatic Licensing of Downstream Recipients. Each time you convey a covered work, the recipient automatically receives a license from the original licensors, to run, modify and propagate that work, subject to this License. You are not responsible for enforcing compliance by third parties with this License. An "entity transaction" is a transaction transferring control of an organization, or substantially all assets of one, or subdividing an organization, or merging organizations. If propagation of a covered work results from an entity transaction, each party to that transaction who receives a copy of the work also receives whatever licenses to the work the party's predecessor in interest had or could give under the previous paragraph, plus a right to possession of the Corresponding Source of the work from the predecessor in interest, if the predecessor has it or can get it with reasonable efforts. You may not impose any further restrictions on the exercise of the rights granted or affirmed under this License. For example, you may not impose a license fee, royalty, or other charge for exercise of rights granted under this License, and you may not initiate litigation (including a cross-claim or counterclaim in a lawsuit) alleging that any patent claim is infringed by making, using, selling, offering for sale, or importing the Program or any portion of it. 11. Patents. A "contributor" is a copyright holder who authorizes use under this License of the Program or a work on which the Program is based. The work thus licensed is called the contributor's "contributor version". A contributor's "essential patent claims" are all patent claims owned or controlled by the contributor, whether already acquired or hereafter acquired, that would be infringed by some manner, permitted by this License, of making, using, or selling its contributor version, but do not include claims that would be infringed only as a consequence of further modification of the contributor version. For purposes of this definition, "control" includes the right to grant patent sublicenses in a manner consistent with the requirements of this License. Each contributor grants you a non-exclusive, worldwide, royalty-free patent license under the contributor's essential patent claims, to make, use, sell, offer for sale, import and otherwise run, modify and propagate the contents of its contributor version. In the following three paragraphs, a "patent license" is any express agreement or commitment, however denominated, not to enforce a patent (such as an express permission to practice a patent or covenant not to sue for patent infringement). To "grant" such a patent license to a party means to make such an agreement or commitment not to enforce a patent against the party. If you convey a covered work, knowingly relying on a patent license, and the Corresponding Source of the work is not available for anyone to copy, free of charge and under the terms of this License, through a publicly available network server or other readily accessible means, then you must either (1) cause the Corresponding Source to be so available, or (2) arrange to deprive yourself of the benefit of the patent license for this particular work, or (3) arrange, in a manner consistent with the requirements of this License, to extend the patent license to downstream recipients. "Knowingly relying" means you have actual knowledge that, but for the patent license, your conveying the covered work in a country, or your recipient's use of the covered work in a country, would infringe one or more identifiable patents in that country that you have reason to believe are valid. If, pursuant to or in connection with a single transaction or arrangement, you convey, or propagate by procuring conveyance of, a covered work, and grant a patent license to some of the parties receiving the covered work authorizing them to use, propagate, modify or convey a specific copy of the covered work, then the patent license you grant is automatically extended to all recipients of the covered work and works based on it. A patent license is "discriminatory" if it does not include within the scope of its coverage, prohibits the exercise of, or is conditioned on the non-exercise of one or more of the rights that are specifically granted under this License. You may not convey a covered work if you are a party to an arrangement with a third party that is in the business of distributing software, under which you make payment to the third party based on the extent of your activity of conveying the work, and under which the third party grants, to any of the parties who would receive the covered work from you, a discriminatory patent license (a) in connection with copies of the covered work conveyed by you (or copies made from those copies), or (b) primarily for and in connection with specific products or compilations that contain the covered work, unless you entered into that arrangement, or that patent license was granted, prior to 28 March 2007. Nothing in this License shall be construed as excluding or limiting any implied license or other defenses to infringement that may otherwise be available to you under applicable patent law. 12. No Surrender of Others' Freedom. If conditions are imposed on you (whether by court order, agreement or otherwise) that contradict the conditions of this License, they do not excuse you from the conditions of this License. If you cannot convey a covered work so as to satisfy simultaneously your obligations under this License and any other pertinent obligations, then as a consequence you may not convey it at all. For example, if you agree to terms that obligate you to collect a royalty for further conveying from those to whom you convey the Program, the only way you could satisfy both those terms and this License would be to refrain entirely from conveying the Program. 13. Remote Network Interaction; Use with the GNU General Public License. Notwithstanding any other provision of this License, if you modify the Program, your modified version must prominently offer all users interacting with it remotely through a computer network (if your version supports such interaction) an opportunity to receive the Corresponding Source of your version by providing access to the Corresponding Source from a network server at no charge, through some standard or customary means of facilitating copying of software. This Corresponding Source shall include the Corresponding Source for any work covered by version 3 of the GNU General Public License that is incorporated pursuant to the following paragraph. Notwithstanding any other provision of this License, you have permission to link or combine any covered work with a work licensed under version 3 of the GNU General Public License into a single combined work, and to convey the resulting work. The terms of this License will continue to apply to the part which is the covered work, but the work with which it is combined will remain governed by version 3 of the GNU General Public License. 14. Revised Versions of this License. The Free Software Foundation may publish revised and/or new versions of the GNU Affero General Public License from time to time. Such new versions will be similar in spirit to the present version, but may differ in detail to address new problems or concerns. Each version is given a distinguishing version number. If the Program specifies that a certain numbered version of the GNU Affero General Public License "or any later version" applies to it, you have the option of following the terms and conditions either of that numbered version or of any later version published by the Free Software Foundation. If the Program does not specify a version number of the GNU Affero General Public License, you may choose any version ever published by the Free Software Foundation. If the Program specifies that a proxy can decide which future versions of the GNU Affero General Public License can be used, that proxy's public statement of acceptance of a version permanently authorizes you to choose that version for the Program. Later license versions may give you additional or different permissions. However, no additional obligations are imposed on any author or copyright holder as a result of your choosing to follow a later version. 15. Disclaimer of Warranty. THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 16. Limitation of Liability. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 17. Interpretation of Sections 15 and 16. If the disclaimer of warranty and limitation of liability provided above cannot be given local legal effect according to their terms, reviewing courts shall apply local law that most closely approximates an absolute waiver of all civil liability in connection with the Program, unless a warranty or assumption of liability accompanies a copy of the Program in return for a fee. END OF TERMS AND CONDITIONS How to Apply These Terms to Your New Programs If you develop a new program, and you want it to be of the greatest possible use to the public, the best way to achieve this is to make it free software which everyone can redistribute and change under these terms. To do so, attach the following notices to the program. It is safest to attach them to the start of each source file to most effectively state the exclusion of warranty; and each file should have at least the "copyright" line and a pointer to where the full notice is found. Copyright (C) This program is free software: you can redistribute it and/or modify it under the terms of the GNU Affero General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. You should have received a copy of the GNU Affero General Public License along with this program. If not, see . Also add information on how to contact you by electronic and paper mail. If your software can interact with users remotely through a computer network, you should also make sure that it provides a way for users to get its source. For example, if your program is a web application, its interface could display a "Source" link that leads users to an archive of the code. There are many ways you could offer source, and different solutions will be better for different programs; see section 13 for the specific requirements. You should also get your employer (if you work as a programmer) or school, if any, to sign a "copyright disclaimer" for the program, if necessary. For more information on this, and how to apply and follow the GNU AGPL, see . Files under unsloth/*, tests/*, scripts/* are Apache 2.0 licensed. Files under studio/*, unsloth_cli/* which is optional to install are AGPLv3 licensed. ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [2024-] [Unsloth AI. Inc team, Daniel Han-Chen & Michael Han-Chen] Files under unsloth/*, tests/*, scripts/* are Apache 2.0 licensed. Files under studio/*, unsloth_cli/* which is optional to install are AGPLv3 licensed. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================

Unsloth logo

Run and train AI models with a unified local interface.

FeaturesQuickstartNotebooksDocumentationDiscord

unsloth studio ui homepage Unsloth Studio (Beta) lets you run and train text, [audio](https://unsloth.ai/docs/basics/text-to-speech-tts-fine-tuning), [embedding](https://unsloth.ai/docs/new/embedding-finetuning), [vision](https://unsloth.ai/docs/basics/vision-fine-tuning) models on Windows, Linux and macOS. ## ⭐ Features Unsloth provides several key features for both inference and training: ### Inference * **Search + download + run models** including GGUF, LoRA adapters, safetensors * **Export models**: [Save or export](https://unsloth.ai/docs/new/studio/export) models to GGUF, 16-bit safetensors and other formats. * **Tool calling**: Support for [self-healing tool calling](https://unsloth.ai/docs/new/studio/chat#auto-healing-tool-calling) and web search * **[Code execution](https://unsloth.ai/docs/new/studio/chat#code-execution)**: lets LLMs test code in Claude artifacts and sandbox environments * [Auto-tune inference parameters](https://unsloth.ai/docs/new/studio/chat#auto-parameter-tuning) and customize chat templates. * Upload images, audio, PDFs, code, DOCX and more file types to chat with. ### Training * Train **500+ models** up to **2x faster** with up to **70% less VRAM**, with no accuracy loss. * Supports full fine-tuning, pretraining, 4-bit, 16-bit and, FP8 training. * **Observability**: Monitor training live, track loss and GPU usage and customize graphs. * **Data Recipes**: [Auto-create datasets](https://unsloth.ai/docs/new/studio/data-recipe) from **PDF, CSV, DOCX** etc. Edit data in a visual-node workflow. * **Reinforcement Learning**: The most efficient [RL](https://unsloth.ai/docs/get-started/reinforcement-learning-rl-guide) library, using **80% less VRAM** for GRPO, [FP8](https://unsloth.ai/docs/get-started/reinforcement-learning-rl-guide/fp8-reinforcement-learning) etc. * [Multi-GPU](https://unsloth.ai/docs/basics/multi-gpu-training-with-unsloth) training is supported, with major improvements coming soon. ## ⚡ Quickstart Unsloth can be used in two ways: through **[Unsloth Studio](https://unsloth.ai/docs/new/studio/)**, the web UI, or through **Unsloth Core**, the code-based version. Each has different requirements. ### Unsloth Studio (web UI) Unsloth Studio (Beta) works on **Windows, Linux, WSL** and **macOS**. * **CPU:** Supported for Chat and Data Recipes currently * **NVIDIA:** Training works on RTX 30/40/50, Blackwell, DGX Spark, Station and more * **macOS:** Currently supports chat and Data Recipes. **MLX training** is coming very soon * **AMD:** Chat works. Train with [Unsloth Core](#unsloth-core-code-based). Studio support is coming soon. * **Coming soon:** Training support for Apple MLX, AMD, and Intel. * **Multi-GPU:** Available now, with a major upgrade on the way #### MacOS, Linux, WSL Setup: ```bash curl -fsSL https://raw.githubusercontent.com/unslothai/unsloth/main/install.sh | sh ``` If you don't have `curl`, use `wget`. Then to launch after setup: ```bash source unsloth_studio/bin/activate unsloth studio -H 0.0.0.0 -p 8888 ``` #### Windows PowerShell Setup: ```powershell irm https://raw.githubusercontent.com/unslothai/unsloth/main/install.ps1 | iex ``` Then to launch after setup: ```powershell & .\unsloth_studio\Scripts\unsloth.exe studio -H 0.0.0.0 -p 8888 ``` #### MacOS, Linux, WSL developer installs: ```bash curl -LsSf https://astral.sh/uv/install.sh | sh uv venv unsloth_studio --python 3.13 source unsloth_studio/bin/activate uv pip install unsloth --torch-backend=auto unsloth studio setup unsloth studio -H 0.0.0.0 -p 8888 ``` #### Windows PowerShell developer installs: ```powershell winget install -e --id Python.Python.3.13 winget install --id=astral-sh.uv -e uv venv unsloth_studio --python 3.13 .\unsloth_studio\Scripts\activate uv pip install unsloth --torch-backend=auto unsloth studio setup unsloth studio -H 0.0.0.0 -p 8888 ``` #### Docker Use our [Docker image](https://hub.docker.com/r/unsloth/unsloth) ```unsloth/unsloth``` container. Run: ```bash docker run -d -e JUPYTER_PASSWORD="mypassword" \ -p 8888:8888 -p 8000:8000 -p 2222:22 \ -v $(pwd)/work:/workspace/work \ --gpus all \ unsloth/unsloth ``` #### Nightly Install - MacOS, Linux, WSL: ```bash curl -LsSf https://astral.sh/uv/install.sh | sh git clone --filter=blob:none https://github.com/unslothai/unsloth.git unsloth_studio cd unsloth_studio uv venv --python 3.13 source .venv/bin/activate uv pip install -e . --torch-backend=auto unsloth studio setup unsloth studio -H 0.0.0.0 -p 8888 ``` Then to launch every time: ```bash cd unsloth_studio source .venv/bin/activate unsloth studio -H 0.0.0.0 -p 8888 ``` #### Nightly Install - Windows: Run in Windows Powershell: ```bash winget install -e --id Python.Python.3.13 winget install --id=astral-sh.uv -e git clone --filter=blob:none https://github.com/unslothai/unsloth.git unsloth_studio cd unsloth_studio uv venv --python 3.13 .\.venv\Scripts\activate uv pip install -e . --torch-backend=auto unsloth studio setup unsloth studio -H 0.0.0.0 -p 8888 ``` Then to launch every time: ```bash cd unsloth_studio .\.venv\Scripts\activate unsloth studio -H 0.0.0.0 -p 8888 ``` ### Unsloth Core (code-based) #### Linux, WSL ```bash curl -LsSf https://astral.sh/uv/install.sh | sh uv venv unsloth_env --python 3.13 source unsloth_env/bin/activate uv pip install unsloth --torch-backend=auto ``` #### Windows Powershell ```bash winget install -e --id Python.Python.3.13 winget install --id=astral-sh.uv -e uv venv unsloth_env --python 3.13 .\unsloth_env\Scripts\activate uv pip install unsloth --torch-backend=auto ``` For Windows, `pip install unsloth` works only if you have Pytorch installed. Read our [Windows Guide](https://unsloth.ai/docs/get-started/install/windows-installation). You can use the same Docker image as Unsloth Studio. #### AMD, Intel For RTX 50x, B200, 6000 GPUs: `uv pip install unsloth --torch-backend=auto`. Read our guides for: [Blackwell](https://unsloth.ai/docs/blog/fine-tuning-llms-with-blackwell-rtx-50-series-and-unsloth) and [DGX Spark](https://unsloth.ai/docs/blog/fine-tuning-llms-with-nvidia-dgx-spark-and-unsloth).
To install Unsloth on **AMD** and **Intel** GPUs, follow our [AMD Guide](https://unsloth.ai/docs/get-started/install/amd) and [Intel Guide](https://unsloth.ai/docs/get-started/install/intel). ## ✨ Free Notebooks Train for free with our notebooks. Read our [guide](https://unsloth.ai/docs/get-started/fine-tuning-llms-guide). Add dataset, run, then deploy your trained model. | Model | Free Notebooks | Performance | Memory use | |-----------|---------|--------|----------| | **Qwen3.5 (4B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen3_5_(4B)_Vision.ipynb) | 1.5x faster | 60% less | | **gpt-oss (20B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/gpt-oss-(20B)-Fine-tuning.ipynb) | 2x faster | 70% less | | **gpt-oss (20B): GRPO** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/gpt-oss-(20B)-GRPO.ipynb) | 2x faster | 80% less | | **Qwen3: Advanced GRPO** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Qwen3_(4B)-GRPO.ipynb) | 2x faster | 50% less | | **Gemma 3 (4B) Vision** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Gemma3_(4B)-Vision.ipynb) | 1.7x faster | 60% less | | **embeddinggemma (300M)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/EmbeddingGemma_(300M).ipynb) | 2x faster | 20% less | | **Mistral Ministral 3 (3B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Ministral_3_VL_(3B)_Vision.ipynb) | 1.5x faster | 60% less | | **Llama 3.1 (8B) Alpaca** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-Alpaca.ipynb) | 2x faster | 70% less | | **Llama 3.2 Conversational** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(1B_and_3B)-Conversational.ipynb) | 2x faster | 70% less | | **Orpheus-TTS (3B)** | [▶️ Start for free](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Orpheus_(3B)-TTS.ipynb) | 1.5x faster | 50% less | - See all our notebooks for: [Kaggle](https://github.com/unslothai/notebooks?tab=readme-ov-file#-kaggle-notebooks), [GRPO](https://unsloth.ai/docs/get-started/unsloth-notebooks#grpo-reasoning-rl-notebooks), [TTS](https://unsloth.ai/docs/get-started/unsloth-notebooks#text-to-speech-tts-notebooks), [embedding](https://unsloth.ai/docs/new/embedding-finetuning) & [Vision](https://unsloth.ai/docs/get-started/unsloth-notebooks#vision-multimodal-notebooks) - See [all our models](https://unsloth.ai/docs/get-started/unsloth-model-catalog) and [all our notebooks](https://unsloth.ai/docs/get-started/unsloth-notebooks) - See detailed documentation for Unsloth [here](https://unsloth.ai/docs) ## 🦥 Unsloth News - **Introducing Unsloth Studio**: our new web UI for running and training LLMs. [Blog](https://unsloth.ai/docs/new/studio) - **Qwen3.5** - 0.8B, 2B, 4B, 9B, 27B, 35-A3B, 112B-A10B are now supported. [Guide + notebooks](https://unsloth.ai/docs/models/qwen3.5/fine-tune) - Train **MoE LLMs 12x faster** with 35% less VRAM - DeepSeek, GLM, Qwen and gpt-oss. [Blog](https://unsloth.ai/docs/new/faster-moe) - **Embedding models**: Unsloth now supports ~1.8-3.3x faster embedding fine-tuning. [Blog](https://unsloth.ai/docs/new/embedding-finetuning) • [Notebooks](https://unsloth.ai/docs/get-started/unsloth-notebooks#embedding-models) - New **7x longer context RL** vs. all other setups, via our new batching algorithms. [Blog](https://unsloth.ai/docs/new/grpo-long-context) - New RoPE & MLP **Triton Kernels** & **Padding Free + Packing**: 3x faster training & 30% less VRAM. [Blog](https://unsloth.ai/docs/new/3x-faster-training-packing) - **500K Context**: Training a 20B model with >500K context is now possible on an 80GB GPU. [Blog](https://unsloth.ai/docs/blog/500k-context-length-fine-tuning) - **FP8 & Vision RL**: You can now do FP8 & VLM GRPO on consumer GPUs. [FP8 Blog](https://unsloth.ai/docs/get-started/reinforcement-learning-rl-guide/fp8-reinforcement-learning) • [Vision RL](https://unsloth.ai/docs/get-started/reinforcement-learning-rl-guide/vision-reinforcement-learning-vlm-rl) - **gpt-oss** by OpenAI: Read our [RL blog](https://unsloth.ai/docs/models/gpt-oss-how-to-run-and-fine-tune/gpt-oss-reinforcement-learning), [Flex Attention](https://unsloth.ai/docs/models/gpt-oss-how-to-run-and-fine-tune/long-context-gpt-oss-training) blog and [Guide](https://unsloth.ai/docs/models/gpt-oss-how-to-run-and-fine-tune). ## 🔗 Links and Resources | Type | Links | | ----------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------ | |   **r/unsloth Reddit** | [Join Reddit community](https://reddit.com/r/unsloth) | | 📚 **Documentation & Wiki** | [Read Our Docs](https://unsloth.ai/docs) | |   **Twitter (aka X)** | [Follow us on X](https://twitter.com/unslothai) | | 💾 **Installation** | [Pip & Docker Install](https://unsloth.ai/docs/get-started/install) | | 🔮 **Our Models** | [Unsloth Catalog](https://unsloth.ai/docs/get-started/unsloth-model-catalog) | | ✍️ **Blog** | [Read our Blogs](https://unsloth.ai/blog) | ### Citation You can cite the Unsloth repo as follows: ```bibtex @software{unsloth, author = {Daniel Han, Michael Han and Unsloth team}, title = {Unsloth}, url = {https://github.com/unslothai/unsloth}, year = {2023} } ``` If you trained a model with 🦥Unsloth, you can use this cool sticker!   ### License Unsloth uses a dual-licensing model of Apache 2.0 and AGPL-3.0. The core Unsloth package remains licensed under **[Apache 2.0](https://github.com/unslothai/unsloth?tab=Apache-2.0-1-ov-file)**, while certain optional components, such as the Unsloth Studio UI are licensed under the open-source license **[AGPL-3.0](https://github.com/unslothai/unsloth?tab=AGPL-3.0-2-ov-file)**. This structure helps support ongoing Unsloth development while keeping the project open source and enabling the broader ecosystem to continue growing. ### Thank You to - The [llama.cpp library](https://github.com/ggml-org/llama.cpp) that lets users run and save models with Unsloth - The Hugging Face team and their libraries: [transformers](https://github.com/huggingface/transformers) and [TRL](https://github.com/huggingface/trl) - The Pytorch and [Torch AO](https://github.com/unslothai/unsloth/pull/3391) team for their contributions - And of course for every single person who has contributed or has used Unsloth! ================================================ FILE: build.sh ================================================ #!/usr/bin/env bash set -euo pipefail # 1. Build frontend (Vite outputs to dist/) cd studio/frontend # Clean stale dist to force a full rebuild rm -rf dist # Tailwind v4's oxide scanner respects .gitignore in parent directories. # Python venvs create a .gitignore with "*" (ignore everything), which # prevents Tailwind from scanning .tsx source files for class names. # Temporarily hide any such .gitignore during the build, then restore it. _HIDDEN_GITIGNORES=() _dir="$(pwd)" while [ "$_dir" != "/" ]; do _dir="$(dirname "$_dir")" if [ -f "$_dir/.gitignore" ] && grep -qx '\*' "$_dir/.gitignore" 2>/dev/null; then mv "$_dir/.gitignore" "$_dir/.gitignore._twbuild" _HIDDEN_GITIGNORES+=("$_dir/.gitignore") fi done _restore_gitignores() { for _gi in "${_HIDDEN_GITIGNORES[@]+"${_HIDDEN_GITIGNORES[@]}"}"; do mv "${_gi}._twbuild" "$_gi" 2>/dev/null || true done } trap _restore_gitignores EXIT npm install npm run build # outputs to studio/frontend/dist/ _restore_gitignores trap - EXIT # Validate CSS output -- catch truncated Tailwind builds before packaging MAX_CSS_SIZE=$(find dist/assets -name '*.css' -exec wc -c {} + 2>/dev/null | sort -n | tail -1 | awk '{print $1}') if [ -z "$MAX_CSS_SIZE" ]; then echo "❌ ERROR: No CSS files were emitted into dist/assets." echo " The frontend build may have failed silently." exit 1 fi if [ "$MAX_CSS_SIZE" -lt 100000 ]; then echo "❌ ERROR: Largest CSS file is only $((MAX_CSS_SIZE / 1024))KB (expected >100KB)." echo " Tailwind may not have scanned all source files." echo " Check for .gitignore files blocking the Tailwind oxide scanner." exit 1 fi echo "✅ Frontend CSS validated (${MAX_CSS_SIZE} bytes)" cd ../.. # 2. Clean old artifacts rm -rf build dist *.egg-info # 3. Build wheel python -m build # 4. Optionally publish if [ "${1:-}" = "publish" ]; then python -m twine upload dist/* fi ================================================ FILE: cli.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 from unsloth_cli import app if __name__ == "__main__": app() ================================================ FILE: install.ps1 ================================================ # Unsloth Studio Installer for Windows PowerShell # Usage: irm https://raw.githubusercontent.com/unslothai/unsloth/main/install.ps1 | iex # Local: Set-ExecutionPolicy -Scope Process -ExecutionPolicy Bypass; .\install.ps1 function Install-UnslothStudio { $ErrorActionPreference = "Stop" $VenvName = "unsloth_studio" $PythonVersion = "3.13" Write-Host "" Write-Host "=========================================" Write-Host " Unsloth Studio Installer (Windows)" Write-Host "=========================================" Write-Host "" # ── Helper: refresh PATH from registry (preserving current session entries) ── function Refresh-SessionPath { $machine = [System.Environment]::GetEnvironmentVariable("Path", "Machine") $user = [System.Environment]::GetEnvironmentVariable("Path", "User") $env:Path = "$machine;$user;$env:Path" } # ── Check winget ── if (-not (Get-Command winget -ErrorAction SilentlyContinue)) { Write-Host "Error: winget is not available." -ForegroundColor Red Write-Host " Install it from https://aka.ms/getwinget" -ForegroundColor Yellow Write-Host " or install Python $PythonVersion and uv manually, then re-run." -ForegroundColor Yellow return } # ── Install Python if no compatible version (3.11-3.13) found ── $DetectedPythonVersion = "" if (Get-Command python -ErrorAction SilentlyContinue) { $pyVer = python --version 2>&1 if ($pyVer -match "Python (3\.1[1-3])\.\d+") { Write-Host "==> Python already installed: $pyVer" $DetectedPythonVersion = $Matches[1] } } if (-not $DetectedPythonVersion) { Write-Host "==> Installing Python ${PythonVersion}..." winget install -e --id Python.Python.3.13 --accept-package-agreements --accept-source-agreements Refresh-SessionPath if ($LASTEXITCODE -ne 0) { # winget returns non-zero for "already installed" -- only fail if python is truly missing if (-not (Get-Command python -ErrorAction SilentlyContinue)) { Write-Host "[ERROR] Python installation failed (exit code $LASTEXITCODE)" -ForegroundColor Red return } } $DetectedPythonVersion = $PythonVersion } # ── Install uv if not present ── if (-not (Get-Command uv -ErrorAction SilentlyContinue)) { Write-Host "==> Installing uv package manager..." winget install --id=astral-sh.uv -e --accept-package-agreements --accept-source-agreements Refresh-SessionPath # Fallback: if winget didn't put uv on PATH, try the PowerShell installer if (-not (Get-Command uv -ErrorAction SilentlyContinue)) { Write-Host " Trying alternative uv installer..." powershell -ExecutionPolicy ByPass -c "irm https://astral.sh/uv/install.ps1 | iex" Refresh-SessionPath } } if (-not (Get-Command uv -ErrorAction SilentlyContinue)) { Write-Host "Error: uv could not be installed." -ForegroundColor Red Write-Host " Install it from https://docs.astral.sh/uv/" -ForegroundColor Yellow return } # ── Create venv (skip if it already exists and has a valid interpreter) ── $VenvPython = Join-Path $VenvName "Scripts\python.exe" if (-not (Test-Path $VenvPython)) { if (Test-Path $VenvName) { Remove-Item -Recurse -Force $VenvName } Write-Host "==> Creating Python ${DetectedPythonVersion} virtual environment (${VenvName})..." uv venv $VenvName --python $DetectedPythonVersion if ($LASTEXITCODE -ne 0) { Write-Host "[ERROR] Failed to create virtual environment (exit code $LASTEXITCODE)" -ForegroundColor Red return } } else { Write-Host "==> Virtual environment ${VenvName} already exists, skipping creation." } # ── Install unsloth directly into the venv (no activation needed) ── Write-Host "==> Installing unsloth (this may take a few minutes)..." uv pip install --python $VenvPython unsloth --torch-backend=auto if ($LASTEXITCODE -ne 0) { Write-Host "[ERROR] Failed to install unsloth (exit code $LASTEXITCODE)" -ForegroundColor Red return } # ── Run studio setup ── # setup.ps1 will handle installing Git, CMake, Visual Studio Build Tools, # CUDA Toolkit, Node.js, and other dependencies automatically via winget. Write-Host "==> Running unsloth studio setup..." $UnslothExe = Join-Path $VenvName "Scripts\unsloth.exe" & $UnslothExe studio setup if ($LASTEXITCODE -ne 0) { Write-Host "[ERROR] unsloth studio setup failed (exit code $LASTEXITCODE)" -ForegroundColor Red return } Write-Host "" Write-Host "=========================================" Write-Host " Unsloth Studio installed!" Write-Host "=========================================" Write-Host "" Write-Host " To launch, run:" Write-Host "" Write-Host " .\${VenvName}\Scripts\activate" Write-Host " unsloth studio -H 0.0.0.0 -p 8888" Write-Host "" } Install-UnslothStudio ================================================ FILE: install.sh ================================================ #!/bin/sh # Unsloth Studio Installer # Usage (curl): curl -fsSL https://raw.githubusercontent.com/unslothai/unsloth/main/install.sh | sh # Usage (wget): wget -qO- https://raw.githubusercontent.com/unslothai/unsloth/main/install.sh | sh set -e VENV_NAME="unsloth_studio" PYTHON_VERSION="3.13" # ── Helper: download a URL to a file (supports curl and wget) ── download() { if command -v curl >/dev/null 2>&1; then curl -LsSf "$1" -o "$2" elif command -v wget >/dev/null 2>&1; then wget -qO "$2" "$1" else echo "Error: neither curl nor wget found. Install one and re-run." exit 1 fi } # ── Helper: check if a single package is available on the system ── _is_pkg_installed() { case "$1" in build-essential) command -v gcc >/dev/null 2>&1 ;; libcurl4-openssl-dev) command -v dpkg >/dev/null 2>&1 && dpkg -s "$1" >/dev/null 2>&1 ;; pciutils) command -v lspci >/dev/null 2>&1 ;; *) command -v "$1" >/dev/null 2>&1 ;; esac } # ── Helper: install packages via apt, escalating to sudo only if needed ── # Usage: _smart_apt_install pkg1 pkg2 pkg3 ... _smart_apt_install() { _PKGS="$*" # Step 1: Try installing without sudo (works when already root) apt-get update -y /dev/null 2>&1 || true apt-get install -y $_PKGS /dev/null 2>&1 || true # Step 2: Check which packages are still missing _STILL_MISSING="" for _pkg in $_PKGS; do if ! _is_pkg_installed "$_pkg"; then _STILL_MISSING="$_STILL_MISSING $_pkg" fi done _STILL_MISSING=$(echo "$_STILL_MISSING" | sed 's/^ *//') if [ -z "$_STILL_MISSING" ]; then return 0 fi # Step 3: Escalate -- need elevated permissions for remaining packages if command -v sudo >/dev/null 2>&1; then echo "" echo " !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!" echo " WARNING: We require sudo elevated permissions to install:" echo " $_STILL_MISSING" echo " If you accept, we'll run sudo now, and it'll prompt your password." echo " !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!" echo "" printf " Accept? [Y/n] " if [ -r /dev/tty ]; then read -r REPLY /dev/null; then OS="wsl" fi echo "==> Platform: $OS" # ── Check system dependencies ── # cmake and git are needed by unsloth studio setup to build the GGUF inference # engine (llama.cpp). build-essential and libcurl-dev are also needed on Linux. MISSING="" command -v cmake >/dev/null 2>&1 || MISSING="$MISSING cmake" command -v git >/dev/null 2>&1 || MISSING="$MISSING git" case "$OS" in macos) # Xcode Command Line Tools provide the C/C++ compiler if ! xcode-select -p >/dev/null 2>&1; then echo "" echo "==> Xcode Command Line Tools are required." echo " Installing (a system dialog will appear)..." xcode-select --install /dev/null || true echo " After the installation completes, please re-run this script." exit 1 fi ;; linux|wsl) # curl or wget is needed for downloads; check both if ! command -v curl >/dev/null 2>&1 && ! command -v wget >/dev/null 2>&1; then MISSING="$MISSING curl" fi command -v gcc >/dev/null 2>&1 || MISSING="$MISSING build-essential" # libcurl dev headers for llama.cpp HTTPS support if command -v dpkg >/dev/null 2>&1; then dpkg -s libcurl4-openssl-dev >/dev/null 2>&1 || MISSING="$MISSING libcurl4-openssl-dev" fi ;; esac MISSING=$(echo "$MISSING" | sed 's/^ *//') if [ -n "$MISSING" ]; then echo "" echo "==> Unsloth Studio needs these packages: $MISSING" echo " These are needed to build the GGUF inference engine." case "$OS" in macos) if ! command -v brew >/dev/null 2>&1; then echo "" echo " Homebrew is required to install them." echo " Install Homebrew from https://brew.sh then re-run this script." exit 1 fi brew install $MISSING /dev/null 2>&1; then _smart_apt_install $MISSING else echo " apt-get is not available. Please install with your package manager:" echo " $MISSING" echo " Then re-run Unsloth Studio setup." exit 1 fi ;; esac echo "" else echo "==> All system dependencies found." fi # ── Install uv ── if ! command -v uv >/dev/null 2>&1; then echo "==> Installing uv package manager..." _uv_tmp=$(mktemp) download "https://astral.sh/uv/install.sh" "$_uv_tmp" sh "$_uv_tmp" Creating Python ${PYTHON_VERSION} virtual environment (${VENV_NAME})..." uv venv "$VENV_NAME" --python "$PYTHON_VERSION" else echo "==> Virtual environment ${VENV_NAME} already exists, skipping creation." fi # ── Install unsloth directly into the venv (no activation needed) ── echo "==> Installing unsloth (this may take a few minutes)..." uv pip install --python "$VENV_NAME/bin/python" unsloth --torch-backend=auto # ── Run studio setup ── # Ensure the venv's Python is on PATH for setup.sh's Python discovery. # On macOS the system Python may be outside the 3.11-3.13 range that # setup.sh requires, but uv already installed a compatible interpreter # inside the venv. VENV_ABS_BIN="$(cd "$VENV_NAME/bin" && pwd)" if [ -n "$VENV_ABS_BIN" ]; then export PATH="$VENV_ABS_BIN:$PATH" fi echo "==> Running unsloth studio setup..." "$VENV_NAME/bin/unsloth" studio setup =3.0.0 ; ('linux' in sys_platform)", "triton-windows ; (sys_platform == 'win32') and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", ] huggingfacenotorch = [ "wheel>=0.42.0", "packaging", "numpy", "tqdm", "psutil", "tyro", "protobuf", "sentencepiece>=0.2.0", "datasets>=3.4.1,!=4.0.*,!=4.1.0,<4.4.0", "accelerate>=0.34.1", "peft>=0.18.0,!=0.11.0", "huggingface_hub>=0.34.0", "hf_transfer", "diffusers", "transformers>=4.51.3,!=4.52.0,!=4.52.1,!=4.52.2,!=4.52.3,!=4.53.0,!=4.54.0,!=4.55.0,!=4.55.1,!=4.57.0,!=4.57.4,!=4.57.5,!=5.0.0,!=5.1.0,<=5.3.0", "trl>=0.18.2,!=0.19.0,<=0.24.0", "sentence-transformers", ] huggingface = [ "unsloth[huggingfacenotorch]", "unsloth_zoo>=2026.3.4", "torchvision", "unsloth[triton]", ] windows = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0 ; (sys_platform == 'win32')", "xformers>=0.0.22.post7 ; (sys_platform == 'win32')", ] base = [ "unsloth[huggingface]", ] cu118only = [ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)", ] cu121only = [ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.22.post7-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.22.post7-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.22.post7-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)", ] cu118onlytorch211 = [ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)", ] cu121onlytorch211 = [ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)", ] cu118onlytorch212 = [ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23.post1%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23.post1%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23.post1%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)", ] cu121onlytorch212 = [ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23.post1-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23.post1-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23.post1-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)", ] cu118onlytorch220 = [ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.24%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.24%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.24%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)", ] cu121onlytorch220 = [ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.24-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.24-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.24-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)", ] cu118onlytorch230 = [ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27%2Bcu118-cp312-cp312-manylinux2014_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)", ] cu121onlytorch230 = [ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.27-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.27-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.27-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.27-cp312-cp312-manylinux2014_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)", ] cu118onlytorch240 = [ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27.post2%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27.post2%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27.post2%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27.post2%2Bcu118-cp312-cp312-manylinux2014_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)", ] cu121onlytorch240 = [ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post1-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post1-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post1-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post1-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)", ] cu124onlytorch240 = [ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp39-cp39-win_amd64.whl ; python_version=='3.9' and (sys_platform == 'win32')", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp310-cp310-win_amd64.whl ; python_version=='3.10' and (sys_platform == 'win32')", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp311-cp311-win_amd64.whl ; python_version=='3.11' and (sys_platform == 'win32')", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp312-cp312-win_amd64.whl ; python_version=='3.12' and (sys_platform == 'win32')", ] cu118onlytorch250 = [ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.28.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.28.post2-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.28.post2-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.28.post2-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)", ] cu121onlytorch250 = [ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post2-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post2-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post2-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)", ] cu124onlytorch250 = [ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp39-cp39-win_amd64.whl ; python_version=='3.9' and (sys_platform == 'win32')", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp310-cp310-win_amd64.whl ; python_version=='3.10' and (sys_platform == 'win32')", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp311-cp311-win_amd64.whl ; python_version=='3.11' and (sys_platform == 'win32')", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp312-cp312-win_amd64.whl ; python_version=='3.12' and (sys_platform == 'win32')", ] cu118onlytorch251 = [ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post1-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post1-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post1-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post1-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)", ] cu121onlytorch251 = [ "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29.post1-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29.post1-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29.post1-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.29.post1-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)", ] cu124onlytorch251 = [ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp39-cp39-win_amd64.whl ; python_version=='3.9' and (sys_platform == 'win32')", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp310-cp310-win_amd64.whl ; python_version=='3.10' and (sys_platform == 'win32')", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp311-cp311-win_amd64.whl ; python_version=='3.11' and (sys_platform == 'win32')", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post1-cp312-cp312-win_amd64.whl ; python_version=='3.12' and (sys_platform == 'win32')", ] cu118onlytorch260 = [ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post3-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post3-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post3-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.29.post3-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)", ] cu124onlytorch260 = [ "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp39-cp39-win_amd64.whl ; python_version=='3.9' and (sys_platform == 'win32')", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp310-cp310-win_amd64.whl ; python_version=='3.10' and (sys_platform == 'win32')", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp311-cp311-win_amd64.whl ; python_version=='3.11' and (sys_platform == 'win32')", "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.29.post3-cp312-cp312-win_amd64.whl ; python_version=='3.12' and (sys_platform == 'win32')", ] cu126onlytorch260 = [ "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp39-cp39-win_amd64.whl ; python_version=='3.9' and (sys_platform == 'win32')", "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp310-cp310-win_amd64.whl ; python_version=='3.10' and (sys_platform == 'win32')", "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp311-cp311-win_amd64.whl ; python_version=='3.11' and (sys_platform == 'win32')", "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.29.post3-cp312-cp312-win_amd64.whl ; python_version=='3.12' and (sys_platform == 'win32')", ] cu118onlytorch270 = [ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.30-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.30-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.30-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.30-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.30-cp39-cp39-win_amd64.whl ; python_version=='3.9' and (sys_platform == 'win32')", "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.30-cp310-cp310-win_amd64.whl ; python_version=='3.10' and (sys_platform == 'win32')", "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.30-cp311-cp311-win_amd64.whl ; python_version=='3.11' and (sys_platform == 'win32')", "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.30-cp312-cp312-win_amd64.whl ; python_version=='3.12' and (sys_platform == 'win32')", ] cu126onlytorch270 = [ "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.30-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.30-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.30-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.30-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.30-cp39-cp39-win_amd64.whl ; python_version=='3.9' and (sys_platform == 'win32')", "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.30-cp310-cp310-win_amd64.whl ; python_version=='3.10' and (sys_platform == 'win32')", "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.30-cp311-cp311-win_amd64.whl ; python_version=='3.11' and (sys_platform == 'win32')", "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.30-cp312-cp312-win_amd64.whl ; python_version=='3.12' and (sys_platform == 'win32')", ] cu128onlytorch270 = [ "xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.30-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.30-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.30-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.30-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.30-cp39-cp39-win_amd64.whl ; python_version=='3.9' and (sys_platform == 'win32')", "xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.30-cp310-cp310-win_amd64.whl ; python_version=='3.10' and (sys_platform == 'win32')", "xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.30-cp311-cp311-win_amd64.whl ; python_version=='3.11' and (sys_platform == 'win32')", "xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.30-cp312-cp312-win_amd64.whl ; python_version=='3.12' and (sys_platform == 'win32')", ] cu118onlytorch271 = [ "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.31.post1-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.31.post1-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')", ] cu126onlytorch271 = [ "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.31.post1-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.31.post1-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')", ] cu128onlytorch271 = [ "xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.31.post1-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.31.post1-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')", ] cu118onlytorch280 = [ "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.32.post2-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.32.post2-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')", ] cu126onlytorch280 = [ "xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.32.post2-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.32.post2-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')", ] cu128onlytorch280 = [ "xformers @ https://download.pytorch.org/whl/cu129/xformers-0.0.32.post2-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu129/xformers-0.0.32.post2-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')", ] cu130onlytorch280 = [ ] cu126onlytorch290 = [ "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.33.post1-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.33.post1-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')", ] cu128onlytorch290 = [ "xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.33.post1-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.33.post1-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')", ] cu130onlytorch290 = [ "xformers @ https://download.pytorch.org/whl/cu130/xformers-0.0.33.post1-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu130/xformers-0.0.33.post1-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')", ] cu126onlytorch291 = [ "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.33.post2-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.33.post2-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')", ] cu128onlytorch291 = [ "xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.33.post2-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.33.post2-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')", ] cu130onlytorch291 = [ "xformers @ https://download.pytorch.org/whl/cu130/xformers-0.0.33.post2-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu130/xformers-0.0.33.post2-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')", ] cu126onlytorch2100 = [ "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.34-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu126/xformers-0.0.34-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')", ] cu128onlytorch2100 = [ "xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.34-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu128/xformers-0.0.34-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')", ] cu130onlytorch2100 = [ "xformers @ https://download.pytorch.org/whl/cu130/xformers-0.0.34-cp39-abi3-manylinux_2_28_x86_64.whl ; ('linux' in sys_platform)", "xformers @ https://download.pytorch.org/whl/cu130/xformers-0.0.34-cp39-abi3-win_amd64.whl ; (sys_platform == 'win32')", ] cu118 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu118only]", ] cu121 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu121only]", ] cu118-torch211 = [ "unsloth[huggingface]", "bitsandbytes==0.45.5", "unsloth[cu118onlytorch211]", ] cu121-torch211 = [ "unsloth[huggingface]", "bitsandbytes==0.45.5", "unsloth[cu121onlytorch211]", ] cu118-torch212 = [ "unsloth[huggingface]", "bitsandbytes==0.45.5", "unsloth[cu118onlytorch212]", ] cu121-torch212 = [ "unsloth[huggingface]", "bitsandbytes==0.45.5", "unsloth[cu121onlytorch212]", ] cu118-torch220 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu118onlytorch220]", ] cu121-torch220 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu121onlytorch220]", ] cu118-torch230 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu118onlytorch230]", ] cu121-torch230 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu121onlytorch230]", ] cu118-torch240 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu118onlytorch240]", ] cu121-torch240 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu121onlytorch240]", ] cu124-torch240 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu124onlytorch240]", ] cu118-torch250 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu118onlytorch250]", ] cu121-torch250 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu121onlytorch250]", ] cu124-torch250 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu124onlytorch250]", ] cu118-torch251 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu118onlytorch251]", ] cu121-torch251 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu121onlytorch251]", ] cu124-torch251 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu124onlytorch251]", ] cu118-torch260 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu118onlytorch260]", ] cu124-torch260 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu124onlytorch260]", ] cu126-torch260 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu126onlytorch260]", ] cu118-torch270 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu118onlytorch270]", ] cu126-torch270 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu126onlytorch270]", ] cu128-torch270 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu128onlytorch270]", ] cu118-torch271 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu118onlytorch271]", ] cu126-torch271 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu126onlytorch271]", ] cu128-torch271 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu128onlytorch271]", ] cu118-torch280 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu118onlytorch280]", ] cu126-torch280 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu126onlytorch280]", ] cu128-torch280 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu128onlytorch280]", ] cu130-torch280 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu130onlytorch280]", ] cu126-torch290 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu126onlytorch290]", ] cu128-torch290 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu128onlytorch290]", ] cu130-torch290 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu130onlytorch290]", ] cu126-torch291 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu126onlytorch291]", ] cu128-torch291 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu128onlytorch291]", ] cu130-torch291 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu130onlytorch291]", ] cu126-torch2100 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu126onlytorch2100]", ] cu128-torch2100 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu128onlytorch2100]", ] cu130-torch2100 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu130onlytorch2100]", ] kaggle = [ "unsloth[huggingface]", ] kaggle-new = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", ] conda = [ "unsloth[huggingface]", ] colab-torch211 = [ "unsloth[huggingface]", "bitsandbytes==0.45.5", "unsloth[cu121onlytorch211]", ] colab-ampere-torch211 = [ "unsloth[huggingface]", "bitsandbytes==0.45.5", "unsloth[cu121onlytorch211]", "packaging", "ninja", "flash-attn>=2.6.3 ; ('linux' in sys_platform)", ] colab-torch220 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu121onlytorch220]", ] colab-ampere-torch220 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu121onlytorch220]", "packaging", "ninja", "flash-attn>=2.6.3 ; ('linux' in sys_platform)", ] colab-new = [ "unsloth_zoo>=2026.3.4", "packaging", "tyro", "transformers>=4.51.3,!=4.52.0,!=4.52.1,!=4.52.2,!=4.52.3,!=4.53.0,!=4.54.0,!=4.55.0,!=4.55.1,!=4.57.0,!=4.57.4,!=4.57.5,!=5.0.0,!=5.1.0,<=5.3.0", "datasets>=3.4.1,!=4.0.*,!=4.1.0,<4.4.0", "sentencepiece>=0.2.0", "tqdm", "psutil", "wheel>=0.42.0", "numpy", "protobuf", "huggingface_hub>=0.34.0", "hf_transfer", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[triton]", "sentence-transformers", ] colab-no-deps = [ "accelerate>=0.34.1", "trl>=0.18.2,!=0.19.0,<=0.24.0", "peft>=0.18.0", "xformers ; ('linux' in sys_platform or sys_platform == 'win32') and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "protobuf", ] colab = [ "unsloth[cu121]", ] flashattention = [ "packaging ; ('linux' in sys_platform)", "ninja ; ('linux' in sys_platform)", "flash-attn>=2.6.3 ; ('linux' in sys_platform)", ] colab-ampere = [ "unsloth[colab-ampere-torch220]", "unsloth[flashattention]", ] cu118-ampere = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu118only]", "unsloth[flashattention]", ] cu121-ampere = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu121only]", "unsloth[flashattention]", ] cu118-ampere-torch211 = [ "unsloth[huggingface]", "bitsandbytes==0.45.5", "unsloth[cu118onlytorch211]", "unsloth[flashattention]", ] cu121-ampere-torch211 = [ "unsloth[huggingface]", "bitsandbytes==0.45.5", "unsloth[cu121onlytorch211]", "unsloth[flashattention]", ] cu118-ampere-torch220 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu118onlytorch220]", "unsloth[flashattention]", ] cu121-ampere-torch220 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu121onlytorch220]", "unsloth[flashattention]", ] cu118-ampere-torch230 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu118onlytorch230]", "unsloth[flashattention]", ] cu121-ampere-torch230 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu121onlytorch230]", "unsloth[flashattention]", ] cu118-ampere-torch240 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu118onlytorch240]", "unsloth[flashattention]", ] cu121-ampere-torch240 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu121onlytorch240]", "unsloth[flashattention]", ] cu124-ampere-torch240 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu124onlytorch240]", "unsloth[flashattention]", ] cu118-ampere-torch250 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu118onlytorch250]", "unsloth[flashattention]", ] cu121-ampere-torch250 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu121onlytorch250]", "unsloth[flashattention]", ] cu124-ampere-torch250 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu124onlytorch250]", "unsloth[flashattention]", ] cu118-ampere-torch251 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu118onlytorch251]", "unsloth[flashattention]", ] cu121-ampere-torch251 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu121onlytorch251]", "unsloth[flashattention]", ] cu124-ampere-torch251 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu124onlytorch251]", "unsloth[flashattention]", ] cu118-ampere-torch260 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu118onlytorch260]", "unsloth[flashattention]", ] cu124-ampere-torch260 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu124onlytorch260]", "unsloth[flashattention]", ] cu126-ampere-torch260 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu126onlytorch260]", "unsloth[flashattention]", ] cu118-ampere-torch270 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu118onlytorch270]", "unsloth[flashattention]", ] cu126-ampere-torch270 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu126onlytorch270]", "unsloth[flashattention]", ] cu128-ampere-torch270 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu128onlytorch270]", "unsloth[flashattention]", ] cu118-ampere-torch271 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu118onlytorch271]", "unsloth[flashattention]", ] cu126-ampere-torch271 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu126onlytorch271]", "unsloth[flashattention]", ] cu128-ampere-torch271 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu128onlytorch271]", "unsloth[flashattention]", ] cu118-ampere-torch280 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu118onlytorch280]", "unsloth[flashattention]", ] cu126-ampere-torch280 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu126onlytorch280]", "unsloth[flashattention]", ] cu128-ampere-torch280 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu128onlytorch280]", "unsloth[flashattention]", ] cu130-ampere-torch280 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu130onlytorch280]", "unsloth[flashattention]", ] cu126-ampere-torch290 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu126onlytorch290]", ] cu128-ampere-torch290 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu128onlytorch290]", ] cu130-ampere-torch290 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu130onlytorch290]", ] cu126-ampere-torch291 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu126onlytorch291]", ] cu128-ampere-torch291 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu128onlytorch291]", ] cu130-ampere-torch291 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu130onlytorch291]", ] cu126-ampere-torch2100 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu126onlytorch2100]", ] cu128-ampere-torch2100 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu128onlytorch2100]", ] cu130-ampere-torch2100 = [ "unsloth[huggingface]", "bitsandbytes>=0.45.5,!=0.46.0,!=0.48.0", "unsloth[cu130onlytorch2100]", ] flashattentiontorch260abiFALSEcu12x = [ "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp39-cp39-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.9'", "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.10'", "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.11'", "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.12'", "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp313-cp313-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.13'", ] flashattentiontorch260abiTRUEcu12x = [ "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiTRUE-cp39-cp39-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.9'", "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiTRUE-cp310-cp310-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.10'", "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiTRUE-cp311-cp311-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.11'", "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiTRUE-cp312-cp312-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.12'", "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiTRUE-cp313-cp313-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.13'", ] flashattentiontorch250abiFALSEcu12x = [ "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiFALSE-cp39-cp39-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.9'", "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.10'", "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.11'", "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.12'", "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiFALSE-cp313-cp313-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.13'", ] flashattentiontorch250abiTRUEcu12x = [ "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiTRUE-cp39-cp39-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.9'", "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiTRUE-cp310-cp310-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.10'", "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiTRUE-cp311-cp311-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.11'", "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiTRUE-cp312-cp312-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.12'", "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiTRUE-cp313-cp313-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.13'", ] flashattentiontorch240abiFALSEcu12x = [ "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiFALSE-cp39-cp39-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.9'", "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.10'", "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiFALSE-cp311-cp311-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.11'", "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.12'", "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiFALSE-cp313-cp313-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.13'", ] flashattentiontorch240abiTRUEcu12x = [ "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiTRUE-cp39-cp39-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.9'", "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiTRUE-cp310-cp310-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.10'", "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiTRUE-cp311-cp311-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.11'", "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiTRUE-cp312-cp312-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.12'", "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.4cxx11abiTRUE-cp313-cp313-linux_x86_64.whl ; ('linux' in sys_platform) and python_version == '3.13'", ] intelgputorch260 = [ "unsloth_zoo[intelgpu]", "unsloth[huggingfacenotorch]", "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.2.0-cp39-cp39-linux_x86_64.whl#sha256=147607f190a7d7aa24ba454def5977fbbfec792fdae18e4ed278cfec29b69271 ; ('linux' in sys_platform) and python_version == '3.9' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.2.0-cp310-cp310-linux_x86_64.whl#sha256=23aa423fa1542afc34f67eb3ba8ef20060f6d1b3a4697eaeab22b11c92b30f2b ; ('linux' in sys_platform) and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.2.0-cp311-cp311-linux_x86_64.whl#sha256=bcfa995229bbfd9ffd8d6c8d9f6428d393e876fa6e23ee3c20e3c0d73ca75ca5 ; ('linux' in sys_platform) and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.2.0-cp312-cp312-linux_x86_64.whl#sha256=bd340903d03470708df3442438acb8b7e08087ab9e61fbe349b2872bf9257ab0 ; ('linux' in sys_platform) and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.2.0-cp313-cp313-linux_x86_64.whl#sha256=814dccc8a07159e6eca74bed70091bc8fea2d9dd87b0d91845f9f38cde62f01c ; ('linux' in sys_platform) and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "bitsandbytes @ https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-manylinux_2_24_x86_64.whl ; ('linux' in sys_platform) and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "bitsandbytes @ https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-win_amd64.whl ; (sys_platform == 'win32') and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torch @ https://download.pytorch.org/whl/xpu/torch-2.6.0%2Bxpu-cp39-cp39-linux_x86_64.whl#sha256=6a8adf6dc4c089406e8b3a7e58ab57a463bddf9b07130d2576e76eced43e92af ; ('linux' in sys_platform) and python_version == '3.9' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torch @ https://download.pytorch.org/whl/xpu/torch-2.6.0%2Bxpu-cp310-cp310-linux_x86_64.whl#sha256=ff4561cbf07c83bbccaa0f6e9bb0e6dcf721bacd53c9c43c4eb0e7331b4792f9 ; ('linux' in sys_platform) and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torch @ https://download.pytorch.org/whl/xpu/torch-2.6.0%2Bxpu-cp311-cp311-linux_x86_64.whl#sha256=12005f66b810ddd3ab93f86c4522bcfdd412cbd27fc9d189b661ff7509bc5e8a ; ('linux' in sys_platform) and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torch @ https://download.pytorch.org/whl/xpu/torch-2.6.0%2Bxpu-cp312-cp312-linux_x86_64.whl#sha256=c4c5c67625cdacf35765c2b94e61fe166e3c3f4a14521b1212a59ad1b3eb0f2e ; ('linux' in sys_platform) and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torch @ https://download.pytorch.org/whl/xpu/torch-2.6.0%2Bxpu-cp313-cp313-linux_x86_64.whl#sha256=e6864f7a60a5ecc43d5d38f59a16e5dd132384f73dfd3a697f74944026038f7b ; ('linux' in sys_platform) and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", ] intel-gpu-torch260 = [ "unsloth[intelgputorch260]" ] intelgputorch270 = [ "unsloth_zoo[intelgpu]", "unsloth[huggingfacenotorch]", "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.3.0-cp39-cp39-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=749a7098492c6a27b356c97149a4a62973b953eae60bc1b6259260974f344913 ; ('linux' in sys_platform) and python_version == '3.9' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.3.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=44362e80abd752471a08341093321955b066daa2cfb4810e73b8e3b240850f93 ; ('linux' in sys_platform) and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.3.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=faa6b8c945a837a080f641bc8ccc77a98fa66980dcd7e62e715fd853737343fd ; ('linux' in sys_platform) and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.3.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=40f6fb65b345dc9a61813abe7ac9a585f2c9808f414d140cc2a5f11f53ee063c ; ('linux' in sys_platform) and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.3.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=b22b4c02ec71b4bfc862ae3cdfd2871dc0b05d2b1802f5db2196e0f897d581e9 ; ('linux' in sys_platform) and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.3.0-cp39-cp39-win_amd64.whl#sha256=d4b738d7fa5100c1bd766f91614962828a4810eb57b4df92cd5214a83505a752 ; sys_platform == 'win32' and python_version == '3.9' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.3.0-cp310-cp310-win_amd64.whl#sha256=143fe8a64d807bcdb7d81bbc062816add325570aa160448454ab6ded4a0a17a1 ; sys_platform == 'win32' and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.3.0-cp311-cp311-win_amd64.whl#sha256=a8025459ff325d6e3532eb5cf72519db1b178155e7d60aff6c56beb5968fc758 ; sys_platform == 'win32' and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.3.0-cp312-cp312-win_amd64.whl#sha256=0dd07e6d5b872e42e48f5ee140e609d4554ca3cc509d5bf509ac232267cf358e ; sys_platform == 'win32' and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.3.0-cp313-cp313-win_amd64.whl#sha256=a936a18182d8e065a9933afc9a3ebbffadd38604969f87c493831214539fc027 ; sys_platform == 'win32' and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "bitsandbytes @ https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-manylinux_2_24_x86_64.whl ; ('linux' in sys_platform) and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "bitsandbytes @ https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-win_amd64.whl ; (sys_platform == 'win32') and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torch @ https://download.pytorch.org/whl/xpu/torch-2.7.0%2Bxpu-cp39-cp39-linux_x86_64.whl#sha256=f8ee75e50fcbb37ed5b498299ca2264da99ab278a93fae2358e921e4a6e28273 ; ('linux' in sys_platform) and python_version == '3.9' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torch @ https://download.pytorch.org/whl/xpu/torch-2.7.0%2Bxpu-cp310-cp310-linux_x86_64.whl#sha256=d6fdc342961d98fdcd9d03dfd491a3208bb5f7fbb435841f8f72ce9fdcd2d026 ; ('linux' in sys_platform) and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torch @ https://download.pytorch.org/whl/xpu/torch-2.7.0%2Bxpu-cp311-cp311-linux_x86_64.whl#sha256=74d07f9357df5cf2bf223ad3c84de16346bfaa0504f988fdd5590d3e177e5e86 ; ('linux' in sys_platform) and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torch @ https://download.pytorch.org/whl/xpu/torch-2.7.0%2Bxpu-cp312-cp312-linux_x86_64.whl#sha256=c806d44aa2ca5d225629f6fbc6c994d5deaac2d2cde449195bc8e3522ddd219a ; ('linux' in sys_platform) and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torch @ https://download.pytorch.org/whl/xpu/torch-2.7.0%2Bxpu-cp313-cp313-linux_x86_64.whl#sha256=25d8277b7f01d42e2e014ccbab57a2692b6ec4eff8dcf894eda1b297407cf97a ; ('linux' in sys_platform) and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torch @ https://download.pytorch.org/whl/xpu/torch-2.7.0%2Bxpu-cp39-cp39-win_amd64.whl#sha256=046e85125266ae69c1a0d083e6c092f947ab4b6b41532c16bafe40dbced845df ; sys_platform == 'win32' and python_version == '3.9' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torch @ https://download.pytorch.org/whl/xpu/torch-2.7.0%2Bxpu-cp310-cp310-win_amd64.whl#sha256=9ebaeffb82b0b3e39b6030927d3ebe0eb62a0e9045a3b2d7b0a9e7b15222c0db ; sys_platform == 'win32' and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torch @ https://download.pytorch.org/whl/xpu/torch-2.7.0%2Bxpu-cp311-cp311-win_amd64.whl#sha256=356ba66cee127e7e2c942880bd50e03768306a4ea08d358a0f29c6eebfc4bc81 ; sys_platform == 'win32' and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torch @ https://download.pytorch.org/whl/xpu/torch-2.7.0%2Bxpu-cp312-cp312-win_amd64.whl#sha256=94739e665d9b4d5cd7af5f517cb6103f6f9fb421c095184609653a24524040f5 ; sys_platform == 'win32' and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torch @ https://download.pytorch.org/whl/xpu/torch-2.7.0%2Bxpu-cp313-cp313-win_amd64.whl#sha256=31df3cb674918e89bc8c532baa331dc84f4430e1f9c0ec379232db44cba78355 ; sys_platform == 'win32' and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", ] intel-gpu-torch270 = [ "unsloth[intelgputorch270]" ] intelgputorch280 = [ "unsloth_zoo[intelgpu]", "unsloth[huggingfacenotorch]", "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.4.0-cp39-cp39-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=ac4d8e33986b1c3c5e48151640539272b2187e83016985853111b46fb82c3c94 ; 'linux' in sys_platform and python_version == '3.9' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.4.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=999fef4c1f711092b9d3086525920545df490de476ecebe899ffc777019ae17f ; 'linux' in sys_platform and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.4.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=57b09c8c492985ff6a27cd3a22b08e8f7b96b407bd8030967b6efbb9f63b80cf ; 'linux' in sys_platform and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.4.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=df4bb3282bac9a3b90231700077110d8680b338416de03c2b7c6133c9b602649 ; 'linux' in sys_platform and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.4.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=60da63c99ca827bdcb0df28e0298bf7d066dc607454c6d6176783cb4e79d838b ; 'linux' in sys_platform and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.4.0-cp39-cp39-win_amd64.whl#sha256=64aea8de349f3e2e0ebf4c24b011a8122531fdffda5776edaef45829cc241cf8 ; sys_platform == 'win32' and python_version == '3.9' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.4.0-cp310-cp310-win_amd64.whl#sha256=ae573d255b257fdbed319a3440dc9d0a721e31160ab7f6eba1b2226e6a409a1d ; sys_platform == 'win32' and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.4.0-cp311-cp311-win_amd64.whl#sha256=8e0ea4558e5776d8ddab0264310be9b26aee5641bcac0da023537556d4317b86 ; sys_platform == 'win32' and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.4.0-cp312-cp312-win_amd64.whl#sha256=4090dde07a4fffc34aaf855701a9db28e9fccb57b368ade520f1a0f8e811c878 ; sys_platform == 'win32' and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.4.0-cp313-cp313-win_amd64.whl#sha256=a33d0888f3c8df028a2d028842715837d0049524d6c06b9bb11869890a13601a ; sys_platform == 'win32' and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torch @ https://download.pytorch.org/whl/xpu/torch-2.8.0%2Bxpu-cp39-cp39-linux_x86_64.whl ; 'linux' in sys_platform and python_version == '3.9' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torch @ https://download.pytorch.org/whl/xpu/torch-2.8.0%2Bxpu-cp310-cp310-linux_x86_64.whl ; 'linux' in sys_platform and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torch @ https://download.pytorch.org/whl/xpu/torch-2.8.0%2Bxpu-cp311-cp311-linux_x86_64.whl ; 'linux' in sys_platform and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torch @ https://download.pytorch.org/whl/xpu/torch-2.8.0%2Bxpu-cp312-cp312-linux_x86_64.whl ; 'linux' in sys_platform and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torch @ https://download.pytorch.org/whl/xpu/torch-2.8.0%2Bxpu-cp313-cp313-linux_x86_64.whl ; 'linux' in sys_platform and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torch @ https://download.pytorch.org/whl/xpu/torch-2.8.0%2Bxpu-cp39-cp39-win_amd64.whl#sha256=f2f401276892428e4875cf1d8717c5cbab704b16fc594ccf23795e7b16549a99 ; sys_platform == 'win32' and python_version == '3.9' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torch @ https://download.pytorch.org/whl/xpu/torch-2.8.0%2Bxpu-cp310-cp310-win_amd64.whl#sha256=125c60cd59d51b39581a7e9afcd4679bc3a6b8c1f9440b1bb502a23fdd60571e ; sys_platform == 'win32' and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torch @ https://download.pytorch.org/whl/xpu/torch-2.8.0%2Bxpu-cp311-cp311-win_amd64.whl#sha256=47f1a57258cd460e80b38b2ed6744e31587ab77a96b4215bf59546cb4bab5cc0 ; sys_platform == 'win32' and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torch @ https://download.pytorch.org/whl/xpu/torch-2.8.0%2Bxpu-cp312-cp312-win_amd64.whl#sha256=0937d8943c145a83d9bafc6f80ef28971167817f9eda26066d33f72caf8a6646 ; sys_platform == 'win32' and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torch @ https://download.pytorch.org/whl/xpu/torch-2.8.0%2Bxpu-cp313-cp313-win_amd64.whl#sha256=e034aab1d71760dc80a731531be43673ffe15e99033b82d24e40d2e6d41bd8bf ; sys_platform == 'win32' and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "bitsandbytes @ https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-manylinux_2_24_x86_64.whl ; ('linux' in sys_platform) and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "bitsandbytes @ https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-win_amd64.whl ; (sys_platform == 'win32') and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.23.0%2Bxpu-cp39-cp39-manylinux_2_28_x86_64.whl#sha256=6e981c192045fc249c008441179ff237bb00174d818b875b0475730b63f0eaca ; 'linux' in sys_platform and python_version == '3.9' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.23.0%2Bxpu-cp310-cp310-manylinux_2_28_x86_64.whl#sha256=e5ba4805969277175ebfd59cc717093528cc6e3ada89ac2725fc7a3c1fee6169 ; 'linux' in sys_platform and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.23.0%2Bxpu-cp311-cp311-manylinux_2_28_x86_64.whl#sha256=74c39c144104416bc4c5ad8c26ab0c169dc5cc6be58059e01bc3665dd0ef676f ; 'linux' in sys_platform and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.23.0%2Bxpu-cp312-cp312-manylinux_2_28_x86_64.whl#sha256=0acec355b80c3899841184084f365df336c508602812e34a44007b8b60d53af4 ; 'linux' in sys_platform and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.23.0%2Bxpu-cp313-cp313-manylinux_2_28_x86_64.whl#sha256=e2109ae773dad27b98ca17681044b4f876563c37f2382b75de3a371399edcff8 ; 'linux' in sys_platform and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.23.0%2Bxpu-cp39-cp39-win_amd64.whl#sha256=5f7904e7048d414379bc8c1167260f1e84204f105db2d0a2f9c89e87ce1cf205 ; sys_platform == 'win32' and python_version == '3.9' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.23.0%2Bxpu-cp310-cp310-win_amd64.whl#sha256=005fca5e658ca8e37adb63c1a021c84f5e56dfa6cf0d601d89cfe40b9473f79f ; sys_platform == 'win32' and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.23.0%2Bxpu-cp311-cp311-win_amd64.whl#sha256=c6d030f5361461550c0ff1339b5bca8585fc1e84fda2e64b6184e65a581e4f98 ; sys_platform == 'win32' and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.23.0%2Bxpu-cp312-cp312-win_amd64.whl#sha256=91aafd61864cdce27461cbec13ddbf28c1bc6494265a1e4b80131c64a3b7d18f ; sys_platform == 'win32' and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.23.0%2Bxpu-cp313-cp313-win_amd64.whl#sha256=71dc4a6421742ed1e7f585b04a100ad53615c341fbccfbc255aefb38ea9091da ; sys_platform == 'win32' and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", ] intel-gpu-torch280 = [ "unsloth[intelgputorch280]" ] intelgputorch290 = [ "unsloth_zoo[intelgpu]", "unsloth[huggingfacenotorch]", "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=c169a1de14c19673b17c751290d467fa282fc90fa5da4314b2e5cdab1f553146 ; platform_system == 'Linux' and python_version == '3.10' and platform_machine == 'x86_64'", "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=013d9dd5d6479bd22983161f462e61c8dbe1d82e6730624a7a8d5945507eaa61 ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'", "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=afc8cabfbf7ed51fd278d1e0f88d6afc157b0201bad4b99d681e4d542f9e66d4 ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'", "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=0d24c1716088f2764d0d24c64227732195b6a42706c3c5fc89eeb4904bfa0818 ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'", "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp310-cp310-win_amd64.whl#sha256=c83ab007311d9cfb6e809ee5a4587d99a9eef4be720b90da4f1aaa68b45139a0 ; sys_platform == 'win32' and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp311-cp311-win_amd64.whl#sha256=debf75348da8e8c7166b4d4a9b91d1508bb8d6581e339f79f7604b2e6746bacd ; sys_platform == 'win32' and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp312-cp312-win_amd64.whl#sha256=97337a47425f1963a723475bd61037460e84ba01db4f87a1d662c3718ff6c47e ; sys_platform == 'win32' and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp313-cp313-win_amd64.whl#sha256=2caf8138695f6abb023ecd02031a2611ba1bf8fff2f19802567cb2fadefe9e87 ; sys_platform == 'win32' and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torch @ https://download.pytorch.org/whl/xpu/torch-2.9.0%2Bxpu-cp310-cp310-linux_x86_64.whl#sha256=5afbe860ce991825a36b75706a523601087e414b77598ef0d9d3d565741c277d ; platform_system == 'Linux' and python_version == '3.10' and platform_machine == 'x86_64'", "torch @ https://download.pytorch.org/whl/xpu/torch-2.9.0%2Bxpu-cp311-cp311-linux_x86_64.whl#sha256=607fe419c32d6e8e0556f745742e7cff1d0babce51f54be890e0c1422359c442 ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'", "torch @ https://download.pytorch.org/whl/xpu/torch-2.9.0%2Bxpu-cp312-cp312-linux_x86_64.whl#sha256=376bae584d89980b8e59934d248c38d5fa3b7d4687a4df1a19f4bc1d23dcc8c1 ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'", "torch @ https://download.pytorch.org/whl/xpu/torch-2.9.0%2Bxpu-cp313-cp313-linux_x86_64.whl#sha256=98d6a06dd7fb185874367b18bd609f05f16fdce4142a5980ca94461949965cd2 ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'", "torch @ https://download.pytorch.org/whl/xpu/torch-2.9.0%2Bxpu-cp310-cp310-win_amd64.whl#sha256=47cc68f631f65bd9c84924d052cd04dec7531023caa85e80345e9c94611c887d ; sys_platform == 'win32' and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torch @ https://download.pytorch.org/whl/xpu/torch-2.9.0%2Bxpu-cp311-cp311-win_amd64.whl#sha256=d56c44ab4818aba57e5c7b628f422d014e0d507427170a771c5be85e308b0bc6 ; sys_platform == 'win32' and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torch @ https://download.pytorch.org/whl/xpu/torch-2.9.0%2Bxpu-cp312-cp312-win_amd64.whl#sha256=18cad93aaff76a01ce73aef6935ece7cfc03344b905592ec731446c44d44592b ; sys_platform == 'win32' and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torch @ https://download.pytorch.org/whl/xpu/torch-2.9.0%2Bxpu-cp313-cp313-win_amd64.whl#sha256=579929cdc10a76800ead41289cac191ea36d1b16f5f501d3fc25607d4375cd83 ; sys_platform == 'win32' and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "bitsandbytes @ https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-manylinux_2_24_x86_64.whl ; ('linux' in sys_platform) and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "bitsandbytes @ https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-win_amd64.whl ; (sys_platform == 'win32') and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.24.0%2Bxpu-cp310-cp310-manylinux_2_28_x86_64.whl#sha256=cbfae2b79b7549fd368c2462fc8e94f8f26cc450782ee72138e908077c09a519 ; platform_system == 'Linux' and python_version == '3.10' and platform_machine == 'x86_64'", "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.24.0%2Bxpu-cp311-cp311-manylinux_2_28_x86_64.whl#sha256=044fa36ef4b6b43edcd490b75c853fa4b3eb033c2bded29f8fbcf27734713c67 ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'", "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.24.0%2Bxpu-cp312-cp312-manylinux_2_28_x86_64.whl#sha256=4b91e4bec1d740a6211f02578a79888550b73f3a4e1383035f8f6d72f587212c ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'", "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.24.0%2Bxpu-cp313-cp313-manylinux_2_28_x86_64.whl#sha256=88239e73ca37254bec84f29cd5887e10ff712de7edbbda3fbb3609cd6190d99e ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'", "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.24.0%2Bxpu-cp310-cp310-win_amd64.whl#sha256=19c7da8ca767d593e13a88a12bb08d06e34a673f6f26c2f9c191d60e81c02953 ; sys_platform == 'win32' and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.24.0%2Bxpu-cp311-cp311-win_amd64.whl#sha256=9bb0d1421c544ac8e2eca5b47daacaf54706dc9139c003aa5e77ee5f355c5931 ; sys_platform == 'win32' and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.24.0%2Bxpu-cp312-cp312-win_amd64.whl#sha256=6a5194bc736089606342d48a3f6822829b167617e9495d91d753dd1bd46fda18 ; sys_platform == 'win32' and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.24.0%2Bxpu-cp313-cp313-win_amd64.whl#sha256=da47a3ce2bb7f0301a31124668b5908f9b9e92d6241443de15a310ef9632fd83 ; sys_platform == 'win32' and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", ] intel-gpu-torch290 = [ "unsloth[intelgputorch290]" ] intelgputorch210 = [ "unsloth_zoo[intelgpu]", "unsloth[huggingfacenotorch]", "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=c169a1de14c19673b17c751290d467fa282fc90fa5da4314b2e5cdab1f553146 ; platform_system == 'Linux' and python_version == '3.10' and platform_machine == 'x86_64'", "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=013d9dd5d6479bd22983161f462e61c8dbe1d82e6730624a7a8d5945507eaa61 ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'", "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=afc8cabfbf7ed51fd278d1e0f88d6afc157b0201bad4b99d681e4d542f9e66d4 ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'", "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl#sha256=0d24c1716088f2764d0d24c64227732195b6a42706c3c5fc89eeb4904bfa0818 ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'", "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp310-cp310-win_amd64.whl#sha256=c83ab007311d9cfb6e809ee5a4587d99a9eef4be720b90da4f1aaa68b45139a0 ; sys_platform == 'win32' and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp311-cp311-win_amd64.whl#sha256=debf75348da8e8c7166b4d4a9b91d1508bb8d6581e339f79f7604b2e6746bacd ; sys_platform == 'win32' and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp312-cp312-win_amd64.whl#sha256=97337a47425f1963a723475bd61037460e84ba01db4f87a1d662c3718ff6c47e ; sys_platform == 'win32' and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "pytorch_triton_xpu @ https://download.pytorch.org/whl/pytorch_triton_xpu-3.5.0-cp313-cp313-win_amd64.whl#sha256=2caf8138695f6abb023ecd02031a2611ba1bf8fff2f19802567cb2fadefe9e87 ; sys_platform == 'win32' and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torch @ https://download.pytorch.org/whl/xpu/torch-2.10.0%2Bxpu-cp310-cp310-linux_x86_64.whl#sha256=abb1d1ec1ac672bac0ff35420c965f2df0c636ef9d94e2a830e34578489d0a57 ; platform_system == 'Linux' and python_version == '3.10' and platform_machine == 'x86_64'", "torch @ https://download.pytorch.org/whl/xpu/torch-2.10.0%2Bxpu-cp311-cp311-linux_x86_64.whl#sha256=71ad2f82da0f41eaec159f39fc85854e27c2391efa91b373e550648a6f4aaad3 ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'", "torch @ https://download.pytorch.org/whl/xpu/torch-2.10.0%2Bxpu-cp312-cp312-linux_x86_64.whl#sha256=b473571d478912f92881cc13f15fa18f8463fb0fb8a068c96ed47a7d45a4da0a ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'", "torch @ https://download.pytorch.org/whl/xpu/torch-2.10.0%2Bxpu-cp313-cp313-linux_x86_64.whl#sha256=3bc64a746ff25a93de140902c60c9e819d7413f5cea1e88d80999c27a5901e9c ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'", "torch @ https://download.pytorch.org/whl/xpu/torch-2.10.0%2Bxpu-cp310-cp310-win_amd64.whl#sha256=ce50691ab3fb6301d9b7bb8b3834cf5fa7152a2b5f91fd24c5efdc601a25b780 ; sys_platform == 'win32' and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torch @ https://download.pytorch.org/whl/xpu/torch-2.10.0%2Bxpu-cp311-cp311-win_amd64.whl#sha256=cb9d37f21cb9fb7df67d62863f021c3144e8d8832b9ea8e8523ac308bc620ea1 ; sys_platform == 'win32' and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torch @ https://download.pytorch.org/whl/xpu/torch-2.10.0%2Bxpu-cp312-cp312-win_amd64.whl#sha256=3ad605be4728b6d3a28a44d07dd794b1a9e45551b0057815bf25eb2a6d6a56a7 ; sys_platform == 'win32' and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torch @ https://download.pytorch.org/whl/xpu/torch-2.10.0%2Bxpu-cp313-cp313-win_amd64.whl#sha256=2b4b56dd6c792aef82006904fa888692e3782e4ae5da27526801bad4898f05a5 ; sys_platform == 'win32' and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "bitsandbytes @ https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-manylinux_2_24_x86_64.whl ; ('linux' in sys_platform) and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "bitsandbytes @ https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-1.33.7.preview-py3-none-win_amd64.whl ; (sys_platform == 'win32') and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.25.0%2Bxpu-cp310-cp310-manylinux_2_28_x86_64.whl#sha256=7e1e7b170fcf7161c8499b67156c5a05462243626dc0974010791a0bab4378d3 ; platform_system == 'Linux' and python_version == '3.10' and platform_machine == 'x86_64'", "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.25.0%2Bxpu-cp311-cp311-manylinux_2_28_x86_64.whl#sha256=bd6add201bd7628af70437292e1447abb368e0b5f4ff9abd334ae435efd44792 ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'", "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.25.0%2Bxpu-cp312-cp312-manylinux_2_28_x86_64.whl#sha256=6ad2543496bc29e59d3dd614a94d09aa9870318aedb66045344fffddfedd2cf8 ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'", "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.25.0%2Bxpu-cp313-cp313-manylinux_2_28_x86_64.whl#sha256=80269f37865fcd8b57f20e4786efae2200bfa2b2727926c3c7acc82f0e7d3548 ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'", "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.25.0%2Bxpu-cp310-cp310-win_amd64.whl#sha256=6b9485ba85dcba4d196d6134d9c3332fb228fb2556416bf0450a64e8a472fcba ; sys_platform == 'win32' and python_version == '3.10' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.25.0%2Bxpu-cp311-cp311-win_amd64.whl#sha256=36cbaedf10f6412af5c89afd9aeea474e6a56a0050348ada8fabe1ecaf6b879e ; sys_platform == 'win32' and python_version == '3.11' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.25.0%2Bxpu-cp312-cp312-win_amd64.whl#sha256=738357d97468d75fe3d510ac37e65130f2787f81d9bbc1518898f7396dc3403f ; sys_platform == 'win32' and python_version == '3.12' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "torchvision @ https://download.pytorch.org/whl/xpu/torchvision-0.25.0%2Bxpu-cp313-cp313-win_amd64.whl#sha256=1c4b44b36a557f7381e3076fb8843366742238648441d607c8d049c6da0f8886 ; sys_platform == 'win32' and python_version == '3.13' and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", ] intel-gpu-torch210 = [ "unsloth[intelgputorch210]" ] intel = [ "unsloth[intelgputorch280]", ] amd = [ "unsloth[huggingfacenotorch]", "bitsandbytes>=0.49.1 ; ('linux' in sys_platform) and (platform_machine == 'AMD64' or platform_machine == 'x86_64' or platform_machine == 'aarch64')", "bitsandbytes>=0.49.1 ; (sys_platform == 'win32') and (platform_machine == 'AMD64' or platform_machine == 'x86_64')", ] rocm702-torch280 = [ "unsloth[amd]", "triton @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.0.2/triton-3.4.0%2Brocm7.0.2.gitf9e5bf54-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'", "triton @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.0.2/triton-3.4.0%2Brocm7.0.2.gitf9e5bf54-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'", "triton @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.0.2/triton-3.4.0%2Brocm7.0.2.gitf9e5bf54-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'", "torch @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.0.2/torch-2.8.0%2Brocm7.0.2.lw.git245bf6ed-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'", "torch @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.0.2/torch-2.8.0%2Brocm7.0.2.lw.git245bf6ed-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'", "torch @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.0.2/torch-2.8.0%2Brocm7.0.2.lw.git245bf6ed-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'", "torchvision @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.0.2/torchvision-0.23.0%2Brocm7.0.2.git824e8c87-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'", "torchvision @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.0.2/torchvision-0.23.0%2Brocm7.0.2.git824e8c87-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'", "torchvision @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.0.2/torchvision-0.23.0%2Brocm7.0.2.git824e8c87-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'", ] rocm72-torch291 = [ "unsloth[amd]", "triton @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/triton-3.5.1%2Brocm7.2.0.gita272dfa8-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.10' and platform_machine == 'x86_64'", "triton @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/triton-3.5.1%2Brocm7.2.0.gita272dfa8-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'", "triton @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/triton-3.5.1%2Brocm7.2.0.gita272dfa8-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'", "triton @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/triton-3.5.1%2Brocm7.2.0.gita272dfa8-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'", "torch @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torch-2.9.1%2Brocm7.2.0.lw.git7e1940d4-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.10' and platform_machine == 'x86_64'", "torch @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torch-2.9.1%2Brocm7.2.0.lw.git7e1940d4-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'", "torch @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torch-2.9.1%2Brocm7.2.0.lw.git7e1940d4-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'", "torch @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torch-2.9.1%2Brocm7.2.0.lw.git7e1940d4-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'", "torch @ https://repo.radeon.com/rocm/windows/rocm-rel-7.2/torch-2.9.1%2Brocmsdk20260116-cp312-cp312-win_amd64.whl ; sys_platform == 'win32' and python_version == '3.12'", "torchvision @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torchvision-0.24.0%2Brocm7.2.0.gitb919bd0c-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.10' and platform_machine == 'x86_64'", "torchvision @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torchvision-0.24.0%2Brocm7.2.0.gitb919bd0c-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'", "torchvision @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torchvision-0.24.0%2Brocm7.2.0.gitb919bd0c-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'", "torchvision @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torchvision-0.24.0%2Brocm7.2.0.gitb919bd0c-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'", "torchvision @ https://repo.radeon.com/rocm/windows/rocm-rel-7.2/torchvision-0.24.1%2Brocmsdk20260116-cp312-cp312-win_amd64.whl ; sys_platform == 'win32' and python_version == '3.12'", ] rocm711-torch291 = [ "unsloth[amd]", "triton @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/triton-3.5.1%2Brocm7.1.1.gita272dfa8-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.10' and platform_machine == 'x86_64'", "triton @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/triton-3.5.1%2Brocm7.1.1.gita272dfa8-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'", "triton @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/triton-3.5.1%2Brocm7.1.1.gita272dfa8-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'", "triton @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/triton-3.5.1%2Brocm7.1.1.gita272dfa8-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'", "torch @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/torch-2.9.1%2Brocm7.1.1.lw.git351ff442-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.10' and platform_machine == 'x86_64'", "torch @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/torch-2.9.1%2Brocm7.1.1.lw.git351ff442-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'", "torch @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/torch-2.9.1%2Brocm7.1.1.lw.git351ff442-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'", "torch @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/torch-2.9.1%2Brocm7.1.1.lw.git351ff442-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'", "torchvision @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/torchvision-0.24.0%2Brocm7.1.1.gitb919bd0c-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.10' and platform_machine == 'x86_64'", "torchvision @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/torchvision-0.24.0%2Brocm7.1.1.gitb919bd0c-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'", "torchvision @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/torchvision-0.24.0%2Brocm7.1.1.gitb919bd0c-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'", "torchvision @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/torchvision-0.24.0%2Brocm7.1.1.gitb919bd0c-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'", ] rocm72-torch2100 = [ "unsloth[amd]", "triton @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/triton-3.6.0%2Brocm7.2.0.gitba5c1517-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.10' and platform_machine == 'x86_64'", "triton @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/triton-3.6.0%2Brocm7.2.0.gitba5c1517-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'", "triton @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/triton-3.6.0%2Brocm7.2.0.gitba5c1517-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'", "triton @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/triton-3.6.0%2Brocm7.2.0.gitba5c1517-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'", "torch @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torch-2.10.0%2Brocm7.2.0.lw.gitb6ee5fde-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.10' and platform_machine == 'x86_64'", "torch @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torch-2.10.0%2Brocm7.2.0.lw.gitb6ee5fde-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'", "torch @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torch-2.10.0%2Brocm7.2.0.lw.gitb6ee5fde-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'", "torch @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torch-2.10.0%2Brocm7.2.0.lw.gitb6ee5fde-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'", "torchvision @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torchvision-0.25.0%2Brocm7.2.0.git82df5f59-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.10' and platform_machine == 'x86_64'", "torchvision @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torchvision-0.25.0%2Brocm7.2.0.git82df5f59-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'", "torchvision @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torchvision-0.25.0%2Brocm7.2.0.git82df5f59-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'", "torchvision @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torchvision-0.25.0%2Brocm7.2.0.git82df5f59-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'", ] rocm711-torch2100 = [ "unsloth[amd]", "triton @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/triton-3.6.0%2Brocm7.1.1.gitba5c1517-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.10' and platform_machine == 'x86_64'", "triton @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/triton-3.6.0%2Brocm7.1.1.gitba5c1517-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'", "triton @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/triton-3.6.0%2Brocm7.1.1.gitba5c1517-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'", "triton @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/triton-3.6.0%2Brocm7.1.1.gitba5c1517-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'", "torch @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/torch-2.10.0%2Brocm7.1.1.lw.gitd9556b05-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.10' and platform_machine == 'x86_64'", "torch @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/torch-2.10.0%2Brocm7.1.1.lw.gitd9556b05-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'", "torch @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/torch-2.10.0%2Brocm7.1.1.lw.gitd9556b05-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'", "torch @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/torch-2.10.0%2Brocm7.1.1.lw.gitd9556b05-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'", "torchvision @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/torchvision-0.25.0%2Brocm7.1.1.git82df5f59-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.10' and platform_machine == 'x86_64'", "torchvision @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/torchvision-0.25.0%2Brocm7.1.1.git82df5f59-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.11' and platform_machine == 'x86_64'", "torchvision @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/torchvision-0.25.0%2Brocm7.1.1.git82df5f59-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.12' and platform_machine == 'x86_64'", "torchvision @ https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/torchvision-0.25.0%2Brocm7.1.1.git82df5f59-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and python_version == '3.13' and platform_machine == 'x86_64'", ] [project.urls] homepage = "https://unsloth.ai" documentation = "https://unsloth.ai/docs" repository = "https://github.com/unslothai/unsloth" [tool.ruff] target-version = "py311" force-exclude = true extend-exclude = [ "*chat_templates.py", "*ollama_template_mappers.py", "*_auto_install.py", "*mapper.py", ] [tool.ruff.lint] select = ["E9", "F63", "F7", "F82"] ignore = [ "E402", "E722", "F403", "F405", "F811", "F821", "F841", "F401", "E731", "E741", "F601", "E712", ] [tool.ruff.format] ================================================ FILE: scripts/enforce_kwargs_spacing.py ================================================ #!/usr/bin/env python3 """Ensure keyword arguments use spaces around '=', prune redundant pass statements.""" from __future__ import annotations import ast import argparse import io import sys import tokenize from collections import defaultdict from pathlib import Path def enforce_spacing(text: str) -> tuple[str, bool]: """Return updated text with keyword '=' padded by spaces, plus change flag.""" lines = text.splitlines(keepends=True) if not lines: return text, False offsets: dict[int, int] = defaultdict(int) changed = False reader = io.StringIO(text).readline for token in tokenize.generate_tokens(reader): if token.type != tokenize.OP or token.string != "=": continue line_index = token.start[0] - 1 col = token.start[1] + offsets[line_index] if line_index < 0 or line_index >= len(lines): continue line = lines[line_index] if col >= len(line) or line[col] != "=": continue line_changed = False # Insert a space before '=' when missing and not preceded by whitespace. if col > 0 and line[col - 1] not in {" ", "\t"}: line = f"{line[:col]} {line[col:]}" offsets[line_index] += 1 col += 1 line_changed = True changed = True # Insert a space after '=' when missing and not followed by whitespace or newline. next_index = col + 1 if next_index < len(line) and line[next_index] not in {" ", "\t", "\n", "\r"}: line = f"{line[:next_index]} {line[next_index:]}" offsets[line_index] += 1 line_changed = True changed = True if line_changed: lines[line_index] = line if not changed: return text, False return "".join(lines), True def remove_redundant_passes(text: str) -> tuple[str, bool]: """Drop pass statements that share a block with other executable code.""" try: tree = ast.parse(text) except SyntaxError: return text, False redundant: list[ast.Pass] = [] def visit(node: ast.AST) -> None: for attr in ("body", "orelse", "finalbody"): value = getattr(node, attr, None) if not isinstance(value, list) or len(value) <= 1: continue for stmt in value: if isinstance(stmt, ast.Pass): redundant.append(stmt) for stmt in value: if isinstance(stmt, ast.AST): visit(stmt) handlers = getattr(node, "handlers", None) if handlers: for handler in handlers: visit(handler) visit(tree) if not redundant: return text, False lines = text.splitlines(keepends=True) changed = False for node in sorted( redundant, key=lambda item: (item.lineno, item.col_offset), reverse=True ): start = node.lineno - 1 end = (node.end_lineno or node.lineno) - 1 if start >= len(lines): continue changed = True if start == end: line = lines[start] col_start = node.col_offset col_end = node.end_col_offset or (col_start + 4) segment = line[:col_start] + line[col_end:] lines[start] = segment if segment.strip() else "" continue # Defensive fall-back for unexpected multi-line 'pass'. prefix = lines[start][: node.col_offset] lines[start] = prefix if prefix.strip() else "" for idx in range(start + 1, end): lines[idx] = "" suffix = lines[end][(node.end_col_offset or 0) :] lines[end] = suffix # Normalise to ensure lines end with newlines except at EOF. result_lines: list[str] = [] for index, line in enumerate(lines): if not line: continue if index < len(lines) - 1 and not line.endswith("\n"): result_lines.append(f"{line}\n") else: result_lines.append(line) return "".join(result_lines), changed def process_file(path: Path) -> bool: try: with tokenize.open(path) as handle: original = handle.read() encoding = handle.encoding except (OSError, SyntaxError) as exc: # SyntaxError from tokenize on invalid python print(f"Failed to read {path}: {exc}", file=sys.stderr) return False updated, changed = enforce_spacing(original) updated, removed = remove_redundant_passes(updated) if changed or removed: path.write_text(updated, encoding=encoding) return True return False def main(argv: list[str]) -> int: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("files", nargs="+", help="Python files to fix") args = parser.parse_args(argv) touched: list[Path] = [] self_path = Path(__file__).resolve() for entry in args.files: path = Path(entry) # Skip modifying this script to avoid self-edit loops. if path.resolve() == self_path: continue if not path.exists() or path.is_dir(): continue if process_file(path): touched.append(path) if touched: for path in touched: print(f"Adjusted kwarg spacing in {path}") return 0 if __name__ == "__main__": sys.exit(main(sys.argv[1:])) ================================================ FILE: scripts/run_ruff_format.py ================================================ #!/usr/bin/env python3 """Run `ruff format` followed by kwarg spacing enforcement.""" from __future__ import annotations import subprocess import sys from pathlib import Path HERE = Path(__file__).resolve().parent def main(argv: list[str]) -> int: files = [arg for arg in argv if Path(arg).exists()] if not files: return 0 ruff_cmd = [sys.executable, "-m", "ruff", "format", *files] ruff_proc = subprocess.run(ruff_cmd) if ruff_proc.returncode != 0: return ruff_proc.returncode spacing_script = HERE / "enforce_kwargs_spacing.py" spacing_cmd = [sys.executable, str(spacing_script), *files] spacing_proc = subprocess.run(spacing_cmd) return spacing_proc.returncode if __name__ == "__main__": raise SystemExit(main(sys.argv[1:])) ================================================ FILE: studio/LICENSE.AGPL-3.0 ================================================ GNU AFFERO GENERAL PUBLIC LICENSE Version 3, 19 November 2007 Copyright (C) 2007 Free Software Foundation, Inc. Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed. Preamble The GNU Affero General Public License is a free, copyleft license for software and other kinds of works, specifically designed to ensure cooperation with the community in the case of network server software. The licenses for most software and other practical works are designed to take away your freedom to share and change the works. By contrast, our General Public Licenses are intended to guarantee your freedom to share and change all versions of a program--to make sure it remains free software for all its users. When we speak of free software, we are referring to freedom, not price. Our General Public Licenses are designed to make sure that you have the freedom to distribute copies of free software (and charge for them if you wish), that you receive source code or can get it if you want it, that you can change the software or use pieces of it in new free programs, and that you know you can do these things. Developers that use our General Public Licenses protect your rights with two steps: (1) assert copyright on the software, and (2) offer you this License which gives you legal permission to copy, distribute and/or modify the software. A secondary benefit of defending all users' freedom is that improvements made in alternate versions of the program, if they receive widespread use, become available for other developers to incorporate. Many developers of free software are heartened and encouraged by the resulting cooperation. However, in the case of software used on network servers, this result may fail to come about. The GNU General Public License permits making a modified version and letting the public access it on a server without ever releasing its source code to the public. The GNU Affero General Public License is designed specifically to ensure that, in such cases, the modified source code becomes available to the community. It requires the operator of a network server to provide the source code of the modified version running there to the users of that server. Therefore, public use of a modified version, on a publicly accessible server, gives the public access to the source code of the modified version. An older license, called the Affero General Public License and published by Affero, was designed to accomplish similar goals. This is a different license, not a version of the Affero GPL, but Affero has released a new version of the Affero GPL which permits relicensing under this license. The precise terms and conditions for copying, distribution and modification follow. TERMS AND CONDITIONS 0. Definitions. "This License" refers to version 3 of the GNU Affero General Public License. "Copyright" also means copyright-like laws that apply to other kinds of works, such as semiconductor masks. "The Program" refers to any copyrightable work licensed under this License. Each licensee is addressed as "you". "Licensees" and "recipients" may be individuals or organizations. To "modify" a work means to copy from or adapt all or part of the work in a fashion requiring copyright permission, other than the making of an exact copy. The resulting work is called a "modified version" of the earlier work or a work "based on" the earlier work. A "covered work" means either the unmodified Program or a work based on the Program. To "propagate" a work means to do anything with it that, without permission, would make you directly or secondarily liable for infringement under applicable copyright law, except executing it on a computer or modifying a private copy. Propagation includes copying, distribution (with or without modification), making available to the public, and in some countries other activities as well. To "convey" a work means any kind of propagation that enables other parties to make or receive copies. Mere interaction with a user through a computer network, with no transfer of a copy, is not conveying. An interactive user interface displays "Appropriate Legal Notices" to the extent that it includes a convenient and prominently visible feature that (1) displays an appropriate copyright notice, and (2) tells the user that there is no warranty for the work (except to the extent that warranties are provided), that licensees may convey the work under this License, and how to view a copy of this License. If the interface presents a list of user commands or options, such as a menu, a prominent item in the list meets this criterion. 1. Source Code. The "source code" for a work means the preferred form of the work for making modifications to it. "Object code" means any non-source form of a work. A "Standard Interface" means an interface that either is an official standard defined by a recognized standards body, or, in the case of interfaces specified for a particular programming language, one that is widely used among developers working in that language. The "System Libraries" of an executable work include anything, other than the work as a whole, that (a) is included in the normal form of packaging a Major Component, but which is not part of that Major Component, and (b) serves only to enable use of the work with that Major Component, or to implement a Standard Interface for which an implementation is available to the public in source code form. A "Major Component", in this context, means a major essential component (kernel, window system, and so on) of the specific operating system (if any) on which the executable work runs, or a compiler used to produce the work, or an object code interpreter used to run it. The "Corresponding Source" for a work in object code form means all the source code needed to generate, install, and (for an executable work) run the object code and to modify the work, including scripts to control those activities. However, it does not include the work's System Libraries, or general-purpose tools or generally available free programs which are used unmodified in performing those activities but which are not part of the work. For example, Corresponding Source includes interface definition files associated with source files for the work, and the source code for shared libraries and dynamically linked subprograms that the work is specifically designed to require, such as by intimate data communication or control flow between those subprograms and other parts of the work. The Corresponding Source need not include anything that users can regenerate automatically from other parts of the Corresponding Source. The Corresponding Source for a work in source code form is that same work. 2. Basic Permissions. All rights granted under this License are granted for the term of copyright on the Program, and are irrevocable provided the stated conditions are met. This License explicitly affirms your unlimited permission to run the unmodified Program. The output from running a covered work is covered by this License only if the output, given its content, constitutes a covered work. This License acknowledges your rights of fair use or other equivalent, as provided by copyright law. You may make, run and propagate covered works that you do not convey, without conditions so long as your license otherwise remains in force. You may convey covered works to others for the sole purpose of having them make modifications exclusively for you, or provide you with facilities for running those works, provided that you comply with the terms of this License in conveying all material for which you do not control copyright. Those thus making or running the covered works for you must do so exclusively on your behalf, under your direction and control, on terms that prohibit them from making any copies of your copyrighted material outside their relationship with you. Conveying under any other circumstances is permitted solely under the conditions stated below. Sublicensing is not allowed; section 10 makes it unnecessary. 3. Protecting Users' Legal Rights From Anti-Circumvention Law. No covered work shall be deemed part of an effective technological measure under any applicable law fulfilling obligations under article 11 of the WIPO copyright treaty adopted on 20 December 1996, or similar laws prohibiting or restricting circumvention of such measures. When you convey a covered work, you waive any legal power to forbid circumvention of technological measures to the extent such circumvention is effected by exercising rights under this License with respect to the covered work, and you disclaim any intention to limit operation or modification of the work as a means of enforcing, against the work's users, your or third parties' legal rights to forbid circumvention of technological measures. 4. Conveying Verbatim Copies. You may convey verbatim copies of the Program's source code as you receive it, in any medium, provided that you conspicuously and appropriately publish on each copy an appropriate copyright notice; keep intact all notices stating that this License and any non-permissive terms added in accord with section 7 apply to the code; keep intact all notices of the absence of any warranty; and give all recipients a copy of this License along with the Program. You may charge any price or no price for each copy that you convey, and you may offer support or warranty protection for a fee. 5. Conveying Modified Source Versions. You may convey a work based on the Program, or the modifications to produce it from the Program, in the form of source code under the terms of section 4, provided that you also meet all of these conditions: a) The work must carry prominent notices stating that you modified it, and giving a relevant date. b) The work must carry prominent notices stating that it is released under this License and any conditions added under section 7. This requirement modifies the requirement in section 4 to "keep intact all notices". c) You must license the entire work, as a whole, under this License to anyone who comes into possession of a copy. This License will therefore apply, along with any applicable section 7 additional terms, to the whole of the work, and all its parts, regardless of how they are packaged. This License gives no permission to license the work in any other way, but it does not invalidate such permission if you have separately received it. d) If the work has interactive user interfaces, each must display Appropriate Legal Notices; however, if the Program has interactive interfaces that do not display Appropriate Legal Notices, your work need not make them do so. A compilation of a covered work with other separate and independent works, which are not by their nature extensions of the covered work, and which are not combined with it such as to form a larger program, in or on a volume of a storage or distribution medium, is called an "aggregate" if the compilation and its resulting copyright are not used to limit the access or legal rights of the compilation's users beyond what the individual works permit. Inclusion of a covered work in an aggregate does not cause this License to apply to the other parts of the aggregate. 6. Conveying Non-Source Forms. You may convey a covered work in object code form under the terms of sections 4 and 5, provided that you also convey the machine-readable Corresponding Source under the terms of this License, in one of these ways: a) Convey the object code in, or embodied in, a physical product (including a physical distribution medium), accompanied by the Corresponding Source fixed on a durable physical medium customarily used for software interchange. b) Convey the object code in, or embodied in, a physical product (including a physical distribution medium), accompanied by a written offer, valid for at least three years and valid for as long as you offer spare parts or customer support for that product model, to give anyone who possesses the object code either (1) a copy of the Corresponding Source for all the software in the product that is covered by this License, on a durable physical medium customarily used for software interchange, for a price no more than your reasonable cost of physically performing this conveying of source, or (2) access to copy the Corresponding Source from a network server at no charge. c) Convey individual copies of the object code with a copy of the written offer to provide the Corresponding Source. This alternative is allowed only occasionally and noncommercially, and only if you received the object code with such an offer, in accord with subsection 6b. d) Convey the object code by offering access from a designated place (gratis or for a charge), and offer equivalent access to the Corresponding Source in the same way through the same place at no further charge. You need not require recipients to copy the Corresponding Source along with the object code. If the place to copy the object code is a network server, the Corresponding Source may be on a different server (operated by you or a third party) that supports equivalent copying facilities, provided you maintain clear directions next to the object code saying where to find the Corresponding Source. Regardless of what server hosts the Corresponding Source, you remain obligated to ensure that it is available for as long as needed to satisfy these requirements. e) Convey the object code using peer-to-peer transmission, provided you inform other peers where the object code and Corresponding Source of the work are being offered to the general public at no charge under subsection 6d. A separable portion of the object code, whose source code is excluded from the Corresponding Source as a System Library, need not be included in conveying the object code work. A "User Product" is either (1) a "consumer product", which means any tangible personal property which is normally used for personal, family, or household purposes, or (2) anything designed or sold for incorporation into a dwelling. In determining whether a product is a consumer product, doubtful cases shall be resolved in favor of coverage. For a particular product received by a particular user, "normally used" refers to a typical or common use of that class of product, regardless of the status of the particular user or of the way in which the particular user actually uses, or expects or is expected to use, the product. A product is a consumer product regardless of whether the product has substantial commercial, industrial or non-consumer uses, unless such uses represent the only significant mode of use of the product. "Installation Information" for a User Product means any methods, procedures, authorization keys, or other information required to install and execute modified versions of a covered work in that User Product from a modified version of its Corresponding Source. The information must suffice to ensure that the continued functioning of the modified object code is in no case prevented or interfered with solely because modification has been made. If you convey an object code work under this section in, or with, or specifically for use in, a User Product, and the conveying occurs as part of a transaction in which the right of possession and use of the User Product is transferred to the recipient in perpetuity or for a fixed term (regardless of how the transaction is characterized), the Corresponding Source conveyed under this section must be accompanied by the Installation Information. But this requirement does not apply if neither you nor any third party retains the ability to install modified object code on the User Product (for example, the work has been installed in ROM). The requirement to provide Installation Information does not include a requirement to continue to provide support service, warranty, or updates for a work that has been modified or installed by the recipient, or for the User Product in which it has been modified or installed. Access to a network may be denied when the modification itself materially and adversely affects the operation of the network or violates the rules and protocols for communication across the network. Corresponding Source conveyed, and Installation Information provided, in accord with this section must be in a format that is publicly documented (and with an implementation available to the public in source code form), and must require no special password or key for unpacking, reading or copying. 7. Additional Terms. "Additional permissions" are terms that supplement the terms of this License by making exceptions from one or more of its conditions. Additional permissions that are applicable to the entire Program shall be treated as though they were included in this License, to the extent that they are valid under applicable law. If additional permissions apply only to part of the Program, that part may be used separately under those permissions, but the entire Program remains governed by this License without regard to the additional permissions. When you convey a copy of a covered work, you may at your option remove any additional permissions from that copy, or from any part of it. (Additional permissions may be written to require their own removal in certain cases when you modify the work.) You may place additional permissions on material, added by you to a covered work, for which you have or can give appropriate copyright permission. Notwithstanding any other provision of this License, for material you add to a covered work, you may (if authorized by the copyright holders of that material) supplement the terms of this License with terms: a) Disclaiming warranty or limiting liability differently from the terms of sections 15 and 16 of this License; or b) Requiring preservation of specified reasonable legal notices or author attributions in that material or in the Appropriate Legal Notices displayed by works containing it; or c) Prohibiting misrepresentation of the origin of that material, or requiring that modified versions of such material be marked in reasonable ways as different from the original version; or d) Limiting the use for publicity purposes of names of licensors or authors of the material; or e) Declining to grant rights under trademark law for use of some trade names, trademarks, or service marks; or f) Requiring indemnification of licensors and authors of that material by anyone who conveys the material (or modified versions of it) with contractual assumptions of liability to the recipient, for any liability that these contractual assumptions directly impose on those licensors and authors. All other non-permissive additional terms are considered "further restrictions" within the meaning of section 10. If the Program as you received it, or any part of it, contains a notice stating that it is governed by this License along with a term that is a further restriction, you may remove that term. If a license document contains a further restriction but permits relicensing or conveying under this License, you may add to a covered work material governed by the terms of that license document, provided that the further restriction does not survive such relicensing or conveying. If you add terms to a covered work in accord with this section, you must place, in the relevant source files, a statement of the additional terms that apply to those files, or a notice indicating where to find the applicable terms. Additional terms, permissive or non-permissive, may be stated in the form of a separately written license, or stated as exceptions; the above requirements apply either way. 8. Termination. You may not propagate or modify a covered work except as expressly provided under this License. Any attempt otherwise to propagate or modify it is void, and will automatically terminate your rights under this License (including any patent licenses granted under the third paragraph of section 11). However, if you cease all violation of this License, then your license from a particular copyright holder is reinstated (a) provisionally, unless and until the copyright holder explicitly and finally terminates your license, and (b) permanently, if the copyright holder fails to notify you of the violation by some reasonable means prior to 60 days after the cessation. Moreover, your license from a particular copyright holder is reinstated permanently if the copyright holder notifies you of the violation by some reasonable means, this is the first time you have received notice of violation of this License (for any work) from that copyright holder, and you cure the violation prior to 30 days after your receipt of the notice. Termination of your rights under this section does not terminate the licenses of parties who have received copies or rights from you under this License. If your rights have been terminated and not permanently reinstated, you do not qualify to receive new licenses for the same material under section 10. 9. Acceptance Not Required for Having Copies. You are not required to accept this License in order to receive or run a copy of the Program. Ancillary propagation of a covered work occurring solely as a consequence of using peer-to-peer transmission to receive a copy likewise does not require acceptance. However, nothing other than this License grants you permission to propagate or modify any covered work. These actions infringe copyright if you do not accept this License. Therefore, by modifying or propagating a covered work, you indicate your acceptance of this License to do so. 10. Automatic Licensing of Downstream Recipients. Each time you convey a covered work, the recipient automatically receives a license from the original licensors, to run, modify and propagate that work, subject to this License. You are not responsible for enforcing compliance by third parties with this License. An "entity transaction" is a transaction transferring control of an organization, or substantially all assets of one, or subdividing an organization, or merging organizations. If propagation of a covered work results from an entity transaction, each party to that transaction who receives a copy of the work also receives whatever licenses to the work the party's predecessor in interest had or could give under the previous paragraph, plus a right to possession of the Corresponding Source of the work from the predecessor in interest, if the predecessor has it or can get it with reasonable efforts. You may not impose any further restrictions on the exercise of the rights granted or affirmed under this License. For example, you may not impose a license fee, royalty, or other charge for exercise of rights granted under this License, and you may not initiate litigation (including a cross-claim or counterclaim in a lawsuit) alleging that any patent claim is infringed by making, using, selling, offering for sale, or importing the Program or any portion of it. 11. Patents. A "contributor" is a copyright holder who authorizes use under this License of the Program or a work on which the Program is based. The work thus licensed is called the contributor's "contributor version". A contributor's "essential patent claims" are all patent claims owned or controlled by the contributor, whether already acquired or hereafter acquired, that would be infringed by some manner, permitted by this License, of making, using, or selling its contributor version, but do not include claims that would be infringed only as a consequence of further modification of the contributor version. For purposes of this definition, "control" includes the right to grant patent sublicenses in a manner consistent with the requirements of this License. Each contributor grants you a non-exclusive, worldwide, royalty-free patent license under the contributor's essential patent claims, to make, use, sell, offer for sale, import and otherwise run, modify and propagate the contents of its contributor version. In the following three paragraphs, a "patent license" is any express agreement or commitment, however denominated, not to enforce a patent (such as an express permission to practice a patent or covenant not to sue for patent infringement). To "grant" such a patent license to a party means to make such an agreement or commitment not to enforce a patent against the party. If you convey a covered work, knowingly relying on a patent license, and the Corresponding Source of the work is not available for anyone to copy, free of charge and under the terms of this License, through a publicly available network server or other readily accessible means, then you must either (1) cause the Corresponding Source to be so available, or (2) arrange to deprive yourself of the benefit of the patent license for this particular work, or (3) arrange, in a manner consistent with the requirements of this License, to extend the patent license to downstream recipients. "Knowingly relying" means you have actual knowledge that, but for the patent license, your conveying the covered work in a country, or your recipient's use of the covered work in a country, would infringe one or more identifiable patents in that country that you have reason to believe are valid. If, pursuant to or in connection with a single transaction or arrangement, you convey, or propagate by procuring conveyance of, a covered work, and grant a patent license to some of the parties receiving the covered work authorizing them to use, propagate, modify or convey a specific copy of the covered work, then the patent license you grant is automatically extended to all recipients of the covered work and works based on it. A patent license is "discriminatory" if it does not include within the scope of its coverage, prohibits the exercise of, or is conditioned on the non-exercise of one or more of the rights that are specifically granted under this License. You may not convey a covered work if you are a party to an arrangement with a third party that is in the business of distributing software, under which you make payment to the third party based on the extent of your activity of conveying the work, and under which the third party grants, to any of the parties who would receive the covered work from you, a discriminatory patent license (a) in connection with copies of the covered work conveyed by you (or copies made from those copies), or (b) primarily for and in connection with specific products or compilations that contain the covered work, unless you entered into that arrangement, or that patent license was granted, prior to 28 March 2007. Nothing in this License shall be construed as excluding or limiting any implied license or other defenses to infringement that may otherwise be available to you under applicable patent law. 12. No Surrender of Others' Freedom. If conditions are imposed on you (whether by court order, agreement or otherwise) that contradict the conditions of this License, they do not excuse you from the conditions of this License. If you cannot convey a covered work so as to satisfy simultaneously your obligations under this License and any other pertinent obligations, then as a consequence you may not convey it at all. For example, if you agree to terms that obligate you to collect a royalty for further conveying from those to whom you convey the Program, the only way you could satisfy both those terms and this License would be to refrain entirely from conveying the Program. 13. Remote Network Interaction; Use with the GNU General Public License. Notwithstanding any other provision of this License, if you modify the Program, your modified version must prominently offer all users interacting with it remotely through a computer network (if your version supports such interaction) an opportunity to receive the Corresponding Source of your version by providing access to the Corresponding Source from a network server at no charge, through some standard or customary means of facilitating copying of software. This Corresponding Source shall include the Corresponding Source for any work covered by version 3 of the GNU General Public License that is incorporated pursuant to the following paragraph. Notwithstanding any other provision of this License, you have permission to link or combine any covered work with a work licensed under version 3 of the GNU General Public License into a single combined work, and to convey the resulting work. The terms of this License will continue to apply to the part which is the covered work, but the work with which it is combined will remain governed by version 3 of the GNU General Public License. 14. Revised Versions of this License. The Free Software Foundation may publish revised and/or new versions of the GNU Affero General Public License from time to time. Such new versions will be similar in spirit to the present version, but may differ in detail to address new problems or concerns. Each version is given a distinguishing version number. If the Program specifies that a certain numbered version of the GNU Affero General Public License "or any later version" applies to it, you have the option of following the terms and conditions either of that numbered version or of any later version published by the Free Software Foundation. If the Program does not specify a version number of the GNU Affero General Public License, you may choose any version ever published by the Free Software Foundation. If the Program specifies that a proxy can decide which future versions of the GNU Affero General Public License can be used, that proxy's public statement of acceptance of a version permanently authorizes you to choose that version for the Program. Later license versions may give you additional or different permissions. However, no additional obligations are imposed on any author or copyright holder as a result of your choosing to follow a later version. 15. Disclaimer of Warranty. THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 16. Limitation of Liability. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 17. Interpretation of Sections 15 and 16. If the disclaimer of warranty and limitation of liability provided above cannot be given local legal effect according to their terms, reviewing courts shall apply local law that most closely approximates an absolute waiver of all civil liability in connection with the Program, unless a warranty or assumption of liability accompanies a copy of the Program in return for a fee. END OF TERMS AND CONDITIONS How to Apply These Terms to Your New Programs If you develop a new program, and you want it to be of the greatest possible use to the public, the best way to achieve this is to make it free software which everyone can redistribute and change under these terms. To do so, attach the following notices to the program. It is safest to attach them to the start of each source file to most effectively state the exclusion of warranty; and each file should have at least the "copyright" line and a pointer to where the full notice is found. Copyright (C) This program is free software: you can redistribute it and/or modify it under the terms of the GNU Affero General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. You should have received a copy of the GNU Affero General Public License along with this program. If not, see . Also add information on how to contact you by electronic and paper mail. If your software can interact with users remotely through a computer network, you should also make sure that it provides a way for users to get its source. For example, if your program is a web application, its interface could display a "Source" link that leads users to an archive of the code. There are many ways you could offer source, and different solutions will be better for different programs; see section 13 for the specific requirements. You should also get your employer (if you work as a programmer) or school, if any, to sign a "copyright disclaimer" for the program, if necessary. For more information on this, and how to apply and follow the GNU AGPL, see . ================================================ FILE: studio/Unsloth_Studio_Colab.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "id": "6b87de59", "metadata": {}, "source": [ "To run this, press \"*Runtime*\" and press \"*Run all*\" on a **free** Tesla T4 Google Colab instance!\n", "
\n", "\n", "\n", " Join Discord if you need help + ⭐ Star us on Github ⭐\n", "
\n", "\n", "To install Unsloth Studio on your local device, follow [our guide](https://unsloth.ai/docs/new/unsloth-studio/install). Unsloth Studio is licensed [AGPL-3.0](https://github.com/unslothai/unsloth/blob/main/studio/LICENSE.AGPL-3.0).\n", "\n", "### Unsloth Studio\n", "\n", "Train and run open models with [**Unsloth Studio**](https://unsloth.ai/docs/new/unsloth-studio/start). Currently, installation may take 30+ mins so use a newer GPU.\n", "\n", "\n", "We are actively working on making Unsloth Studio install on Colab T4 GPUs faster.\n", "\n", "[Features](https://unsloth.ai/docs/new/unsloth-studio#features) • [Quickstart](https://unsloth.ai/docs/new/unsloth-studio/start) • [Data Recipes](https://unsloth.ai/docs/new/unsloth-studio/data-recipe) • [Studio Chat](https://unsloth.ai/docs/new/unsloth-studio/chat) • [Export](https://unsloth.ai/docs/new/unsloth-studio/export)" ] }, { "cell_type": "markdown", "id": "e4206349", "metadata": {}, "source": [ "

" ] }, { "cell_type": "markdown", "id": "27da2957", "metadata": {}, "source": [ "### Setup: Clone repo and run setup" ] }, { "cell_type": "code", "execution_count": null, "id": "27e68f91", "metadata": {}, "outputs": [], "source": [ "!git clone --depth 1 --branch main https://github.com/unslothai/unsloth.git\n", "%cd /content/unsloth\n", "\n", "# Run setup script\n", "!chmod +x studio/setup.sh\n", "!./studio/setup.sh" ] }, { "cell_type": "markdown", "id": "3e1771a9", "metadata": {}, "source": [ "### Start Unsloth Studio" ] }, { "cell_type": "code", "execution_count": null, "id": "277e431e", "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.insert(0, '/content/unsloth/studio/backend')\n", "\n", "from colab import start\n", "start()" ] }, { "cell_type": "markdown", "id": "f2b0c6a1", "metadata": {}, "source": [ "And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/unsloth) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!\n", "\n", "Some other resources:\n", "1. Looking to use Unsloth locally? Read our [Installation Guide](https://unsloth.ai/docs/get-started/install) for details on installing Unsloth on Windows, Docker, AMD, Intel GPUs.\n", "2. Learn how to do Reinforcement Learning with our [RL Guide and notebooks](https://unsloth.ai/docs/get-started/reinforcement-learning-rl-guide).\n", "3. Read our guides and notebooks for [Text-to-speech (TTS)](https://unsloth.ai/docs/basics/text-to-speech-tts-fine-tuning) and [vision](https://unsloth.ai/docs/basics/vision-fine-tuning) model support.\n", "4. Explore our [LLM Tutorials Directory](https://unsloth.ai/docs/models/tutorials-how-to-fine-tune-and-run-llms) to find dedicated guides for each model.\n", "5. Need help with Inference? Read our [Inference & Deployment page](https://unsloth.ai/docs/basics/inference-and-deployment) for details on using vLLM, llama.cpp, Ollama etc.\n", "\n", "
\n", " \n", " \n", " \n", "\n", " Join Discord if you need help + ⭐️ Star us on Github ⭐️\n", "\n", " This notebook is licensed AGPL-3.0\n", "
" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "include_colab_link": true, "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: studio/__init__.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 ================================================ FILE: studio/backend/__init__.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 ================================================ FILE: studio/backend/assets/__init__.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 ================================================ FILE: studio/backend/assets/configs/__init__.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 ================================================ FILE: studio/backend/assets/configs/full_finetune.yaml ================================================ model: unsloth/Qwen2.5-0.5B data: dataset: tatsu-lab/alpaca format_type: auto training: training_type: full max_seq_length: 2048 load_in_4bit: false output_dir: outputs num_epochs: 1 learning_rate: 0.0002 batch_size: 1 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 0 save_steps: 0 weight_decay: 0.01 random_seed: 3407 packing: false train_on_completions: false gradient_checkpointing: "unsloth" lora: lora_r: 64 lora_alpha: 16 lora_dropout: 0.0 target_modules: "" vision_all_linear: false use_rslora: false use_loftq: false finetune_vision_layers: true finetune_language_layers: true finetune_attention_modules: true finetune_mlp_modules: true logging: enable_wandb: false wandb_project: unsloth-training enable_tensorboard: false tensorboard_dir: runs ================================================ FILE: studio/backend/assets/configs/inference_defaults.json ================================================ { "_comment": "Per-model-family inference parameter defaults. Sources: (1) Ollama params blobs, (2) Existing Unsloth Studio YAML configs. Patterns ordered longest-match-first.", "families": { "qwen3.5": { "temperature": 0.7, "top_p": 0.8, "top_k": 20, "min_p": 0.0, "repetition_penalty": 1.0, "presence_penalty": 1.5 }, "qwen3-coder": { "temperature": 0.7, "top_p": 0.8, "top_k": 20, "min_p": 0.0, "repetition_penalty": 1.0 }, "qwen3-next": { "temperature": 0.7, "top_p": 0.8, "top_k": 20, "min_p": 0.0, "repetition_penalty": 1.0 }, "qwen3-vl": { "temperature": 0.7, "top_p": 0.8, "top_k": 20, "min_p": 0.0, "repetition_penalty": 1.0 }, "qwen3": { "temperature": 0.6, "top_p": 0.95, "top_k": 20, "min_p": 0.0, "repetition_penalty": 1.0 }, "qwen2.5-coder": { "temperature": 1.5, "top_p": 0.95, "top_k": -1, "min_p": 0.1, "repetition_penalty": 1.0 }, "qwen2.5-vl": { "temperature": 1.5, "top_p": 0.95, "top_k": -1, "min_p": 0.1, "repetition_penalty": 1.0 }, "qwen2.5-omni": { "temperature": 0.7, "top_p": 0.8, "top_k": 20, "min_p": 0.0, "repetition_penalty": 1.0 }, "qwen2.5-math": { "temperature": 0.7, "top_p": 0.8, "top_k": 20, "min_p": 0.0, "repetition_penalty": 1.0 }, "qwen2.5": { "temperature": 0.7, "top_p": 0.8, "top_k": 20, "min_p": 0.0, "repetition_penalty": 1.0 }, "qwen2-vl": { "temperature": 1.5, "top_p": 0.95, "top_k": -1, "min_p": 0.1, "repetition_penalty": 1.0 }, "qwen2": { "temperature": 0.7, "top_p": 0.8, "top_k": 20, "min_p": 0.0, "repetition_penalty": 1.0 }, "qwq": { "temperature": 0.6, "top_p": 0.95, "top_k": 40, "min_p": 0.0, "repetition_penalty": 1.0 }, "gemma-3n": { "temperature": 1.0, "top_p": 0.95, "top_k": 64, "min_p": 0.0, "repetition_penalty": 1.0 }, "gemma-3": { "temperature": 1.0, "top_p": 0.95, "top_k": 64, "min_p": 0.0, "repetition_penalty": 1.0 }, "medgemma": { "temperature": 1.0, "top_p": 0.95, "top_k": 64, "min_p": 0.0, "repetition_penalty": 1.0 }, "gemma-2": { "temperature": 1.0, "top_p": 0.95, "top_k": 64, "min_p": 0.0, "repetition_penalty": 1.0 }, "llama-4": { "temperature": 1.0, "top_p": 0.9, "top_k": -1, "min_p": 0.01, "repetition_penalty": 1.0 }, "llama-3.3": { "temperature": 1.5, "top_p": 0.95, "top_k": -1, "min_p": 0.1, "repetition_penalty": 1.0 }, "llama-3.2": { "temperature": 1.5, "top_p": 0.95, "top_k": -1, "min_p": 0.1, "repetition_penalty": 1.0 }, "llama-3.1": { "temperature": 1.5, "top_p": 0.95, "top_k": -1, "min_p": 0.1, "repetition_penalty": 1.0 }, "llama-3": { "temperature": 1.5, "top_p": 0.95, "top_k": -1, "min_p": 0.1, "repetition_penalty": 1.0 }, "phi-4": { "temperature": 0.8, "top_p": 0.95, "top_k": -1, "min_p": 0.0, "repetition_penalty": 1.0 }, "phi-3": { "temperature": 0.7, "top_p": 0.9, "top_k": -1, "min_p": 0.01, "repetition_penalty": 1.0 }, "mistral-nemo": { "temperature": 0.7, "top_p": 0.95, "top_k": -1, "min_p": 0.01, "repetition_penalty": 1.0 }, "mistral-small": { "temperature": 0.15, "top_p": 0.95, "top_k": -1, "min_p": 0.01, "repetition_penalty": 1.0 }, "mistral-large": { "temperature": 0.7, "top_p": 0.95, "top_k": -1, "min_p": 0.01, "repetition_penalty": 1.0 }, "magistral": { "temperature": 0.7, "top_p": 0.95, "top_k": -1, "min_p": 0.01, "repetition_penalty": 1.0 }, "ministral": { "temperature": 0.15, "top_p": 0.95, "top_k": -1, "min_p": 0.01, "repetition_penalty": 1.0 }, "devstral": { "temperature": 0.7, "top_p": 0.95, "top_k": -1, "min_p": 0.01, "repetition_penalty": 1.0 }, "pixtral": { "temperature": 1.5, "top_p": 0.95, "top_k": -1, "min_p": 0.1, "repetition_penalty": 1.0 }, "deepseek-r1": { "temperature": 0.6, "top_p": 0.95, "top_k": -1, "min_p": 0.01, "repetition_penalty": 1.0 }, "deepseek-v3": { "temperature": 0.6, "top_p": 0.95, "top_k": -1, "min_p": 0.01, "repetition_penalty": 1.0 }, "deepseek-ocr": { "temperature": 0.0, "top_p": 0.95, "top_k": -1, "min_p": 0.01, "repetition_penalty": 1.0 }, "glm-5": { "temperature": 1.0, "top_p": 0.95, "top_k": -1, "min_p": 0.01, "repetition_penalty": 1.0 }, "glm-4": { "temperature": 1.0, "top_p": 0.95, "top_k": -1, "min_p": 0.01, "repetition_penalty": 1.0 }, "nemotron": { "temperature": 1.0, "top_p": 1.0, "top_k": -1, "min_p": 0.01, "repetition_penalty": 1.0 }, "minimax-m2.5": { "temperature": 1.0, "top_p": 0.95, "top_k": 40, "min_p": 0.01, "repetition_penalty": 1.0 }, "minimax": { "temperature": 1.0, "top_p": 0.95, "top_k": 40, "min_p": 0.01, "repetition_penalty": 1.0 }, "gpt-oss": { "temperature": 1.0, "top_p": 1.0, "top_k": 0, "min_p": 0.01, "repetition_penalty": 1.0 }, "granite-4": { "temperature": 0.0, "top_p": 1.0, "top_k": 0, "min_p": 0.01, "repetition_penalty": 1.0 }, "kimi-k2": { "temperature": 0.6, "top_p": 0.95, "top_k": -1, "min_p": 0.01, "repetition_penalty": 1.0 }, "kimi": { "temperature": 0.6, "top_p": 0.95, "top_k": -1, "min_p": 0.01, "repetition_penalty": 1.0 }, "lfm2": { "temperature": 0.1, "top_p": 0.1, "top_k": 50, "min_p": 0.15, "repetition_penalty": 1.05 }, "smollm": { "temperature": 0.7, "top_p": 0.95, "top_k": -1, "min_p": 0.01, "repetition_penalty": 1.0 }, "olmo": { "temperature": 0.7, "top_p": 0.95, "top_k": -1, "min_p": 0.01, "repetition_penalty": 1.0 }, "falcon": { "temperature": 0.7, "top_p": 0.95, "top_k": -1, "min_p": 0.01, "repetition_penalty": 1.0 }, "ernie": { "temperature": 0.7, "top_p": 0.95, "top_k": -1, "min_p": 0.01, "repetition_penalty": 1.0 }, "seed": { "temperature": 0.7, "top_p": 0.95, "top_k": -1, "min_p": 0.01, "repetition_penalty": 1.0 }, "grok": { "temperature": 1.0, "top_p": 0.95, "top_k": -1, "min_p": 0.01, "repetition_penalty": 1.0 }, "mimo": { "temperature": 0.7, "top_p": 0.95, "top_k": -1, "min_p": 0.01, "repetition_penalty": 1.0 } }, "patterns": [ "qwen3.5", "qwen3-coder", "qwen3-next", "qwen3-vl", "qwen3", "qwen2.5-coder", "qwen2.5-vl", "qwen2.5-omni", "qwen2.5-math", "qwen2.5", "qwen2-vl", "qwen2", "qwq", "gemma-3n", "gemma-3", "medgemma", "gemma-2", "llama-4", "llama-3.3", "llama-3.2", "llama-3.1", "llama-3", "phi-4", "phi-3", "mistral-nemo", "mistral-small", "mistral-large", "magistral", "ministral", "devstral", "pixtral", "deepseek-r1", "deepseek-v3", "deepseek-ocr", "glm-5", "glm-4", "nemotron", "minimax-m2.5", "minimax", "gpt-oss", "granite-4", "kimi-k2", "kimi", "lfm2", "smollm", "olmo", "falcon", "ernie", "seed", "grok", "mimo" ] } ================================================ FILE: studio/backend/assets/configs/lora_text.yaml ================================================ model: unsloth/Qwen2.5-0.5B data: dataset: tatsu-lab/alpaca format_type: auto training: training_type: lora max_seq_length: 2048 load_in_4bit: true output_dir: outputs num_epochs: 1 learning_rate: 0.0002 batch_size: 2 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 0 save_steps: 0 weight_decay: 0.01 random_seed: 3407 packing: false train_on_completions: false gradient_checkpointing: "unsloth" lora: lora_r: 64 lora_alpha: 16 lora_dropout: 0.0 target_modules: "q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj" vision_all_linear: false use_rslora: false use_loftq: false finetune_vision_layers: true finetune_language_layers: true finetune_attention_modules: true finetune_mlp_modules: true logging: enable_wandb: false wandb_project: unsloth-training enable_tensorboard: false tensorboard_dir: runs ================================================ FILE: studio/backend/assets/configs/model_defaults/default.yaml ================================================ # Default model training parameters # Used for models without specific configurations training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 5e-5 batch_size: 2 gradient_accumulation_steps: 4 warmup_ratio: 0.1 max_steps: 30 save_steps: 30 weight_decay: 0.01 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 16 lora_alpha: 16 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false finetune_vision_layers: true finetune_language_layers: true finetune_attention_modules: true finetune_mlp_modules: true logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false temperature: 0.7 top_p: 0.95 top_k: -1 min_p: 0.01 ================================================ FILE: studio/backend/assets/configs/model_defaults/embedding/unsloth_Qwen3-Embedding-0.6B.yaml ================================================ # Model defaults for unsloth/Qwen3-Embedding-0.6B # Based on Qwen3_Embedding_(0_6B).py embedding notebook # Also applies to: unsloth/Qwen3-Embedding-4B training: max_seq_length: 512 # num_epochs: 2 num_epochs: 0 learning_rate: 3e-5 batch_size: 256 gradient_accumulation_steps: 1 warmup_ratio: 0.03 max_steps: 30 save_steps: 30 weight_decay: 0.01 random_seed: 3407 packing: false train_on_completions: false gradient_checkpointing: false optim: "adamw_8bit" lr_scheduler_type: "constant_with_warmup" lora: lora_r: 32 lora_alpha: 32 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "embedding-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 50 ================================================ FILE: studio/backend/assets/configs/model_defaults/embedding/unsloth_all-MiniLM-L6-v2.yaml ================================================ # Model defaults for unsloth/all-MiniLM-L6-v2 # Based on All_MiniLM_L6_v2.py embedding notebook training: max_seq_length: 512 # num_epochs: 2 num_epochs: 0 learning_rate: 2e-4 batch_size: 256 gradient_accumulation_steps: 1 warmup_ratio: 0.03 max_steps: 30 save_steps: 30 weight_decay: 0.01 random_seed: 3407 packing: false train_on_completions: false gradient_checkpointing: false optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 64 lora_alpha: 128 lora_dropout: 0.0 target_modules: - "value" - "key" - "dense" - "query" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "embedding-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 50 ================================================ FILE: studio/backend/assets/configs/model_defaults/embedding/unsloth_bge-m3.yaml ================================================ # Model defaults for unsloth/bge-m3 # Based on BGE_M3.py embedding notebook training: max_seq_length: 512 # num_epochs: 2 num_epochs: 0 learning_rate: 3e-5 batch_size: 256 gradient_accumulation_steps: 1 warmup_ratio: 0.03 max_steps: 30 save_steps: 30 weight_decay: 0.01 random_seed: 3407 packing: false train_on_completions: false gradient_checkpointing: false optim: "adamw_8bit" lr_scheduler_type: "constant_with_warmup" lora: lora_r: 32 lora_alpha: 64 lora_dropout: 0.0 target_modules: - "key" - "query" - "dense" - "value" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "embedding-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 50 ================================================ FILE: studio/backend/assets/configs/model_defaults/embedding/unsloth_embeddinggemma-300m.yaml ================================================ # Model defaults for unsloth/embeddinggemma-300m # Based on EmbeddingGemma_(300M).py embedding notebook training: max_seq_length: 1024 # num_epochs: 1 num_epochs: 0 learning_rate: 2e-5 batch_size: 64 gradient_accumulation_steps: 2 warmup_ratio: 0.03 max_steps: 30 save_steps: 30 weight_decay: 0.01 random_seed: 3407 packing: false train_on_completions: false gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 32 lora_alpha: 64 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "embedding-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 5 ================================================ FILE: studio/backend/assets/configs/model_defaults/embedding/unsloth_gte-modernbert-base.yaml ================================================ # Model defaults for unsloth/gte-modernbert-base # Based on ModernBert.py embedding notebook training: max_seq_length: 512 # num_epochs: 2 num_epochs: 0 learning_rate: 3e-5 batch_size: 256 gradient_accumulation_steps: 1 warmup_ratio: 0.03 max_steps: 30 save_steps: 30 weight_decay: 0.01 random_seed: 3407 packing: false train_on_completions: false gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "constant_with_warmup" lora: lora_r: 64 lora_alpha: 128 lora_dropout: 0.0 target_modules: - "Wi" - "Wo" - "Wqkv" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "embedding-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 50 ================================================ FILE: studio/backend/assets/configs/model_defaults/ernie/unsloth_ERNIE-4.5-21B-A3B-PT.yaml ================================================ # Model defaults for unsloth/ERNIE-4.5-21B-A3B-PT # Based on ERNIE_4_5_21B_A3B_PT-Conversational.ipynb # Also applies to: unsloth/ERNIE-4.5-21B-A3B-PT training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 4 gradient_accumulation_steps: 2 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 8 lora_alpha: 16 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false ================================================ FILE: studio/backend/assets/configs/model_defaults/ernie/unsloth_ERNIE-4.5-VL-28B-A3B-PT.yaml ================================================ # Model defaults for unsloth/ERNIE-4.5-VL-28B-A3B-PT # Based on ERNIE_4_5_VL_28B_A3B_PT_Vision.ipynb # Also applies to: unsloth/ERNIE-4.5-VL-28B-A3B-PT # added inference parameters from unsloth notebook training: trust_remote_code: true max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 2 gradient_accumulation_steps: 2 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 8 lora_alpha: 16 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false finetune_vision_layers: true finetune_language_layers: true finetune_attention_modules: true finetune_mlp_modules: true logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: true temperature: 1.5 min_p: 0.1 ================================================ FILE: studio/backend/assets/configs/model_defaults/falcon/tiiuae_Falcon-H1-0.5B-Instruct.yaml ================================================ # Model defaults for tiiuae/Falcon-H1-0.5B-Instruct # Based on Falcon_H1_(0.5B)-Alpaca.ipynb # Also applies to: tiiuae/Falcon-H1-0.5B-Instruct, unsloth/Falcon-H1-0.5B-Instruct training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 2 gradient_accumulation_steps: 8 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.01 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: false optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 16 lora_alpha: 16 lora_dropout: 0.1 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false ================================================ FILE: studio/backend/assets/configs/model_defaults/gemma/unsloth_codegemma-7b-bnb-4bit.yaml ================================================ # Model defaults for unsloth/codegemma-7b-bnb-4bit # Based on CodeGemma_(7B)-Conversational.ipynb # Also applies to: unsloth/codegemma-7b, google/codegemma-7b # added inference parameters from Ollama training: trust_remote_code: false max_seq_length: 4096 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 1 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 16 lora_alpha: 16 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false temperature: 0 top_p: 0.9 ================================================ FILE: studio/backend/assets/configs/model_defaults/gemma/unsloth_functiongemma-270m-it.yaml ================================================ # Model defaults for unsloth/functiongemma-270m-it # Based on FunctionGemma_(270M).ipynb # Also applies to: unsloth/functiongemma-270m-it-unsloth-bnb-4bit, google/functiongemma-270m-it, unsloth/functiongemma-270m-it-unsloth-bnb-4bit # added inference parameters from unsloth guides training: trust_remote_code: false max_seq_length: 4096 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 4 gradient_accumulation_steps: 2 warmup_steps: 10 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 128 lora_alpha: 256 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false temperature: 1.0 top_k: 64 top_p: 0.95 min_p: 0.0 ================================================ FILE: studio/backend/assets/configs/model_defaults/gemma/unsloth_gemma-2-27b-bnb-4bit.yaml ================================================ # Model defaults for unsloth/gemma-2-27b-bnb-4bit # Based on Gemma2_(9B)-Alpaca.ipynb (same defaults for larger models) training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 2 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 16 lora_alpha: 16 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false ================================================ FILE: studio/backend/assets/configs/model_defaults/gemma/unsloth_gemma-2-2b.yaml ================================================ # Model defaults for unsloth/gemma-2-2b # Based on Gemma2_(2B)-Alpaca.ipynb # Also applies to: unsloth/gemma-2-2b-bnb-4bit, google/gemma-2-2b training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 2 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.01 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 16 lora_alpha: 16 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false ================================================ FILE: studio/backend/assets/configs/model_defaults/gemma/unsloth_gemma-3-270m-it.yaml ================================================ # Model defaults for unsloth/gemma-3-270m-it # Based on Gemma3_(270M).ipynb # Also applies to: unsloth/gemma-3-270m-it-unsloth-bnb-4bit, google/gemma-3-270m-it, unsloth/gemma-3-270m-it-bnb-4bit # added inference parameters from unsloth guides training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 5e-5 batch_size: 4 gradient_accumulation_steps: 1 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 128 lora_alpha: 128 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false temperature: 1.0 top_k: 64 top_p: 0.95 min_p: 0.0 ================================================ FILE: studio/backend/assets/configs/model_defaults/gemma/unsloth_gemma-3-27b-it.yaml ================================================ # Model defaults for unsloth/gemma-3-27b-it # Based on Gemma3_(27B)_A100-Conversational.ipynb # Also applies to: unsloth/gemma-3-27b-it-unsloth-bnb-4bit, google/gemma-3-27b-it, unsloth/gemma-3-27b-it-bnb-4bit # added inference parameters from unsloth guides training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 2 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 8 lora_alpha: 8 lora_dropout: 0.0 target_modules: - "all-linear" use_rslora: false use_loftq: false finetune_vision_layers: true finetune_language_layers: true finetune_attention_modules: true finetune_mlp_modules: true logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false temperature: 1.0 top_k: 64 top_p: 0.95 min_p: 0.0 ================================================ FILE: studio/backend/assets/configs/model_defaults/gemma/unsloth_gemma-3-4b-it.yaml ================================================ # Model defaults for unsloth/gemma-3-4b-it # Based on Gemma3_(4B).ipynb # Also applies to: unsloth/gemma-3-4b-it-unsloth-bnb-4bit, google/gemma-3-4b-it, unsloth/gemma-3-4b-it-bnb-4bit # added inference parameters from unsloth guides training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 2 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 8 lora_alpha: 8 lora_dropout: 0.0 target_modules: - "all-linear" use_rslora: false use_loftq: false finetune_vision_layers: true finetune_language_layers: true finetune_attention_modules: true finetune_mlp_modules: true logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false temperature: 1.0 top_k: 64 top_p: 0.95 min_p: 0.0 ================================================ FILE: studio/backend/assets/configs/model_defaults/gemma/unsloth_gemma-3-4b-pt.yaml ================================================ # Model defaults for unsloth/gemma-3-4b-pt # Based on Gemma3_(4B)-Vision.ipynb # Also applies to: unsloth/gemma-3-4b-pt-unsloth-bnb-4bit, google/gemma-3-4b-pt, unsloth/gemma-3-4b-pt-bnb-4bit # added inference parameters from unsloth guides training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 2 num_epochs: 0 learning_rate: 2e-4 batch_size: 1 gradient_accumulation_steps: 4 warmup_ratio: 0.03 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: true optim: "adamw_torch_fused" lr_scheduler_type: "cosine" lora: lora_r: 16 lora_alpha: 16 lora_dropout: 0.0 target_modules: - "all-linear" use_rslora: false use_loftq: false finetune_vision_layers: true finetune_language_layers: true finetune_attention_modules: true finetune_mlp_modules: true logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false temperature: 1.0 top_k: 64 top_p: 0.95 min_p: 0.0 ================================================ FILE: studio/backend/assets/configs/model_defaults/gemma/unsloth_gemma-3n-E4B-it.yaml ================================================ # Model defaults for unsloth/gemma-3n-E4B-it # Based on Gemma3N_(4B)-Conversational.ipynb # Also applies to: unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit, google/gemma-3n-E4B-it, unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit # added inference parameters from unsloth guides training: trust_remote_code: false max_seq_length: 1024 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 1 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 8 lora_alpha: 8 lora_dropout: 0.0 target_modules: - "all-linear" use_rslora: false use_loftq: false finetune_vision_layers: true finetune_language_layers: true finetune_attention_modules: true finetune_mlp_modules: true logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 audio_input: true inference: trust_remote_code: false temperature: 1.0 top_k: 64 top_p: 0.95 min_p: 0.0 ================================================ FILE: studio/backend/assets/configs/model_defaults/gemma/unsloth_gemma-3n-E4B.yaml ================================================ # Model defaults for unsloth/gemma-3n-E4B # Based on Gemma3N_(4B)-Vision.ipynb # Also applies to: unsloth/gemma-3n-E4B-unsloth-bnb-4bit, google/gemma-3n-E4B # added inference parameters from unsloth guides training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 2 num_epochs: 0 learning_rate: 2e-4 batch_size: 1 gradient_accumulation_steps: 4 warmup_ratio: 0.03 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: true optim: "adamw_torch_fused" lr_scheduler_type: "cosine" lora: lora_r: 32 lora_alpha: 32 lora_dropout: 0.0 target_modules: - "all-linear" use_rslora: false use_loftq: false finetune_vision_layers: true finetune_language_layers: true finetune_attention_modules: true finetune_mlp_modules: true logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 audio_input: true inference: trust_remote_code: false temperature: 1.0 top_k: 64 top_p: 0.95 min_p: 0.0 ================================================ FILE: studio/backend/assets/configs/model_defaults/gpt-oss/unsloth_gpt-oss-120b.yaml ================================================ # Model defaults for unsloth/gpt-oss-120b # Based on gpt-oss-(120B)_A100-Fine-tuning.ipynb # Also applies to: openai/gpt-oss-120b, unsloth/gpt-oss-120b-unsloth-bnb-4bit # added inference parameters from unsloth guides training: trust_remote_code: false max_seq_length: 4096 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 4 gradient_accumulation_steps: 1 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 16 lora_alpha: 32 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false temperature: 1.0 top_p: 1.0 top_k: 0 ================================================ FILE: studio/backend/assets/configs/model_defaults/gpt-oss/unsloth_gpt-oss-20b.yaml ================================================ # Model defaults for unsloth/gpt-oss-20b # Based on gpt-oss-(20B)-Fine-tuning.ipynb # Also applies to: openai/gpt-oss-20b, unsloth/gpt-oss-20b-unsloth-bnb-4bit, unsloth/gpt-oss-20b-BF16 # added inference parameters from unsloth guides training: trust_remote_code: false max_seq_length: 1024 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 1 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 8 lora_alpha: 16 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false temperature: 1.0 top_p: 1.0 top_k: 0 ================================================ FILE: studio/backend/assets/configs/model_defaults/granite/unsloth_granite-4.0-350m-unsloth-bnb-4bit.yaml ================================================ # Model defaults for unsloth/granite-4.0-350m # Based on Granite4.0_350M.ipynb # Also applies to: ibm-granite/granite-4.0-350m, unsloth/granite-4.0-350m-bnb-4bit # added inference parameters from unsloth guides training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 2 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 32 lora_alpha: 32 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" - "shared_mlp.input_linear" - "shared_mlp.output_linear" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false temperature: 0.0 top_p: 1.0 top_k: 0 ================================================ FILE: studio/backend/assets/configs/model_defaults/granite/unsloth_granite-4.0-h-micro.yaml ================================================ # Model defaults for unsloth/granite-4.0-h-micro # Based on Granite4.0.ipynb # Also applies to: ibm-granite/granite-4.0-h-micro, unsloth/granite-4.0-h-micro-bnb-4bit, unsloth/granite-4.0-h-micro-unsloth-bnb-4bit # added inference parameters from unsloth guides training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 2 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 32 lora_alpha: 32 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" - "shared_mlp.input_linear" - "shared_mlp.output_linear" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false temperature: 0.0 top_p: 1.0 top_k: 0 ================================================ FILE: studio/backend/assets/configs/model_defaults/llama/unsloth_Llama-3.2-11B-Vision-Instruct.yaml ================================================ # Model defaults for unsloth/Llama-3.2-11B-Vision-Instruct # Based on Llama3.2_(11B)-Vision.ipynb # Also applies to: unsloth/Llama-3.2-11B-Vision-Instruct-unsloth-bnb-4bit, meta-llama/Llama-3.2-11B-Vision-Instruct, unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit # added inference parameters from unsloth notebook training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 2 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 16 lora_alpha: 16 lora_dropout: 0.0 target_modules: - "all-linear" use_rslora: false use_loftq: false finetune_vision_layers: true finetune_language_layers: true finetune_attention_modules: true finetune_mlp_modules: true logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false temperature: 1.5 min_p: 0.1 ================================================ FILE: studio/backend/assets/configs/model_defaults/llama/unsloth_Llama-3.2-1B-Instruct.yaml ================================================ # Model defaults for unsloth/Llama-3.2-1B-Instruct # Based on Llama3.2_(1B)-RAFT.ipynb # Also applies to: unsloth/Llama-3.2-1B-Instruct-unsloth-bnb-4bit, meta-llama/Llama-3.2-1B-Instruct, unsloth/Llama-3.2-1B-Instruct-bnb-4bit, RedHatAI/Llama-3.2-1B-Instruct-FP8, unsloth/Llama-3.2-1B-Instruct-FP8-Block, unsloth/Llama-3.2-1B-Instruct-FP8-Dynamic training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 5 num_epochs: 0 learning_rate: 2e-5 batch_size: 1 gradient_accumulation_steps: 8 warmup_steps: 0 max_steps: 30 save_steps: 30 weight_decay: 0.01 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: true optim: "adamw_torch" lr_scheduler_type: "cosine" lora: lora_r: 16 lora_alpha: 16 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false ================================================ FILE: studio/backend/assets/configs/model_defaults/llama/unsloth_Llama-3.2-3B-Instruct.yaml ================================================ # Model defaults for unsloth/Llama-3.2-3B-Instruct # Based on Llama3.2_(1B_and_3B)-Conversational.ipynb # Also applies to: unsloth/Llama-3.2-3B-Instruct-unsloth-bnb-4bit, meta-llama/Llama-3.2-3B-Instruct, unsloth/Llama-3.2-3B-Instruct-bnb-4bit, RedHatAI/Llama-3.2-3B-Instruct-FP8, unsloth/Llama-3.2-3B-Instruct-FP8-Block, unsloth/Llama-3.2-3B-Instruct-FP8-Dynamic # added inference parameters from unsloth notebook training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 2 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 16 lora_alpha: 16 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false temperature: 1.5 min_p: 0.1 ================================================ FILE: studio/backend/assets/configs/model_defaults/llama/unsloth_Llama-3.3-70B-Instruct.yaml ================================================ # Model defaults for unsloth/Llama-3.3-70B-Instruct # Based on Llama3.3_(70B)_A100-Conversational.ipynb # Also applies to: unsloth/Llama-3.3-70B-Instruct-unsloth-bnb-4bit, meta-llama/Llama-3.3-70B-Instruct, unsloth/Llama-3.3-70B-Instruct-bnb-4bit, RedHatAI/Llama-3.3-70B-Instruct-FP8, unsloth/Llama-3.3-70B-Instruct-FP8-Block, unsloth/Llama-3.3-70B-Instruct-FP8-Dynamic # added inference parameters from unsloth notebook training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 2 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 16 lora_alpha: 16 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false temperature: 1.5 min_p: 0.1 ================================================ FILE: studio/backend/assets/configs/model_defaults/llama/unsloth_Meta-Llama-3.1-70B-bnb-4bit.yaml ================================================ # Model defaults for unsloth/Meta-Llama-3.1-70B-bnb-4bit # Based on Llama3.1_(8B)-Alpaca.ipynb # Also applies to: unsloth/Meta-Llama-3.1-8B-bnb-4bit, unsloth/Meta-Llama-3.1-8B-unsloth-bnb-4bit, meta-llama/Meta-Llama-3.1-8B, unsloth/Meta-Llama-3.1-8B, unsloth/Meta-Llama-3.1-70B, meta-llama/Meta-Llama-3.1-70B, unsloth/Meta-Llama-3.1-405B-bnb-4bit, meta-llama/Meta-Llama-3.1-405B training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 2 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 16 lora_alpha: 16 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false ================================================ FILE: studio/backend/assets/configs/model_defaults/llama/unsloth_Meta-Llama-3.1-8B-Instruct-bnb-4bit.yaml ================================================ # Model defaults for unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit # Based on Llama3.1_(8B)-Inference.ipynb # Also applies to: "unsloth/Meta-Llama-3.1-8B-Instruct-unsloth-bnb-4bit", "meta-llama/Meta-Llama-3.1-8B-Instruct", "unsloth/Meta-Llama-3.1-8B-Instruct","RedHatAI/Llama-3.1-8B-Instruct-FP8","unsloth/Llama-3.1-8B-Instruct-FP8-Block","unsloth/Llama-3.1-8B-Instruct-FP8-Dynamic" training: trust_remote_code: false max_seq_length: 8192 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 2 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 16 lora_alpha: 16 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false ================================================ FILE: studio/backend/assets/configs/model_defaults/llama/unsloth_llama-3-8b-Instruct-bnb-4bit.yaml ================================================ # Model defaults for unsloth/llama-3-8b-Instruct-bnb-4bit # Based on Llama3_(8B)-Conversational.ipynb # Also applies to: unsloth/llama-3-8b-Instruct, meta-llama/Meta-Llama-3-8B-Instruct training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 2 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 16 lora_alpha: 16 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false ================================================ FILE: studio/backend/assets/configs/model_defaults/llama/unsloth_llama-3-8b-bnb-4bit.yaml ================================================ # Model defaults for unsloth/llama-3-8b-bnb-4bit # Based on Llama3_(8B)-Alpaca.ipynb # Also applies to: unsloth/llama-3-8b, meta-llama/Meta-Llama-3-8B training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 2 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 16 lora_alpha: 16 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false ================================================ FILE: studio/backend/assets/configs/model_defaults/llasa/unsloth_Llasa-3B.yaml ================================================ # Model defaults for unsloth/Llasa-3B # Based on Llasa_TTS_(3B).ipynb and Llasa_TTS_(1B).ipynb # Also applies to: HKUSTAudio/Llasa-1B # added inference parameters from unsloth notebook training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 5e-4 batch_size: 2 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 128 lora_alpha: 128 lora_dropout: 0.0 target_modules: - "q_proj" - "v_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false temperature: 1.2 top_p: 1.2 ================================================ FILE: studio/backend/assets/configs/model_defaults/mistral/unsloth_Magistral-Small-2509-unsloth-bnb-4bit.yaml ================================================ # Model defaults for unsloth/Magistral-Small-2509 # Based on Magistral_(24B)-Reasoning-Conversational.ipynb # Also applies to: mistralai/Magistral-Small-2509, unsloth/Magistral-Small-2509-bnb-4bit # added inference parameters from unsloth guides training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 2 gradient_accumulation_steps: 2 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 32 lora_alpha: 32 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false finetune_vision_layers: true finetune_language_layers: true finetune_attention_modules: true finetune_mlp_modules: true logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false temperature: 0.7 min_p: 0.01 top_p: 0.95 ================================================ FILE: studio/backend/assets/configs/model_defaults/mistral/unsloth_Ministral-3-3B-Instruct-2512.yaml ================================================ # Model defaults for unsloth/Ministral-3-3B-Instruct-2512 # Based on Ministral_3_VL_(3B)_Vision.ipynb # Also applies to: unsloth/Ministral-3-3B-Instruct-2512 # added inference parameters from unsloth guides training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 4 gradient_accumulation_steps: 2 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 32 lora_alpha: 32 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false finetune_vision_layers: true finetune_language_layers: true finetune_attention_modules: true finetune_mlp_modules: true logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false temperature: 0.15 top_p: 0.95 ================================================ FILE: studio/backend/assets/configs/model_defaults/mistral/unsloth_Mistral-Nemo-Base-2407-bnb-4bit.yaml ================================================ # Model defaults for unsloth/Mistral-Nemo-Base-2407-bnb-4bit # Based on Mistral_Nemo_(12B)-Alpaca.ipynb # Also applies to: "unsloth/Mistral-Nemo-Base-2407", "mistralai/Mistral-Nemo-Base-2407", "unsloth/Mistral-Nemo-Instruct-2407-bnb-4bit", "unsloth/Mistral-Nemo-Instruct-2407", "mistralai/Mistral-Nemo-Instruct-2407", training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 2 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 16 lora_alpha: 16 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false ================================================ FILE: studio/backend/assets/configs/model_defaults/mistral/unsloth_Mistral-Small-Instruct-2409.yaml ================================================ # Model defaults for unsloth/Mistral-Small-Instruct-2409 # Based on Mistral_Small_(22B)-Alpaca.ipynb # Also applies to: unsloth/Mistral-Small-Instruct-2409-bnb-4bit, mistralai/Mistral-Small-Instruct-2409 training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 1 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 16 lora_alpha: 16 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false ================================================ FILE: studio/backend/assets/configs/model_defaults/mistral/unsloth_Pixtral-12B-2409.yaml ================================================ # Model defaults for unsloth/Pixtral-12B-2409 # Based on Pixtral_(12B)-Vision.ipynb # Also applies to: unsloth/Pixtral-12B-2409-unsloth-bnb-4bit, mistralai/Pixtral-12B-2409, unsloth/Pixtral-12B-2409-bnb-4bit # added inference parameters from unsloth notebook training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 1 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "paged_adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 8 lora_alpha: 8 lora_dropout: 0.0 target_modules: - "all-linear" use_rslora: false use_loftq: false finetune_vision_layers: true finetune_language_layers: true finetune_attention_modules: false finetune_mlp_modules: true logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false temperature: 1.5 min_p: 0.1 ================================================ FILE: studio/backend/assets/configs/model_defaults/mistral/unsloth_mistral-7b-instruct-v0.3-bnb-4bit.yaml ================================================ # Model defaults for unsloth/mistral-7b-instruct-v0.3-bnb-4bit # Based on Mistral_v0.3_(7B)-Conversational.ipynb # Also applies to: unsloth/mistral-7b-instruct-v0.3, mistralai/Mistral-7B-Instruct-v0.3 training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 2 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 16 lora_alpha: 16 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false ================================================ FILE: studio/backend/assets/configs/model_defaults/mistral/unsloth_mistral-7b-v0.3-bnb-4bit.yaml ================================================ # Model defaults for unsloth/mistral-7b-v0.3-bnb-4bit # Based on Mistral_v0.3_(7B)-Alpaca.ipynb # Also applies to: "unsloth/mistral-7b-v0.3", "mistralai/Mistral-7B-v0.3", training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 2 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 16 lora_alpha: 16 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false ================================================ FILE: studio/backend/assets/configs/model_defaults/other/OuteAI_Llama-OuteTTS-1.0-1B.yaml ================================================ # Model defaults for OuteAI/Llama-OuteTTS-1.0-1B # Based on Oute_TTS_(1B).ipynb # Also applies to: OuteAI/Llama-OuteTTS-1.0-1B # added inference parameters from unsloth notebook audio_type: dac training: trust_remote_code: false eval_steps: 0 max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 2 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 128 lora_alpha: 128 lora_dropout: 0.0 target_modules: - "q_proj" - "v_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false temperature: 0.4 top_k: 40 top_p: 0.9 min_p: 0.05 ================================================ FILE: studio/backend/assets/configs/model_defaults/other/Spark-TTS-0.5B_LLM.yaml ================================================ # Model defaults for Spark-TTS-0.5B/LLM # Based on Spark_TTS_(0_5B).ipynb # Also applies to: Spark-TTS-0.5B/LLM # added inference parameters from unsloth notebook audio_type: bicodec training: trust_remote_code: false eval_steps: 0 max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 2 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 128 lora_alpha: 128 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false temperature: 0.8 top_k: 50 top_p: 1.0 ================================================ FILE: studio/backend/assets/configs/model_defaults/other/sesame_csm-1b.yaml ================================================ # Model defaults for sesame/csm-1b # Based on Sesame_CSM_(1B)-TTS.ipynb # Also applies to: sesame/csm-1b audio_type: csm training: trust_remote_code: false eval_steps: 0 max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 2 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 32 lora_alpha: 32 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false ================================================ FILE: studio/backend/assets/configs/model_defaults/other/unsloth_GLM-4.7-Flash.yaml ================================================ # Model defaults for unsloth/GLM-4.7-Flash # Based on GLM_Flash_A100(80GB).py # Also applies to: unsloth/GLM-4.7-Flash-unsloth-bnb-4bit, unsloth/GLM-4.7-Flash-bnb-4bit, THUDM/GLM-4.7-Flash training: trust_remote_code: true max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 4 gradient_accumulation_steps: 2 warmup_steps: 5 max_steps: 60 save_steps: 60 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 8 lora_alpha: 16 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" - "out_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: true temperature: 0.7 top_p: 0.8 top_k: 20 ================================================ FILE: studio/backend/assets/configs/model_defaults/other/unsloth_LFM2-1.2B.yaml ================================================ # Model defaults for unsloth/LFM2-1.2B # Based on Liquid_LFM2_(1.2B)-Conversational.ipynb # Also applies to: unsloth/LFM2-1.2B # added inference parameters from unsloth notebook training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 2 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 16 lora_alpha: 16 lora_dropout: 0.0 target_modules: - "all-linear" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false temperature: 0.3 min_p: 0.15 ================================================ FILE: studio/backend/assets/configs/model_defaults/other/unsloth_Nemotron-3-Nano-30B-A3B.yaml ================================================ # Model defaults for unsloth/Nemotron-3-Nano-30B-A3B # Based on Nemotron-3-Nano-30B-A3B_A100.ipynb # Also applies to: unsloth/Nemotron-3-Nano-30B-A3B # added inference parameters from unsloth guides training: trust_remote_code: true max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 4 gradient_accumulation_steps: 2 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 8 lora_alpha: 16 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" - "in_proj" - "out_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: true temperature: 1.0 top_p: 1.0 ================================================ FILE: studio/backend/assets/configs/model_defaults/other/unsloth_PaddleOCR-VL.yaml ================================================ # Model defaults for unsloth/PaddleOCR-VL # Based on Paddle_OCR_(1B)_Vision.ipynb # Also applies to: unsloth/PaddleOCR-VL # added inference parameters from unsloth notebook training: trust_remote_code: true max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 5e-5 batch_size: 4 gradient_accumulation_steps: 2 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 64 lora_alpha: 64 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false finetune_vision_layers: true finetune_language_layers: true finetune_attention_modules: true finetune_mlp_modules: true logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: true temperature: 1.5 min_p: 0.1 ================================================ FILE: studio/backend/assets/configs/model_defaults/other/unsloth_answerdotai_ModernBERT-large.yaml ================================================ # Model defaults for answerdotai/ModernBERT-large # Based on bert_classification.ipynb training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 1 num_epochs: 0 learning_rate: 5e-5 batch_size: 32 gradient_accumulation_steps: 1 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 16 lora_alpha: 16 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false ================================================ FILE: studio/backend/assets/configs/model_defaults/other/unsloth_orpheus-3b-0.1-ft.yaml ================================================ # Model defaults for unsloth/orpheus-3b-0.1-ft # Based on Orpheus_(3B)-TTS.ipynb # Also applies to: unsloth/orpheus-3b-0.1-ft-unsloth-bnb-4bit, canopylabs/orpheus-3b-0.1-ft, unsloth/orpheus-3b-0.1-ft-bnb-4bit # added inference parameters from unsloth notebook audio_type: snac training: trust_remote_code: false eval_steps: 0 max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 1 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 64 lora_alpha: 64 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false temperature: 0.6 top_p: 0.95 ================================================ FILE: studio/backend/assets/configs/model_defaults/other/unsloth_tinyllama-bnb-4bit.yaml ================================================ # Model defaults for unsloth/tinyllama # Based on TinyLlama_(1.1B)-Alpaca.ipynb # Also applies to: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T training: trust_remote_code: false max_seq_length: 4096 # num_epochs: 1 num_epochs: 0 learning_rate: 2e-5 batch_size: 2 gradient_accumulation_steps: 4 warmup_ratio: 0.1 max_steps: 30 save_steps: 30 weight_decay: 0.1 random_seed: 3407 packing: true train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 32 lora_alpha: 32 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false ================================================ FILE: studio/backend/assets/configs/model_defaults/other/unsloth_whisper-large-v3.yaml ================================================ # Model defaults for unsloth/whisper-large-v3 # Based on Whisper.ipynb # Also applies to: unsloth/whisper-large-v3, openai/whisper-large-v3 audio_type: whisper audio_input: true training: trust_remote_code: false eval_steps: 5 max_seq_length: 448 # num_epochs: 4 num_epochs: 0 learning_rate: 1e-4 batch_size: 1 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 64 lora_alpha: 64 lora_dropout: 0.0 target_modules: - "q_proj" - "v_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false ================================================ FILE: studio/backend/assets/configs/model_defaults/phi/unsloth_Phi-3-medium-4k-instruct.yaml ================================================ # Model defaults for unsloth/Phi-3-medium-4k-instruct # Based on Phi_3_Medium-Conversational.ipynb # Also applies to: "unsloth/Phi-3-medium-4k-instruct-bnb-4bit", "microsoft/Phi-3-medium-4k-instruct", training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 2 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 16 lora_alpha: 16 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false ================================================ FILE: studio/backend/assets/configs/model_defaults/phi/unsloth_Phi-3.5-mini-instruct.yaml ================================================ # Model defaults for unsloth/Phi-3.5-mini-instruct # Based on Phi_3.5_Mini-Conversational.ipynb # Also applies to: "unsloth/Phi-3.5-mini-instruct-bnb-4bit", "microsoft/Phi-3.5-mini-instruct" training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 2 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 16 lora_alpha: 16 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false ================================================ FILE: studio/backend/assets/configs/model_defaults/phi/unsloth_Phi-4.yaml ================================================ # Model defaults for unsloth/Phi-4 # Based on Phi_4-Conversational.ipynb # Also applies to: unsloth/phi-4-unsloth-bnb-4bit, microsoft/phi-4, unsloth/phi-4-bnb-4bit # added inference parameters from unsloth guides training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 2 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 16 lora_alpha: 16 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false temperature: 0.8 top_p: 0.95 ================================================ FILE: studio/backend/assets/configs/model_defaults/qwen/imdatta0_tiny_qwen3_moe_2.8B_0.7B.yaml ================================================ # Model defaults for imdatta0/tiny_qwen3_moe_2.8B_0.7B # Based on TinyQwen3_MoE.py # Dummy model of qwen3moe architecture created to fit in T4 # MoE model - includes gate_up_proj for MoE layers training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 1 gradient_accumulation_steps: 1 warmup_steps: 5 max_steps: 50 save_steps: 50 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 32 lora_alpha: 64 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" - "gate_up_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false temperature: 0.6 top_k: 20 top_p: 0.95 ================================================ FILE: studio/backend/assets/configs/model_defaults/qwen/unsloth_Qwen2-7B.yaml ================================================ # Model defaults for unsloth/Qwen2-7B # Based on Qwen2_(7B)-Alpaca.ipynb # Also applies to: unsloth/Qwen2-7B-bnb-4bit, Qwen/Qwen2-7B training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 2 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 16 lora_alpha: 16 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false ================================================ FILE: studio/backend/assets/configs/model_defaults/qwen/unsloth_Qwen2-VL-7B-Instruct.yaml ================================================ # Model defaults for unsloth/Qwen2-VL-7B-Instruct # Based on Qwen2_VL_(7B)-Vision.ipynb # Also applies to: unsloth/Qwen2-VL-7B-Instruct-unsloth-bnb-4bit, Qwen/Qwen2-VL-7B-Instruct, unsloth/Qwen2-VL-7B-Instruct-bnb-4bit # added inference parameters from unsloth notebook training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 2 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 16 lora_alpha: 16 lora_dropout: 0.0 target_modules: - "all-linear" use_rslora: false use_loftq: false finetune_vision_layers: true finetune_language_layers: true finetune_attention_modules: true finetune_mlp_modules: true logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false temperature: 1.5 min_p: 0.1 ================================================ FILE: studio/backend/assets/configs/model_defaults/qwen/unsloth_Qwen2.5-1.5B-Instruct.yaml ================================================ # Model defaults for unsloth/Qwen2.5-1.5B-Instruct # Based on nemo_gym_sudoku.ipynb # Also applies to: unsloth/Qwen2.5-1.5B-Instruct-unsloth-bnb-4bit, Qwen/Qwen2.5-1.5B-Instruct, unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit training: trust_remote_code: false max_seq_length: 4096 # num_epochs: 4 num_epochs: 0 learning_rate: 1e-5 batch_size: 1 gradient_accumulation_steps: 64 warmup_ratio: 0.1 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 42 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 4 lora_alpha: 8 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false ================================================ FILE: studio/backend/assets/configs/model_defaults/qwen/unsloth_Qwen2.5-7B.yaml ================================================ # Model defaults for unsloth/Qwen2.5-7B # Based on Qwen2.5_(7B)-Alpaca.ipynb # Also applies to: unsloth/Qwen2.5-7B-unsloth-bnb-4bit, Qwen/Qwen2.5-7B, unsloth/Qwen2.5-7B-bnb-4bit training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 2 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 16 lora_alpha: 16 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false ================================================ FILE: studio/backend/assets/configs/model_defaults/qwen/unsloth_Qwen2.5-Coder-1.5B-Instruct.yaml ================================================ # Model defaults for unsloth/Qwen2.5-Coder-1.5B-Instruct # Based on Qwen2.5_Coder_(1.5B)-Tool_Calling.ipynb # Also applies to: unsloth/Qwen2.5-Coder-1.5B-Instruct-bnb-4bit, Qwen/Qwen2.5-Coder-1.5B-Instruct training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 2 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 16 lora_alpha: 16 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false ================================================ FILE: studio/backend/assets/configs/model_defaults/qwen/unsloth_Qwen2.5-Coder-14B-Instruct.yaml ================================================ # Model defaults for unsloth/Qwen2.5-Coder-14B-Instruct # Based on Qwen2.5_Coder_(14B)-Conversational.ipynb # Also applies to: unsloth/Qwen2.5-Coder-14B-Instruct-bnb-4bit, Qwen/Qwen2.5-Coder-14B-Instruct # added inference parameters from unsloth notebook training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 1 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "paged_adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 16 lora_alpha: 16 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false temperature: 1.5 min_p: 0.1 ================================================ FILE: studio/backend/assets/configs/model_defaults/qwen/unsloth_Qwen2.5-Coder-7B-Instruct-bnb-4bit.yaml ================================================ # Model defaults for unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit # Based on CodeForces-cot-Finetune_for_Reasoning_on_CodeForces.ipynb # Also applies to: unsloth/Qwen2.5-Coder-7B-Instruct, Qwen/Qwen2.5-Coder-7B-Instruct training: trust_remote_code: false max_seq_length: 32768 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 2 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 16 lora_alpha: 16 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false ================================================ FILE: studio/backend/assets/configs/model_defaults/qwen/unsloth_Qwen2.5-VL-7B-Instruct-bnb-4bit.yaml ================================================ # Model defaults for unsloth/Qwen2.5-VL-7B-Instruct-bnb-4bit # Based on Qwen2.5_VL_(7B)-Vision.ipynb # Also applies to: unsloth/Qwen2.5-VL-7B-Instruct, Qwen/Qwen2.5-VL-7B-Instruct, unsloth/Qwen2.5-VL-7B-Instruct-unsloth-bnb-4bit # added inference parameters from unsloth notebook training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 2 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 16 lora_alpha: 16 lora_dropout: 0.0 target_modules: - "all-linear" use_rslora: false use_loftq: false finetune_vision_layers: true finetune_language_layers: true finetune_attention_modules: true finetune_mlp_modules: true logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false temperature: 1.5 min_p: 0.1 ================================================ FILE: studio/backend/assets/configs/model_defaults/qwen/unsloth_Qwen3-0.6B.yaml ================================================ # Model defaults for unsloth/Qwen3-0.6B # Based on Qwen3_(0_6B)-Phone_Deployment.ipynb # Also applies to: unsloth/Qwen3-0.6B-unsloth-bnb-4bit, Qwen/Qwen3-0.6B, unsloth/Qwen3-0.6B-bnb-4bit, Qwen/Qwen3-0.6B-FP8, unsloth/Qwen3-0.6B-FP8 # added inference parameters from Ollama training: trust_remote_code: false max_seq_length: 1024 # num_epochs: 4 num_epochs: 0 learning_rate: 5e-5 batch_size: 2 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 16 lora_alpha: 16 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false temperature: 0.6 top_k: 20 top_p: 0.95 ================================================ FILE: studio/backend/assets/configs/model_defaults/qwen/unsloth_Qwen3-14B-Base-unsloth-bnb-4bit.yaml ================================================ # Model defaults for unsloth/Qwen3-14B-Base # Based on Qwen3_(14B)-Alpaca.ipynb # Also applies to: unsloth/Qwen3-14B-Base, Qwen/Qwen3-14B-Base, unsloth/Qwen3-14B-Base-bnb-4bit # added inference parameters from Ollama training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 2 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 16 lora_alpha: 16 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false temperature: 0.6 top_k: 20 top_p: 0.95 ================================================ FILE: studio/backend/assets/configs/model_defaults/qwen/unsloth_Qwen3-14B.yaml ================================================ # Model defaults for unsloth/Qwen3-14B # Based on Qwen3_(14B).ipynb # Also applies to: unsloth/Qwen3-14B-unsloth-bnb-4bit, Qwen/Qwen3-14B, unsloth/Qwen3-14B-bnb-4bit, Qwen/Qwen3-14B-FP8, unsloth/Qwen3-14B-FP8 # added inference parameters from Ollama training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 2 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 32 lora_alpha: 32 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false temperature: 0.6 top_k: 20 top_p: 0.95 ================================================ FILE: studio/backend/assets/configs/model_defaults/qwen/unsloth_Qwen3-30B-A3B-Instruct-2507.yaml ================================================ # Model defaults for unsloth/Qwen3-30B-A3B-Instruct-2507 # Based on Qwen3_MoE.py # Also applies to: Qwen/Qwen3-30B-A3B-Instruct-2507, unsloth/Qwen3-30B-A3B-Instruct-2507-bnb-4bit # MoE model - includes gate_up_proj for MoE layers training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 1 gradient_accumulation_steps: 1 warmup_steps: 5 max_steps: 50 save_steps: 50 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 32 lora_alpha: 64 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" - "gate_up_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false temperature: 0.6 top_k: 20 top_p: 0.95 ================================================ FILE: studio/backend/assets/configs/model_defaults/qwen/unsloth_Qwen3-32B.yaml ================================================ # Model defaults for unsloth/Qwen3-32B # Based on Qwen3_(32B)_A100-Reasoning-Conversational.ipynb # Also applies to: unsloth/Qwen3-32B-unsloth-bnb-4bit, Qwen/Qwen3-32B, unsloth/Qwen3-32B-bnb-4bit, Qwen/Qwen3-32B-FP8, unsloth/Qwen3-32B-FP8 # added inference parameters from Ollama training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 2 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 32 lora_alpha: 32 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false temperature: 0.6 top_k: 20 top_p: 0.95 ================================================ FILE: studio/backend/assets/configs/model_defaults/qwen/unsloth_Qwen3-4B-Instruct-2507.yaml ================================================ # Model defaults for unsloth/Qwen3-4B-Instruct-2507 # Based on Qwen3_(4B)-Instruct.ipynb # Also applies to: unsloth/Qwen3-4B-Instruct-2507-unsloth-bnb-4bit, Qwen/Qwen3-4B-Instruct-2507, unsloth/Qwen3-4B-Instruct-2507-bnb-4bit, Qwen/Qwen3-4B-Instruct-2507-FP8, unsloth/Qwen3-4B-Instruct-2507-FP8 # added inference parameters from unsloth guides training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 2 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 32 lora_alpha: 32 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false temperature: 0.7 top_p: 0.80 top_k: 20 min_p: 0.00 ================================================ FILE: studio/backend/assets/configs/model_defaults/qwen/unsloth_Qwen3-4B-Thinking-2507.yaml ================================================ # Model defaults for unsloth/Qwen3-4B-Thinking-2507 # Based on Qwen3_(4B)-Thinking.ipynb # Also applies to: unsloth/Qwen3-4B-Thinking-2507-unsloth-bnb-4bit, Qwen/Qwen3-4B-Thinking-2507, unsloth/Qwen3-4B-Thinking-2507-bnb-4bit, Qwen/Qwen3-4B-Thinking-2507-FP8, unsloth/Qwen3-4B-Thinking-2507-FP8 # added inference parameters from unsloth guides training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 2 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 32 lora_alpha: 32 lora_dropout: 0.0 target_modules: - "q_proj" - "k_proj" - "v_proj" - "o_proj" - "gate_proj" - "up_proj" - "down_proj" use_rslora: false use_loftq: false logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false temperature: 0.6 top_p: 0.95 top_k: 20 min_p: 0.00 ================================================ FILE: studio/backend/assets/configs/model_defaults/qwen/unsloth_Qwen3-VL-8B-Instruct-unsloth-bnb-4bit.yaml ================================================ # Model defaults for unsloth/Qwen3-VL-8B-Instruct # Based on Qwen3_VL_(8B)-Vision.ipynb # Also applies to: Qwen/Qwen3-VL-8B-Instruct-FP8, unsloth/Qwen3-VL-8B-Instruct-FP8, unsloth/Qwen3-VL-8B-Instruct, Qwen/Qwen3-VL-8B-Instruct, unsloth/Qwen3-VL-8B-Instruct-bnb-4bit # added inference parameters from unsloth guides training: trust_remote_code: false max_seq_length: 2048 # num_epochs: 4 num_epochs: 0 learning_rate: 2e-4 batch_size: 2 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 30 save_steps: 30 weight_decay: 0.001 random_seed: 3407 packing: false train_on_completions: true gradient_checkpointing: "unsloth" optim: "adamw_8bit" lr_scheduler_type: "linear" lora: lora_r: 16 lora_alpha: 16 lora_dropout: 0.0 target_modules: - "all-linear" use_rslora: false use_loftq: false finetune_vision_layers: true finetune_language_layers: true finetune_attention_modules: true finetune_mlp_modules: true logging: enable_wandb: false wandb_project: "llm-finetuning" enable_tensorboard: false tensorboard_dir: "runs" log_frequency: 10 inference: trust_remote_code: false temperature: 0.7 top_p: 0.8 top_k: 20 ================================================ FILE: studio/backend/assets/configs/vision_lora.yaml ================================================ model: unsloth/Qwen2-VL-2B-Instruct-bnb-4bit data: dataset: philschmid/amazon-product-descriptions-vlm format_type: auto training: training_type: lora max_seq_length: 2048 load_in_4bit: true output_dir: outputs num_epochs: 1 learning_rate: 0.0002 batch_size: 1 gradient_accumulation_steps: 4 warmup_steps: 5 max_steps: 0 save_steps: 0 weight_decay: 0.01 random_seed: 3407 packing: false train_on_completions: false gradient_checkpointing: "unsloth" lora: lora_r: 64 lora_alpha: 16 lora_dropout: 0.0 target_modules: "" # vision uses vision_all_linear by default vision_all_linear: true use_rslora: false use_loftq: false finetune_vision_layers: true finetune_language_layers: true finetune_attention_modules: true finetune_mlp_modules: true logging: enable_wandb: false wandb_project: unsloth-training enable_tensorboard: false tensorboard_dir: runs ================================================ FILE: studio/backend/auth/.gitkeep ================================================ ================================================ FILE: studio/backend/auth/__init__.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Authentication module for JWT-based auth with SQLite storage. """ from .authentication import ( create_access_token, create_refresh_token, refresh_access_token, get_current_subject, get_current_subject_allow_password_change, reload_secret, ) from .storage import ( DEFAULT_ADMIN_USERNAME, clear_bootstrap_password, generate_bootstrap_password, get_bootstrap_password, is_initialized, create_initial_user, ensure_default_admin, get_jwt_secret, get_user_and_secret, load_jwt_secret, requires_password_change, save_refresh_token, update_password, verify_refresh_token, revoke_user_refresh_tokens, ) from .hashing import hash_password, verify_password __all__ = [ "create_access_token", "create_refresh_token", "refresh_access_token", "get_current_subject", "get_current_subject_allow_password_change", "reload_secret", "DEFAULT_ADMIN_USERNAME", "clear_bootstrap_password", "generate_bootstrap_password", "get_bootstrap_password", "is_initialized", "create_initial_user", "ensure_default_admin", "get_jwt_secret", "get_user_and_secret", "load_jwt_secret", "requires_password_change", "save_refresh_token", "update_password", "verify_refresh_token", "revoke_user_refresh_tokens", "hash_password", "verify_password", ] ================================================ FILE: studio/backend/auth/authentication.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 import secrets from datetime import datetime, timedelta, timezone from typing import Optional, Tuple from fastapi import Depends, HTTPException, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer import jwt from .storage import ( get_jwt_secret, get_user_and_secret, load_jwt_secret, save_refresh_token, verify_refresh_token, ) ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 60 REFRESH_TOKEN_EXPIRE_DAYS = 7 security = HTTPBearer() # Reads Authorization: Bearer def _get_secret_for_subject(subject: str) -> str: secret = get_jwt_secret(subject) if secret is None: raise HTTPException( status_code = status.HTTP_401_UNAUTHORIZED, detail = "Invalid or expired token", ) return secret def _decode_subject_without_verification(token: str) -> Optional[str]: try: payload = jwt.decode( token, options = {"verify_signature": False, "verify_exp": False}, ) except jwt.InvalidTokenError: return None subject = payload.get("sub") return subject if isinstance(subject, str) else None def create_access_token( subject: str, expires_delta: Optional[timedelta] = None, ) -> str: """ Create a signed JWT for the given subject (e.g. username). Tokens are valid across restarts because the signing secret is stored in SQLite. """ to_encode = {"sub": subject} expire = datetime.now(timezone.utc) + ( expires_delta or timedelta(minutes = ACCESS_TOKEN_EXPIRE_MINUTES) ) to_encode.update({"exp": expire}) return jwt.encode( to_encode, _get_secret_for_subject(subject), algorithm = ALGORITHM, ) def create_refresh_token(subject: str) -> str: """ Create a random refresh token, store its hash in SQLite, and return it. Refresh tokens are opaque (not JWTs) and expire after REFRESH_TOKEN_EXPIRE_DAYS. """ token = secrets.token_urlsafe(48) expires_at = datetime.now(timezone.utc) + timedelta(days = REFRESH_TOKEN_EXPIRE_DAYS) save_refresh_token(token, subject, expires_at.isoformat()) return token def refresh_access_token(refresh_token: str) -> Tuple[Optional[str], Optional[str]]: """ Validate a refresh token and issue a new access token. The refresh token itself is NOT consumed — it stays valid until expiry. Returns a new access_token or None if the refresh token is invalid/expired. """ username = verify_refresh_token(refresh_token) if username is None: return None, None return create_access_token(subject = username), username def reload_secret() -> None: """ Keep legacy API compatibility for callers expecting auth storage init. Auth now resolves the current signing secret directly from SQLite. """ load_jwt_secret() async def get_current_subject( credentials: HTTPAuthorizationCredentials = Depends(security), ) -> str: """Validate JWT and require the password-change flow to be completed.""" return await _get_current_subject( credentials, allow_password_change = False, ) async def get_current_subject_allow_password_change( credentials: HTTPAuthorizationCredentials = Depends(security), ) -> str: """Validate JWT but allow access to the password-change endpoint.""" return await _get_current_subject( credentials, allow_password_change = True, ) async def _get_current_subject( credentials: HTTPAuthorizationCredentials, *, allow_password_change: bool, ) -> str: """ FastAPI dependency to validate the JWT and return the subject. Use this as a dependency on routes that should be protected, e.g.: @router.get("/secure") async def secure_endpoint(current_subject: str = Depends(get_current_subject)): ... """ token = credentials.credentials subject = _decode_subject_without_verification(token) if subject is None: raise HTTPException( status_code = status.HTTP_401_UNAUTHORIZED, detail = "Invalid token payload", ) record = get_user_and_secret(subject) if record is None: raise HTTPException( status_code = status.HTTP_401_UNAUTHORIZED, detail = "Invalid or expired token", ) _salt, _pwd_hash, jwt_secret, must_change_password = record try: payload = jwt.decode(token, jwt_secret, algorithms = [ALGORITHM]) if payload.get("sub") != subject: raise HTTPException( status_code = status.HTTP_401_UNAUTHORIZED, detail = "Invalid token payload", ) if must_change_password and not allow_password_change: raise HTTPException( status_code = status.HTTP_403_FORBIDDEN, detail = "Password change required", ) return subject except jwt.InvalidTokenError: raise HTTPException( status_code = status.HTTP_401_UNAUTHORIZED, detail = "Invalid or expired token", ) ================================================ FILE: studio/backend/auth/hashing.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Password hashing utilities using PBKDF2. """ import hashlib import hmac import secrets from typing import Tuple def hash_password(password: str, salt: str | None = None) -> Tuple[str, str]: """ Hash a password using PBKDF2-HMAC-SHA256. Returns (salt, hex_hash) tuple. """ if salt is None: salt = secrets.token_hex(16) dk = hashlib.pbkdf2_hmac( "sha256", password.encode("utf-8"), salt.encode("utf-8"), 100_000, # 100k iterations ) return salt, dk.hex() def verify_password(password: str, salt: str, hashed: str) -> bool: """ Verify a password against a stored salt and hash. Uses constant-time comparison to prevent timing attacks. """ dk = hashlib.pbkdf2_hmac( "sha256", password.encode("utf-8"), salt.encode("utf-8"), 100_000, ) return hmac.compare_digest(dk.hex(), hashed) ================================================ FILE: studio/backend/auth/storage.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ SQLite storage for authentication data (user credentials + JWT secret). """ import hashlib import secrets import sqlite3 from datetime import datetime, timezone from typing import Optional, Tuple from utils.paths import auth_db_path, ensure_dir DB_PATH = auth_db_path() DEFAULT_ADMIN_USERNAME = "unsloth" # Plaintext bootstrap password file — lives beside auth.db, deleted on # first password change so the credential never lingers on disk. _BOOTSTRAP_PW_PATH = DB_PATH.parent / ".bootstrap_password" # In-process cache so we don't re-read the file on every HTML serve. _bootstrap_password: Optional[str] = None def generate_bootstrap_password() -> str: """Generate a 4-word diceware passphrase and persist it to disk. The passphrase is written to ``_BOOTSTRAP_PW_PATH`` so that it survives server restarts (the DB only stores the *hash*). On subsequent calls / restarts, the persisted value is returned. """ global _bootstrap_password # 1. Already cached in this process? if _bootstrap_password is not None: return _bootstrap_password # 2. Already persisted from a previous run? if _BOOTSTRAP_PW_PATH.is_file(): _bootstrap_password = _BOOTSTRAP_PW_PATH.read_text().strip() if _bootstrap_password: return _bootstrap_password # 3. First-ever startup — generate a fresh passphrase. import diceware _bootstrap_password = diceware.get_passphrase( options = diceware.handle_options(args = ["-n", "4", "-d", "", "-c"]) ) # Persist so the *same* passphrase is used if the server restarts # before the user changes the password. ensure_dir(_BOOTSTRAP_PW_PATH.parent) _BOOTSTRAP_PW_PATH.write_text(_bootstrap_password) return _bootstrap_password def get_bootstrap_password() -> Optional[str]: """Return the cached bootstrap password, or None if not yet generated.""" return _bootstrap_password def clear_bootstrap_password() -> None: """Delete the persisted bootstrap password file (called after password change).""" global _bootstrap_password _bootstrap_password = None if _BOOTSTRAP_PW_PATH.is_file(): _BOOTSTRAP_PW_PATH.unlink(missing_ok = True) def _hash_token(token: str) -> str: """SHA-256 hash helper used for refresh token storage.""" return hashlib.sha256(token.encode("utf-8")).hexdigest() def get_connection() -> sqlite3.Connection: """Get a connection to the auth database, creating tables if needed.""" ensure_dir(DB_PATH.parent) conn = sqlite3.connect(DB_PATH) conn.row_factory = sqlite3.Row conn.execute( """ CREATE TABLE IF NOT EXISTS auth_user ( id INTEGER PRIMARY KEY, username TEXT UNIQUE NOT NULL, password_salt TEXT NOT NULL, password_hash TEXT NOT NULL, jwt_secret TEXT NOT NULL, must_change_password INTEGER NOT NULL DEFAULT 0 ); """ ) conn.execute( """ CREATE TABLE IF NOT EXISTS refresh_tokens ( id INTEGER PRIMARY KEY, token_hash TEXT NOT NULL, username TEXT NOT NULL, expires_at TEXT NOT NULL ); """ ) columns = {row["name"] for row in conn.execute("PRAGMA table_info(auth_user)")} if "must_change_password" not in columns: conn.execute( "ALTER TABLE auth_user ADD COLUMN must_change_password INTEGER NOT NULL DEFAULT 0" ) conn.commit() return conn def is_initialized() -> bool: """Check if auth is ready for login (at least one user exists in DB).""" conn = get_connection() cur = conn.execute("SELECT COUNT(*) AS c FROM auth_user") row = cur.fetchone() conn.close() return bool(row["c"]) def create_initial_user( username: str, password: str, jwt_secret: str, *, must_change_password: bool = False, ) -> None: """ Create the initial admin user in the database. Raises sqlite3.IntegrityError if username already exists. """ from .hashing import hash_password salt, pwd_hash = hash_password(password) conn = get_connection() try: conn.execute( """ INSERT INTO auth_user ( username, password_salt, password_hash, jwt_secret, must_change_password ) VALUES (?, ?, ?, ?, ?) """, (username, salt, pwd_hash, jwt_secret, int(must_change_password)), ) conn.commit() finally: conn.close() def delete_user(username: str) -> None: """ Delete a user from the database. Used for rollback when user creation fails partway through bootstrap. """ conn = get_connection() try: conn.execute("DELETE FROM auth_user WHERE username = ?", (username,)) conn.commit() finally: conn.close() def get_user_and_secret(username: str) -> Optional[Tuple[str, str, str, bool]]: """ Get user's password salt, hash, and JWT secret. Returns (password_salt, password_hash, jwt_secret, must_change_password) or None if user not found. """ conn = get_connection() try: cur = conn.execute( """ SELECT password_salt, password_hash, jwt_secret, must_change_password FROM auth_user WHERE username = ? """, (username,), ) row = cur.fetchone() if not row: return None return ( row["password_salt"], row["password_hash"], row["jwt_secret"], bool(row["must_change_password"]), ) finally: conn.close() def get_jwt_secret(username: str) -> Optional[str]: """Return the current JWT signing secret for a user.""" conn = get_connection() try: cur = conn.execute( "SELECT jwt_secret FROM auth_user WHERE username = ?", (username,), ) row = cur.fetchone() return row["jwt_secret"] if row else None finally: conn.close() def requires_password_change(username: str) -> bool: """Return whether the user must change the seeded default password.""" conn = get_connection() try: cur = conn.execute( "SELECT must_change_password FROM auth_user WHERE username = ?", (username,), ) row = cur.fetchone() return bool(row and row["must_change_password"]) finally: conn.close() def load_jwt_secret() -> str: """ Load the JWT secret from the database. Raises RuntimeError if no auth user has been created yet. """ conn = get_connection() try: cur = conn.execute("SELECT jwt_secret FROM auth_user LIMIT 1") row = cur.fetchone() if not row: raise RuntimeError( "Auth is not initialized. Wait for the seeded admin bootstrap to complete." ) return row["jwt_secret"] finally: conn.close() def ensure_default_admin() -> bool: """Seed the default admin account on first startup. Uses a randomly generated diceware passphrase as the bootstrap password. Returns True when the default admin was created in this call. """ bootstrap_pw = generate_bootstrap_password() try: create_initial_user( username = DEFAULT_ADMIN_USERNAME, password = bootstrap_pw, jwt_secret = secrets.token_urlsafe(64), must_change_password = True, ) return True except sqlite3.IntegrityError: return False def update_password(username: str, new_password: str) -> bool: """Update password, clear first-login requirement, rotate JWT secret.""" from .hashing import hash_password salt, pwd_hash = hash_password(new_password) jwt_secret = secrets.token_urlsafe(64) conn = get_connection() try: cursor = conn.execute( """ UPDATE auth_user SET password_salt = ?, password_hash = ?, jwt_secret = ?, must_change_password = 0 WHERE username = ? """, (salt, pwd_hash, jwt_secret, username), ) conn.commit() if cursor.rowcount > 0: clear_bootstrap_password() return cursor.rowcount > 0 finally: conn.close() def save_refresh_token(token: str, username: str, expires_at: str) -> None: """ Store a hashed refresh token with its associated username and expiry. """ token_hash = _hash_token(token) conn = get_connection() try: conn.execute( """ INSERT INTO refresh_tokens (token_hash, username, expires_at) VALUES (?, ?, ?) """, (token_hash, username, expires_at), ) conn.commit() finally: conn.close() def verify_refresh_token(token: str) -> Optional[str]: """ Verify a refresh token and return the username. Returns the username if valid and not expired, None otherwise. The token is NOT consumed — it stays valid until it expires. """ token_hash = _hash_token(token) conn = get_connection() try: # Clean up any expired tokens while we're here conn.execute( "DELETE FROM refresh_tokens WHERE expires_at < ?", (datetime.now(timezone.utc).isoformat(),), ) conn.commit() cur = conn.execute( """ SELECT id, username, expires_at FROM refresh_tokens WHERE token_hash = ? """, (token_hash,), ) row = cur.fetchone() if row is None: return None # Check expiry expires_at = datetime.fromisoformat(row["expires_at"]) if datetime.now(timezone.utc) > expires_at: conn.execute("DELETE FROM refresh_tokens WHERE id = ?", (row["id"],)) conn.commit() return None return row["username"] finally: conn.close() def revoke_user_refresh_tokens(username: str) -> None: """Revoke all refresh tokens for a user (e.g. on logout).""" conn = get_connection() try: conn.execute("DELETE FROM refresh_tokens WHERE username = ?", (username,)) conn.commit() finally: conn.close() ================================================ FILE: studio/backend/colab.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Colab-specific helpers for running Unsloth Studio. Uses Colab's built-in proxy - no external tunneling needed! """ from pathlib import Path import sys def _bootstrap_studio_venv() -> None: """Expose the Studio venv's site-packages to the current interpreter. On Colab, notebook cells run outside the venv subshell. Instead of installing the full stack into system Python, we prepend the venv's site-packages so that packages like structlog, fastapi, etc. are importable from notebook cells and take priority over system copies. """ venv_lib = Path.home() / ".unsloth" / "studio" / ".venv" / "lib" if not venv_lib.exists(): import warnings warnings.warn( f"Studio venv not found at {venv_lib.parent} -- run 'unsloth studio setup' first", stacklevel = 2, ) return for sp in venv_lib.glob("python*/site-packages"): sp_str = str(sp) if sp_str not in sys.path: sys.path.insert(0, sp_str) _bootstrap_studio_venv() # Add backend to path early so local modules like loggers can be imported backend_path = str(Path(__file__).parent) if backend_path not in sys.path: sys.path.insert(0, backend_path) from loggers import get_logger logger = get_logger(__name__) def get_colab_url(port: int = 8888) -> str: """ Get the actual Colab proxy URL for a port. """ try: from google.colab.output import eval_js # Use Colab's proxy mechanism url = eval_js(f"google.colab.kernel.proxyPort({port})", timeout_sec = 5) return url if url else f"http://localhost:{port}" except Exception as e: logger.info(f"Note: Could not get Colab URL ({e})") return f"http://localhost:{port}" def show_link(port: int = 8888): """Display a styled clickable link to the UI.""" from IPython.display import display, HTML # Get real Colab proxy URL url = get_colab_url(port) short_url = ( url[: url.index("-", url.index(f"{port}-") + len(str(port)) + 1) + 1] + "..." if f"{port}-" in url else url ) html = f"""

Unsloth Studio is Ready!

Open Unsloth Studio

{short_url}

""" display(HTML(html)) def start(port: int = 8888): """ Start Unsloth Studio server in Colab and display the URL. Usage: from colab import start start() """ import sys logger.info("🦥 Starting Unsloth Studio...") logger.info(" Loading backend...") from run import run_server # Auto-detect frontend path repo_root = Path(__file__).parent.parent frontend_path = repo_root / "frontend" / "dist" if not frontend_path.exists(): logger.info("❌ Frontend not built! Please run the setup cell first.") return logger.info(" Starting server...") # Start server silently run_server(host = "0.0.0.0", port = port, frontend_path = frontend_path, silent = True) logger.info(" Server started!") # Show the clickable link with real URL show_link(port) if __name__ == "__main__": start() ================================================ FILE: studio/backend/core/__init__.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Unified core module for Unsloth backend Imports are LAZY (via __getattr__) so that training subprocesses can import core.training.worker without pulling in heavy ML dependencies like unsloth, transformers, or torch before the version activation code has a chance to run. """ import sys from pathlib import Path # Ensure the backend directory is on sys.path so that bare "from utils.*" # imports used throughout the backend work when core is imported as a package # (e.g. from the CLI: "from studio.backend.core import ModelConfig"). _backend_dir = str(Path(__file__).resolve().parent.parent) if _backend_dir not in sys.path: sys.path.insert(0, _backend_dir) __all__ = [ # Inference "InferenceBackend", "get_inference_backend", # Training "get_training_backend", "TrainingBackend", "TrainingProgress", # Config "ModelConfig", "is_vision_model", "scan_trained_loras", "load_model_defaults", "get_base_model_from_lora", # Utils "format_and_template_dataset", "normalize_path", "is_local_path", "is_model_cached", "without_hf_auth", "format_error_message", "get_gpu_memory_info", "log_gpu_memory", "get_device", "is_apple_silicon", "clear_gpu_cache", "DeviceType", ] def __getattr__(name): # Inference if name in ("InferenceBackend", "get_inference_backend"): from .inference import InferenceBackend, get_inference_backend globals()["InferenceBackend"] = InferenceBackend globals()["get_inference_backend"] = get_inference_backend return globals()[name] # Training if name in ("TrainingBackend", "get_training_backend", "TrainingProgress"): from .training import TrainingBackend, get_training_backend, TrainingProgress globals()["TrainingBackend"] = TrainingBackend globals()["get_training_backend"] = get_training_backend globals()["TrainingProgress"] = TrainingProgress return globals()[name] # Config (from utils.models) if name in ( "is_vision_model", "ModelConfig", "scan_trained_loras", "load_model_defaults", "get_base_model_from_lora", ): from utils.models import ( is_vision_model, ModelConfig, scan_trained_loras, load_model_defaults, get_base_model_from_lora, ) globals()["is_vision_model"] = is_vision_model globals()["ModelConfig"] = ModelConfig globals()["scan_trained_loras"] = scan_trained_loras globals()["load_model_defaults"] = load_model_defaults globals()["get_base_model_from_lora"] = get_base_model_from_lora return globals()[name] # Paths if name in ("normalize_path", "is_local_path", "is_model_cached"): from utils.paths import normalize_path, is_local_path, is_model_cached globals()["normalize_path"] = normalize_path globals()["is_local_path"] = is_local_path globals()["is_model_cached"] = is_model_cached return globals()[name] # Utils if name in ("without_hf_auth", "format_error_message"): from utils.utils import without_hf_auth, format_error_message globals()["without_hf_auth"] = without_hf_auth globals()["format_error_message"] = format_error_message return globals()[name] # Hardware if name in ( "get_device", "is_apple_silicon", "clear_gpu_cache", "get_gpu_memory_info", "log_gpu_memory", "DeviceType", ): from utils.hardware import ( get_device, is_apple_silicon, clear_gpu_cache, get_gpu_memory_info, log_gpu_memory, DeviceType, ) globals()["get_device"] = get_device globals()["is_apple_silicon"] = is_apple_silicon globals()["clear_gpu_cache"] = clear_gpu_cache globals()["get_gpu_memory_info"] = get_gpu_memory_info globals()["log_gpu_memory"] = log_gpu_memory globals()["DeviceType"] = DeviceType return globals()[name] # Datasets if name == "format_and_template_dataset": from utils.datasets import format_and_template_dataset globals()["format_and_template_dataset"] = format_and_template_dataset return format_and_template_dataset raise AttributeError(f"module 'core' has no attribute {name!r}") ================================================ FILE: studio/backend/core/data_recipe/__init__.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Data Recipe core (DataDesigner wrapper + job runner). """ from .jobs import JobManager, get_job_manager __all__ = ["JobManager", "get_job_manager"] ================================================ FILE: studio/backend/core/data_recipe/huggingface.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 from __future__ import annotations import json from pathlib import Path from utils.paths import recipe_datasets_root, resolve_dataset_path _DATA_DESIGNER_FOOTER = ( 'Made with ❤️ using 🎨 ' 'NeMo Data Designer' ) _UNSLOTH_STUDIO_FOOTER = ( 'Made with ❤️ using 🦥 ' "Unsloth Studio" ) class RecipeDatasetPublishError(ValueError): """Raised when a recipe dataset cannot be published to Hugging Face.""" def _resolve_recipe_artifact_path(artifact_path: str) -> Path: root = recipe_datasets_root().expanduser().resolve() candidate = resolve_dataset_path(artifact_path).expanduser() resolved = candidate.resolve(strict = False) try: resolved.relative_to(root) except ValueError as exc: raise RecipeDatasetPublishError( "This execution artifact is outside the Recipe Studio dataset storage." ) from exc if not resolved.exists(): raise RecipeDatasetPublishError("Execution artifacts are no longer available.") if not resolved.is_dir(): raise RecipeDatasetPublishError( "Execution artifact path is not a dataset folder." ) return resolved def publish_recipe_dataset( *, artifact_path: str, repo_id: str, description: str, hf_token: str | None = None, private: bool = False, ) -> str: dataset_path = _resolve_recipe_artifact_path(artifact_path) try: from data_designer.engine.storage.artifact_storage import ( FINAL_DATASET_FOLDER_NAME, METADATA_FILENAME, PROCESSORS_OUTPUTS_FOLDER_NAME, SDG_CONFIG_FILENAME, ) from data_designer.integrations.huggingface.client import ( HuggingFaceHubClient, HuggingFaceHubClientUploadError, ) from data_designer.integrations.huggingface.dataset_card import ( DataDesignerDatasetCard, ) except ImportError as exc: raise RecipeDatasetPublishError( "NeMo Data Designer Hugging Face integration is not installed." ) from exc try: client = HuggingFaceHubClient(token = hf_token) client._validate_repo_id(repo_id = repo_id) client._validate_dataset_path(base_dataset_path = dataset_path) client._create_or_get_repo(repo_id = repo_id, private = private) metadata_path = dataset_path / METADATA_FILENAME builder_config_path = dataset_path / SDG_CONFIG_FILENAME with metadata_path.open(encoding = "utf-8") as fh: metadata = json.load(fh) builder_config = None if builder_config_path.exists(): with builder_config_path.open(encoding = "utf-8") as fh: builder_config = json.load(fh) card = DataDesignerDatasetCard.from_metadata( metadata = metadata, builder_config = builder_config, repo_id = repo_id, description = description, tags = None, ) card.text = card.text.replace(_DATA_DESIGNER_FOOTER, _UNSLOTH_STUDIO_FOOTER) # Data Designer currently drops the explicit token when pushing the # dataset card. Push it ourselves so auth stays request-local. card.push_to_hub(repo_id, token = hf_token, repo_type = "dataset") client._upload_main_dataset_files( repo_id = repo_id, parquet_folder = dataset_path / FINAL_DATASET_FOLDER_NAME, ) client._upload_images_folder( repo_id = repo_id, images_folder = dataset_path / "images", ) client._upload_processor_files( repo_id = repo_id, processors_folder = dataset_path / PROCESSORS_OUTPUTS_FOLDER_NAME, ) client._upload_config_files( repo_id = repo_id, metadata_path = metadata_path, builder_config_path = builder_config_path, ) return f"https://huggingface.co/datasets/{repo_id}" except HuggingFaceHubClientUploadError as exc: raise RecipeDatasetPublishError(str(exc)) from exc ================================================ FILE: studio/backend/core/data_recipe/jobs/__init__.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 from .manager import JobManager, get_job_manager __all__ = ["JobManager", "get_job_manager"] ================================================ FILE: studio/backend/core/data_recipe/jobs/constants.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 from __future__ import annotations # stages parsed from data-designer logs STAGE_CREATE = "create" STAGE_PREVIEW = "preview" STAGE_DAG = "dag" STAGE_HEALTHCHECK = "healthcheck" STAGE_SAMPLING = "sampling" STAGE_COLUMN_CONFIG = "column_config" STAGE_GENERATING = "generating" STAGE_BATCH = "batch" STAGE_PROFILING = "profiling" USAGE_RESET_STAGES = { STAGE_CREATE, STAGE_PREVIEW, STAGE_DAG, STAGE_HEALTHCHECK, STAGE_SAMPLING, STAGE_GENERATING, STAGE_PROFILING, } # job event types emitted by worker/manager EVENT_JOB_ENQUEUED = "job.enqueued" EVENT_JOB_STARTED = "job.started" EVENT_JOB_CANCELLING = "job.cancelling" EVENT_JOB_CANCELLED = "job.cancelled" EVENT_JOB_COMPLETED = "job.completed" EVENT_JOB_ERROR = "job.error" ================================================ FILE: studio/backend/core/data_recipe/jobs/manager.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 from __future__ import annotations import asyncio import json import queue import threading import time import uuid from pathlib import Path from collections import deque from dataclasses import dataclass from typing import Any import multiprocessing as mp from ..jsonable import to_preview_jsonable from .constants import ( EVENT_JOB_CANCELLING, EVENT_JOB_CANCELLED, EVENT_JOB_COMPLETED, EVENT_JOB_ENQUEUED, EVENT_JOB_ERROR, EVENT_JOB_STARTED, ) from .parse import apply_update, coerce_event, parse_log_message from .types import Job from .worker import run_job_process _CTX = mp.get_context("spawn") @dataclass class Subscription: replay: list[dict] _q: queue.Queue _next_id: int = 0 async def next_event(self, *, timeout_sec: float) -> dict | None: """Wait for next event (SSE), w/ timeout so we can check disconnects.""" try: return await asyncio.to_thread(self._q.get, True, timeout_sec) except queue.Empty: return None def format_sse(self, event: dict) -> bytes: """Turn event dict into SSE bytes (id/event/data).""" event_id = event.get("seq") if event_id is None: self._next_id += 1 event_id = self._next_id body = json.dumps(event, separators = (",", ":"), ensure_ascii = False) event_type = event.get("type") or "message" return ( f"id: {event_id}\n" f"event: {event_type}\n" f"data: {body}\n\n" ).encode("utf-8") class JobManager: def __init__(self) -> None: """Single-job runner (in-mem). Simple on purpose, not a whole platform.""" self._lock = threading.Lock() self._job: Job | None = None self._proc: mp.Process | None = None self._mp_q: Any | None = None self._events: deque[dict] = deque(maxlen = 5000) self._subs: list[queue.Queue] = [] self._pump_thread: threading.Thread | None = None self._seq: int = 0 def start(self, *, recipe: dict, run: dict) -> str: """Spawn the job subprocess (one at a time, no cap).""" llm_columns = recipe.get("columns") or [] llm_column_count = 0 if isinstance(llm_columns, list): for column in llm_columns: if not isinstance(column, dict): continue column_type = str(column.get("column_type") or "").strip().lower() if column_type.startswith("llm"): llm_column_count += 1 if llm_column_count <= 0: llm_column_count = 1 with self._lock: if self._proc is not None and self._proc.is_alive(): raise RuntimeError("job already running") job_id = uuid.uuid4().hex self._job = Job(job_id = job_id, status = "pending", started_at = time.time()) self._job.progress_columns_total = llm_column_count self._events.clear() self._seq = 0 run_payload = dict(run) run_payload["_job_id"] = job_id mp_q = _CTX.Queue() proc = _CTX.Process( target = run_job_process, kwargs = {"event_queue": mp_q, "recipe": recipe, "run": run_payload}, daemon = True, ) proc.start() self._mp_q = mp_q self._proc = proc self._pump_thread = threading.Thread(target = self._pump_loop, daemon = True) self._pump_thread.start() self._emit( {"type": EVENT_JOB_ENQUEUED, "ts": time.time(), "job_id": job_id} ) return job_id def cancel(self, job_id: str) -> bool: """Hard stop. We terminate the subprocess. Quick + reliable.""" with self._lock: if self._job is None or self._job.job_id != job_id: return False if self._proc is None or not self._proc.is_alive(): return True self._job.status = "cancelling" self._emit( {"type": EVENT_JOB_CANCELLING, "ts": time.time(), "job_id": job_id} ) try: self._proc.terminate() except (AttributeError, OSError): pass return True def get_status(self, job_id: str) -> dict | None: """UI friendly snapshot that we need. Alternative to sse kinda of and structured""" with self._lock: if self._job is None or self._job.job_id != job_id: return None job = self._job return { "job_id": job.job_id, "status": job.status, "stage": job.stage, "current_column": job.current_column, "completed_columns": list(job.completed_columns), "batch": {"idx": job.batch.idx, "total": job.batch.total}, "progress": { "done": job.progress.done, "total": job.progress.total, "percent": job.progress.percent, "eta_sec": job.progress.eta_sec, "rate": job.progress.rate, "ok": job.progress.ok, "failed": job.progress.failed, }, "column_progress": { "done": job.column_progress.done, "total": job.column_progress.total, "percent": job.column_progress.percent, "eta_sec": job.column_progress.eta_sec, "rate": job.column_progress.rate, "ok": job.column_progress.ok, "failed": job.column_progress.failed, }, "model_usage": { name: { "model": usage.model, "tokens": { "input": usage.input_tokens, "output": usage.output_tokens, "total": usage.total_tokens, "tps": usage.tps, }, "requests": { "success": usage.requests_success, "failed": usage.requests_failed, "total": usage.requests_total, "rpm": usage.rpm, }, } for name, usage in job.model_usage.items() }, "rows": job.rows, "cols": job.cols, "error": job.error, "has_analysis": job.analysis is not None, "dataset_rows": None if job.dataset is None else len(job.dataset), "artifact_path": job.artifact_path, "execution_type": job.execution_type, "started_at": job.started_at, "finished_at": job.finished_at, } def get_current_status(self) -> dict | None: """Single-job convenience (last/current).""" job_id = self.get_current_job_id() if job_id is None: return None return self.get_status(job_id) def get_current_job_id(self) -> str | None: """Return current job_id (or None).""" with self._lock: return None if self._job is None else self._job.job_id def get_analysis(self, job_id: str) -> dict | None: """Final profiling output (only after job completes).""" with self._lock: if self._job is None or self._job.job_id != job_id: return None return self._job.analysis def get_dataset( self, job_id: str, *, limit: int, offset: int = 0, ) -> dict[str, Any] | None: """Load dataset page (offset + limit) and include total rows.""" with self._lock: if self._job is None or self._job.job_id != job_id: return None in_memory_dataset = self._job.dataset artifact_path = self._job.artifact_path job_status = self._job.status if in_memory_dataset is not None: total = len(in_memory_dataset) rows = in_memory_dataset[offset : offset + limit] return {"dataset": rows, "total": total} if not artifact_path: if job_status in {"completed", "error", "cancelled"}: return {"error": "artifact path missing"} return None try: base_dataset_path = Path(artifact_path) parquet_dir = base_dataset_path / "parquet-files" if not parquet_dir.exists(): return {"error": f"dataset path missing: {parquet_dir}"} return self._load_dataset_page( parquet_dir = parquet_dir, limit = limit, offset = offset ) except Exception as exc: return {"error": f"dataset load failed: {exc}"} @staticmethod def _load_dataset_page( *, parquet_dir: Path, limit: int, offset: int, ) -> dict[str, Any]: dataset_page = JobManager._load_dataset_page_with_duckdb( parquet_dir = parquet_dir, limit = limit, offset = offset, ) if dataset_page is not None: return dataset_page return JobManager._load_dataset_page_with_data_designer( parquet_dir = parquet_dir, limit = limit, offset = offset, ) @staticmethod def _load_dataset_page_with_duckdb( *, parquet_dir: Path, limit: int, offset: int, ) -> dict[str, Any] | None: parquet_glob = str((parquet_dir / "*.parquet").resolve()) try: import duckdb # type: ignore except Exception: return None try: conn = duckdb.connect(":memory:") try: total_row = conn.execute( "SELECT COUNT(*) FROM read_parquet(?)", [parquet_glob], ).fetchone() total = int(total_row[0] if total_row else 0) dataframe = conn.execute( ( "SELECT *, row_number() OVER (PARTITION BY filename) AS __row_num__ " "FROM read_parquet(?, filename=true) " "ORDER BY filename, __row_num__ " "LIMIT ? OFFSET ?" ), [parquet_glob, int(limit), int(offset)], ).fetchdf() finally: conn.close() except (RuntimeError, ValueError, duckdb.Error): return None for helper_col in ("filename", "__row_num__"): if helper_col in dataframe.columns: dataframe = dataframe.drop(columns = [helper_col]) rows = dataframe.to_dict(orient = "records") return {"dataset": to_preview_jsonable(rows), "total": total} @staticmethod def _load_dataset_page_with_data_designer( *, parquet_dir: Path, limit: int, offset: int, ) -> dict[str, Any]: from data_designer.config.utils.io_helpers import read_parquet_dataset dataframe = read_parquet_dataset(parquet_dir) total = int(len(dataframe.index)) rows = dataframe.iloc[offset : offset + limit].to_dict(orient = "records") return {"dataset": to_preview_jsonable(rows), "total": total} def subscribe( self, job_id: str, *, after_seq: int | None = None ) -> Subscription | None: """SSE subscribe: get replay buffer + live events stream.""" with self._lock: if self._job is None or self._job.job_id != job_id: return None q: queue.Queue = queue.Queue(maxsize = 2000) self._subs.append(q) if after_seq is None: replay = list(self._events) else: replay = [e for e in self._events if int(e.get("seq") or 0) > after_seq] return Subscription(replay = replay, _q = q) def unsubscribe(self, sub: Subscription) -> None: """Drop SSE subscriber (client disconnected).""" with self._lock: self._subs = [q for q in self._subs if q is not sub._q] def _emit(self, event: dict) -> None: """Broadcast event to replay buffer + all subscribers.""" self._seq += 1 event["seq"] = self._seq self._events.append(event) stale: list[queue.Queue] = [] for q in self._subs: try: q.put_nowait(event) except queue.Full: stale.append(q) if stale: self._subs = [q for q in self._subs if q not in stale] def _snapshot(self) -> tuple[Job, mp.Process, Any] | None: """Grab pointers for the pump loop (avoid holding lock too long).""" with self._lock: if self._job is None or self._proc is None or self._mp_q is None: return None return self._job, self._proc, self._mp_q @staticmethod def _read_queue_with_timeout(q: Any, *, timeout_sec: float) -> dict | None: """Try read 1 event from mp queue. Timeout = pump stays responsive.""" try: return coerce_event(q.get(timeout = timeout_sec)) except queue.Empty: return None except (EOFError, OSError, ValueError): return None @staticmethod def _drain_queue(q: Any) -> list[dict]: """Drain mp queue fast (used on process exit).""" events: list[dict] = [] while True: try: events.append(coerce_event(q.get_nowait())) except queue.Empty: return events except (EOFError, OSError, ValueError): return events def _pump_loop(self) -> None: """Background thread: consumes worker events + updates job snapshot.""" while True: snap = self._snapshot() if snap is None: return job, proc, mp_q = snap event = self._read_queue_with_timeout(mp_q, timeout_sec = 0.25) if event is not None: self._handle_event(job, event) continue if proc.is_alive(): continue for e in self._drain_queue(mp_q): self._handle_event(job, e) with self._lock: if self._job and self._job.status in { "pending", "active", "cancelling", }: if self._job.status == "cancelling": self._job.status = "cancelled" else: self._job.status = "error" self._job.error = self._job.error or "process exited" self._job.finished_at = time.time() event_type = ( EVENT_JOB_CANCELLED if self._job.status == "cancelled" else EVENT_JOB_ERROR ) self._emit( { "type": event_type, "ts": time.time(), "job_id": self._job.job_id, } ) return def _handle_event(self, job: Job, event: dict) -> None: """Apply event -> job state + forward to SSE.""" et = event.get("type") msg = event.get("message") if et == "log" else None with self._lock: if self._job is None or self._job.job_id != job.job_id: return if et == EVENT_JOB_STARTED: self._job.status = "active" if et == EVENT_JOB_COMPLETED: self._job.status = "completed" self._job.finished_at = time.time() self._job.analysis = event.get("analysis") self._job.artifact_path = event.get("artifact_path") self._job.execution_type = event.get("execution_type") self._job.dataset = event.get("dataset") self._job.processor_artifacts = event.get("processor_artifacts") if self._job.progress.total and self._job.progress.total > 0: self._job.progress.done = self._job.progress.total self._job.progress.percent = 100.0 if et == EVENT_JOB_ERROR: self._job.status = "error" self._job.finished_at = time.time() self._job.error = event.get("error") or "error" if msg: upd = parse_log_message(msg) if upd: apply_update(self._job, upd) self._emit(event) _JOB_MANAGER: JobManager | None = None def get_job_manager() -> JobManager: """Singleton JobManager (we only run 1 job anyway).""" global _JOB_MANAGER if _JOB_MANAGER is None: _JOB_MANAGER = JobManager() return _JOB_MANAGER ================================================ FILE: studio/backend/core/data_recipe/jobs/parse.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 from __future__ import annotations import re from dataclasses import dataclass from typing import Any from .constants import ( STAGE_BATCH, STAGE_COLUMN_CONFIG, STAGE_CREATE, STAGE_DAG, STAGE_GENERATING, STAGE_HEALTHCHECK, STAGE_PREVIEW, STAGE_PROFILING, STAGE_SAMPLING, USAGE_RESET_STAGES, ) from .types import Job, ModelUsage, Progress @dataclass(frozen = True) class ParsedUpdate: stage: str | None = None current_column: str | None = None progress: Progress | None = None rows: int | None = None cols: int | None = None batch_idx: int | None = None batch_total: int | None = None usage_model: str | None = None usage_input_tokens: int | None = None usage_output_tokens: int | None = None usage_total_tokens: int | None = None usage_tps: float | None = None usage_requests_success: int | None = None usage_requests_failed: int | None = None usage_requests_total: int | None = None usage_rpm: float | None = None usage_section_start: bool | None = None # kinda of a bummber but currently only option, Best effort parser from data-designer logs -> structured status for UI. _RE_SAMPLERS = re.compile( r"Preparing samplers to generate (?P\d+) records across (?P\d+) columns" ) _RE_COLCFG = re.compile(r"model config for column '(?P[^']+)'") _RE_PROCESSING_COL = re.compile(r"Processing .* column '(?P[^']+)'") _RE_PROGRESS = re.compile( r"progress: (?P\d+)/(?P\d+) \((?P\d+)%\) complete, " r"(?P\d+) ok, (?P\d+) failed, (?P[0-9.]+) rec/s, eta (?P[0-9.]+)s" ) _RE_BATCH = re.compile(r"Processing batch (?P\d+) of (?P\d+)") _RE_USAGE_MODEL = re.compile(r"model:\s*(?P.+)$") _RE_USAGE_TOKENS = re.compile( r"tokens:\s*input=(?P\d+),\s*output=(?P\d+),\s*total=(?P\d+),\s*tps=(?P[0-9.]+)" ) _RE_USAGE_REQUESTS = re.compile( r"requests:\s*success=(?P\d+),\s*failed=(?P\d+),\s*total=(?P\d+),\s*rpm=(?P[0-9.]+)" ) def parse_log_message(msg: str) -> ParsedUpdate | None: m = _RE_SAMPLERS.search(msg) if m: return ParsedUpdate( stage = STAGE_SAMPLING, rows = int(m.group("rows")), cols = int(m.group("cols")), ) if "Sorting column configs into a Directed Acyclic Graph" in msg: return ParsedUpdate(stage = STAGE_DAG) if "Running health checks for models" in msg: return ParsedUpdate(stage = STAGE_HEALTHCHECK) if "Preview generation in progress" in msg: return ParsedUpdate(stage = STAGE_PREVIEW) if "Creating Data Designer dataset" in msg: return ParsedUpdate(stage = STAGE_CREATE) if "Measuring dataset column statistics" in msg: return ParsedUpdate(stage = STAGE_PROFILING) m = _RE_COLCFG.search(msg) if m: col = m.group("col") return ParsedUpdate(stage = STAGE_COLUMN_CONFIG, current_column = col) m = _RE_PROCESSING_COL.search(msg) if m: col = m.group("col") return ParsedUpdate(stage = STAGE_GENERATING, current_column = col) m = _RE_PROGRESS.search(msg) if m: p = Progress( done = int(m.group("done")), total = int(m.group("total")), percent = float(m.group("pct")), ok = int(m.group("ok")), failed = int(m.group("failed")), rate = float(m.group("rate")), eta_sec = float(m.group("eta")), ) return ParsedUpdate(stage = STAGE_GENERATING, progress = p) m = _RE_BATCH.search(msg) if m: return ParsedUpdate( stage = STAGE_BATCH, batch_idx = int(m.group("idx")), batch_total = int(m.group("total")), ) if "Model usage summary" in msg: return ParsedUpdate(usage_section_start = True) m = _RE_USAGE_MODEL.search(msg) if m and "|-- model:" in msg: return ParsedUpdate(usage_model = str(m.group("model")).strip()) m = _RE_USAGE_TOKENS.search(msg) if m: return ParsedUpdate( usage_input_tokens = int(m.group("input")), usage_output_tokens = int(m.group("output")), usage_total_tokens = int(m.group("total")), usage_tps = float(m.group("tps")), ) m = _RE_USAGE_REQUESTS.search(msg) if m: return ParsedUpdate( usage_requests_success = int(m.group("success")), usage_requests_failed = int(m.group("failed")), usage_requests_total = int(m.group("total")), usage_rpm = float(m.group("rpm")), ) return None def apply_update(job: Job, update: ParsedUpdate) -> None: if update.stage is not None: job.stage = update.stage if update.current_column is not None: job.current_column = update.current_column if ( update.stage == STAGE_GENERATING and update.current_column not in job._seen_generation_columns ): job._seen_generation_columns.append(update.current_column) if update.rows is not None: job.rows = update.rows if update.cols is not None: job.cols = update.cols if update.progress is not None: job.column_progress = update.progress if ( job.current_column and update.progress.done is not None and update.progress.total is not None and update.progress.total > 0 and update.progress.done >= update.progress.total and job.current_column not in job.completed_columns ): job.completed_columns.append(job.current_column) job.progress = _compute_overall_progress(job, update.progress) if update.batch_idx is not None: job.batch.idx = update.batch_idx if update.batch_total is not None: job.batch.total = update.batch_total if update.stage in USAGE_RESET_STAGES: # usage summary is a short block so we reset once we move into the next stage. job._in_usage_summary = False if update.usage_section_start is not None: job._in_usage_summary = update.usage_section_start if update.usage_section_start: job._current_usage_model = None if not job._in_usage_summary: return if update.usage_model is not None: name = update.usage_model.strip().strip("'").strip('"') job._current_usage_model = name if name not in job.model_usage: job.model_usage[name] = ModelUsage(model = name) if job._current_usage_model is None: return usage = job.model_usage.get(job._current_usage_model) if usage is None: return if update.usage_input_tokens is not None: usage.input_tokens = update.usage_input_tokens if update.usage_output_tokens is not None: usage.output_tokens = update.usage_output_tokens if update.usage_total_tokens is not None: usage.total_tokens = update.usage_total_tokens if update.usage_tps is not None: usage.tps = update.usage_tps if update.usage_requests_success is not None: usage.requests_success = update.usage_requests_success if update.usage_requests_failed is not None: usage.requests_failed = update.usage_requests_failed if update.usage_requests_total is not None: usage.requests_total = update.usage_requests_total if update.usage_rpm is not None: usage.rpm = update.usage_rpm def _compute_overall_progress(job: Job, column_progress: Progress) -> Progress: if not job.rows: return column_progress total_rows = max(1, int(job.rows)) current_done = 0 if column_progress.done is None else int(column_progress.done) current_done = max(0, min(current_done, total_rows)) total_columns = max(1, int(job.progress_columns_total or 1)) if job.current_column: job._column_done[job.current_column] = current_done if len(job._column_done) == 0: done = current_done else: sum_done = sum( max(0, min(value, total_rows)) for value in job._column_done.values() ) done = int(sum_done / total_columns) prev_done = int(job.progress.done or 0) if done < prev_done: done = prev_done if done > total_rows: done = total_rows percent = (done / total_rows) * 100 if total_rows > 0 else 100.0 prev_percent = float(job.progress.percent or 0.0) if percent < prev_percent: percent = prev_percent return Progress( done = done, total = total_rows, percent = percent, eta_sec = column_progress.eta_sec, rate = column_progress.rate, ok = column_progress.ok, failed = column_progress.failed, ) def coerce_event(obj: Any) -> dict: """Normalize worker payload into event dict.""" return obj if isinstance(obj, dict) else {"type": "log", "message": str(obj)} ================================================ FILE: studio/backend/core/data_recipe/jobs/types.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 from __future__ import annotations from dataclasses import dataclass, field from typing import Any, Literal JobStatus = Literal[ "created", "pending", "active", "cancelling", "cancelled", "error", "completed", ] @dataclass class Progress: done: int | None = None total: int | None = None percent: float | None = None eta_sec: float | None = None rate: float | None = None ok: int | None = None failed: int | None = None @dataclass class BatchProgress: idx: int | None = None total: int | None = None @dataclass class ModelUsage: model: str input_tokens: int | None = None output_tokens: int | None = None total_tokens: int | None = None tps: float | None = None requests_success: int | None = None requests_failed: int | None = None requests_total: int | None = None rpm: float | None = None @dataclass class Job: job_id: str status: JobStatus = "created" stage: str | None = None current_column: str | None = None progress: Progress = field(default_factory = Progress) column_progress: Progress = field(default_factory = Progress) batch: BatchProgress = field(default_factory = BatchProgress) rows: int | None = None cols: int | None = None error: str | None = None started_at: float | None = None finished_at: float | None = None analysis: dict[str, Any] | None = None artifact_path: str | None = None execution_type: str | None = None dataset: list[dict[str, Any]] | None = None processor_artifacts: dict[str, Any] | None = None model_usage: dict[str, ModelUsage] = field(default_factory = dict) progress_columns_total: int | None = None completed_columns: list[str] = field(default_factory = list) _current_usage_model: str | None = None _in_usage_summary: bool = False _seen_generation_columns: list[str] = field(default_factory = list) _column_done: dict[str, int] = field(default_factory = dict) ================================================ FILE: studio/backend/core/data_recipe/jobs/worker.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 from __future__ import annotations import json import structlog import loggers import logging import re import shutil import time import traceback import unicodedata from pathlib import Path from typing import Any from ..jsonable import to_jsonable, to_preview_jsonable from .constants import EVENT_JOB_COMPLETED, EVENT_JOB_ERROR, EVENT_JOB_STARTED from ..service import build_config_builder, create_data_designer from utils.paths import ensure_dir, recipe_datasets_root _ARTIFACT_ROOT = recipe_datasets_root() class _QueueLogHandler(logging.Handler): def __init__(self, event_queue): super().__init__() self._q = event_queue def emit(self, record: logging.LogRecord) -> None: try: event = { "type": "log", "ts": record.created, "level": record.levelname, "logger": record.name, "message": record.getMessage(), } self._q.put(event) except (OSError, RuntimeError, ValueError): pass def _slugify_run_name(value: str) -> str: normalized = unicodedata.normalize("NFKD", value) ascii_only = normalized.encode("ascii", "ignore").decode("ascii") slug = re.sub(r"[^a-zA-Z0-9]+", "-", ascii_only).strip("-").lower() if not slug: return "" return slug[:80].strip("-") def _build_dataset_name( *, run_name: str | None, job_id: str, artifact_root: Path ) -> str: fallback = f"recipe_{job_id}" slug = _slugify_run_name(run_name or "") base_name = f"recipe_{slug}" if slug else fallback candidate = base_name suffix = 2 while (artifact_root / candidate).exists(): candidate = f"{base_name}_{suffix}" suffix += 1 return candidate def run_job_process( *, event_queue, recipe: dict[str, Any], run: dict[str, Any], ) -> None: """ Subprocess entrypoint. Sends events to `event_queue`. """ import os os.environ["PYTHONWARNINGS"] = ( "ignore" # Suppress warnings at C-level before imports ) import warnings from loggers.config import LogConfig if os.getenv("ENVIRONMENT_TYPE", "production") == "production": warnings.filterwarnings("ignore") LogConfig.setup_logging( service_name = "unsloth-studio-data-worker", env = os.getenv("ENVIRONMENT_TYPE", "production"), ) event_queue.put({"type": EVENT_JOB_STARTED, "ts": time.time()}) try: from data_designer.config.run_config import RunConfig rows = int(run.get("rows") or 1000) job_id = str(run.get("_job_id") or "").strip() if not job_id: job_id = f"{int(time.time())}" run_name_raw = run.get("run_name") run_name = run_name_raw if isinstance(run_name_raw, str) else None dataset_name = _build_dataset_name( run_name = run_name, job_id = job_id, artifact_root = _ARTIFACT_ROOT, ) merge_batches = bool(run.get("merge_batches")) ensure_dir(_ARTIFACT_ROOT) run_config_raw = run.get("run_config") or {} builder = build_config_builder(recipe) designer = create_data_designer(recipe, artifact_path = str(_ARTIFACT_ROOT)) # DataDesigner configures root logging in DataDesigner.__init__. # Attach queue logger directly to `data_designer` so parser events survive root resets. handler = _QueueLogHandler(event_queue) handler.setLevel(logging.INFO) data_designer_logger = logging.getLogger("data_designer") data_designer_logger.addHandler(handler) data_designer_logger.setLevel(logging.INFO) data_designer_logger.propagate = True if run_config_raw: designer.set_run_config(RunConfig.model_validate(run_config_raw)) execution_type = str(run.get("execution_type") or "full").strip().lower() if execution_type == "preview": results = designer.preview(builder, num_records = rows) analysis = ( None if results.analysis is None else to_jsonable(results.analysis.model_dump(mode = "json")) ) dataset = ( [] if results.dataset is None else to_preview_jsonable(results.dataset.to_dict(orient = "records")) ) processor_artifacts = ( None if results.processor_artifacts is None else to_jsonable(results.processor_artifacts) ) event_queue.put( { "type": EVENT_JOB_COMPLETED, "ts": time.time(), "analysis": analysis, "dataset": dataset, "processor_artifacts": processor_artifacts, "artifact_path": None, "execution_type": execution_type, } ) else: results = designer.create( builder, num_records = rows, dataset_name = dataset_name ) analysis = to_jsonable(results.load_analysis().model_dump(mode = "json")) if merge_batches: _merge_batches_to_single_parquet( results.artifact_storage.base_dataset_path ) artifact_path = str(results.artifact_storage.base_dataset_path) event_queue.put( { "type": EVENT_JOB_COMPLETED, "ts": time.time(), "analysis": analysis, "artifact_path": artifact_path, "execution_type": execution_type, } ) except Exception as exc: event_queue.put( { "type": EVENT_JOB_ERROR, "ts": time.time(), "error": str(exc), "stack": traceback.format_exc(limit = 20), } ) def _merge_batches_to_single_parquet(base_dataset_path: Path) -> None: parquet_dir = base_dataset_path / "parquet-files" parquet_files = sorted(parquet_dir.glob("*.parquet")) if len(parquet_files) <= 1: return try: from data_designer.config.utils.io_helpers import read_parquet_dataset except ImportError: return dataframe = read_parquet_dataset(parquet_dir) shutil.rmtree(parquet_dir) parquet_dir.mkdir(parents = True, exist_ok = True) merged_file = parquet_dir / "batch_00000.parquet" dataframe.to_parquet(merged_file, index = False) _rewrite_merged_metadata( base_dataset_path = base_dataset_path, parquet_file = merged_file, ) def _rewrite_merged_metadata(*, base_dataset_path: Path, parquet_file: Path) -> None: metadata_path = base_dataset_path / "metadata.json" if not metadata_path.exists(): return try: metadata = json.loads(metadata_path.read_text(encoding = "utf-8")) except (OSError, TypeError, ValueError): return if not isinstance(metadata, dict): return relative_parquet_path = str(parquet_file.relative_to(base_dataset_path)) file_paths = metadata.get("file_paths") if not isinstance(file_paths, dict): file_paths = {} file_paths["parquet-files"] = [relative_parquet_path] metadata["file_paths"] = file_paths metadata["total_num_batches"] = 1 metadata["num_completed_batches"] = 1 try: metadata_path.write_text( json.dumps(metadata, indent = 2, sort_keys = True), encoding = "utf-8", ) except OSError: return ================================================ FILE: studio/backend/core/data_recipe/jsonable.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 from __future__ import annotations import base64 import io from pathlib import Path from typing import Any def _pil_to_preview_payload(image: Any) -> dict[str, Any]: buffer = io.BytesIO() image.convert("RGB").save(buffer, format = "JPEG", quality = 85) return { "type": "image", "mime": "image/jpeg", "width": image.width, "height": image.height, "data": base64.b64encode(buffer.getvalue()).decode("ascii"), } def _open_pil_image_from_bytes(raw_bytes: bytes): from PIL import Image # type: ignore with Image.open(io.BytesIO(raw_bytes)) as image: return image.copy() def _to_pil_from_hf_image_dict(value: Any) -> Any | None: if not isinstance(value, dict): return None raw_bytes = value.get("bytes") if isinstance(raw_bytes, (bytes, bytearray)) and len(raw_bytes) > 0: try: return _open_pil_image_from_bytes(bytes(raw_bytes)) except (OSError, ValueError): pass if ( isinstance(raw_bytes, list) and len(raw_bytes) > 0 and all(isinstance(item, int) and 0 <= item <= 255 for item in raw_bytes) ): try: return _open_pil_image_from_bytes(bytes(raw_bytes)) except (OSError, ValueError): pass path_value = value.get("path") if isinstance(path_value, str) and path_value.strip(): try: from PIL import Image # type: ignore with Image.open(Path(path_value)) as image: return image.copy() except (OSError, ValueError, TypeError): return None return None def to_jsonable(value: Any) -> Any: """Convert numpy/pandas-ish values into plain JSON-safe values.""" try: import numpy as np # type: ignore except ImportError: # pragma: no cover np = None # type: ignore if np is not None: if isinstance(value, np.ndarray): return value.tolist() if isinstance(value, np.generic): return value.item() if isinstance(value, dict): return {str(k): to_jsonable(v) for k, v in value.items()} if isinstance(value, (list, tuple, set)): return [to_jsonable(v) for v in value] if hasattr(value, "isoformat") and callable(value.isoformat): try: return value.isoformat() except (TypeError, ValueError): return value return value def _to_preview_image_payload(value: Any) -> dict[str, Any] | None: try: from PIL.Image import Image as PILImage # type: ignore except ImportError: # pragma: no cover return None if not isinstance(value, PILImage): hf_image = _to_pil_from_hf_image_dict(value) if hf_image is None: return None value = hf_image return _pil_to_preview_payload(value) def to_preview_jsonable(value: Any) -> Any: """Convert values into JSON-safe preview values, including PIL images.""" image_payload = _to_preview_image_payload(value) if image_payload is not None: return image_payload converted = to_jsonable(value) if converted is None or isinstance(converted, (str, int, float, bool)): return converted if isinstance(converted, dict): return {str(k): to_preview_jsonable(v) for k, v in converted.items()} if isinstance(converted, (list, tuple, set)): return [to_preview_jsonable(v) for v in converted] if isinstance(converted, (bytes, bytearray)): return base64.b64encode(bytes(converted)).decode("ascii") return str(converted) ================================================ FILE: studio/backend/core/data_recipe/local_callable_validators.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 from __future__ import annotations import json import os import structlog import subprocess from copy import deepcopy from dataclasses import dataclass from functools import lru_cache from pathlib import Path from typing import Any from loggers import get_logger from utils.paths import ensure_dir, oxc_validator_tmp_root logger = get_logger(__name__) OXC_VALIDATION_FN_MARKER = "unsloth_oxc_validator" _OXC_LANG_TO_NODE_LANG = { "javascript": "js", "typescript": "ts", "jsx": "jsx", "tsx": "tsx", } _OXC_VALIDATION_MODES = {"syntax", "lint", "syntax+lint"} _OXC_CODE_SHAPES = {"auto", "module", "snippet"} _OXC_TOOL_DIR = Path(__file__).resolve().parent / "oxc-validator" _OXC_RUNNER_PATH = _OXC_TOOL_DIR / "validate.mjs" @dataclass(frozen = True) class OxcLocalCallableValidatorSpec: name: str drop: bool target_columns: list[str] batch_size: int code_lang: str validation_mode: str code_shape: str def split_oxc_local_callable_validators( recipe_core: dict[str, Any], ) -> tuple[dict[str, Any], list[OxcLocalCallableValidatorSpec]]: columns = recipe_core.get("columns") if not isinstance(columns, list): return recipe_core, [] sanitized = deepcopy(recipe_core) sanitized_columns = sanitized.get("columns") if not isinstance(sanitized_columns, list): return sanitized, [] kept_columns: list[Any] = [] oxc_specs: list[OxcLocalCallableValidatorSpec] = [] for column in sanitized_columns: if not isinstance(column, dict): kept_columns.append(column) continue maybe_spec = _parse_oxc_spec(column = column) if maybe_spec is None: kept_columns.append(column) continue oxc_specs.append(maybe_spec) sanitized["columns"] = kept_columns return sanitized, oxc_specs def register_oxc_local_callable_validators( *, builder, specs: list[OxcLocalCallableValidatorSpec], ) -> None: if not specs: return from data_designer.config.column_configs import ValidationColumnConfig from data_designer.config.validator_params import ( LocalCallableValidatorParams, ValidatorType, ) for spec in specs: validation_function = _build_oxc_validation_function( spec.code_lang, spec.validation_mode, spec.code_shape, ) builder.add_column( ValidationColumnConfig( name = spec.name, drop = spec.drop, target_columns = spec.target_columns, validator_type = ValidatorType.LOCAL_CALLABLE, validator_params = LocalCallableValidatorParams( validation_function = validation_function, ), batch_size = spec.batch_size, ) ) def _parse_oxc_spec( *, column: dict[str, Any], ) -> OxcLocalCallableValidatorSpec | None: if str(column.get("column_type") or "").strip() != "validation": return None if str(column.get("validator_type") or "").strip() != "local_callable": return None params = column.get("validator_params") if not isinstance(params, dict): return None fn_raw = params.get("validation_function") fn_name = fn_raw.strip() if isinstance(fn_raw, str) else "" if not fn_name.startswith(OXC_VALIDATION_FN_MARKER): return None name = str(column.get("name") or "").strip() if not name: return None target_columns_raw = column.get("target_columns") target_columns = ( [ value.strip() for value in target_columns_raw if isinstance(value, str) and value.strip() ] if isinstance(target_columns_raw, list) else [] ) if not target_columns: return None code_lang, validation_mode, code_shape = _parse_oxc_validation_marker(fn_name) batch_size = _parse_batch_size(column.get("batch_size")) drop = bool(column.get("drop") is True) return OxcLocalCallableValidatorSpec( name = name, drop = drop, target_columns = target_columns, batch_size = batch_size, code_lang = code_lang, validation_mode = validation_mode, code_shape = code_shape, ) def _parse_batch_size(value: Any) -> int: try: parsed = int(value) except (TypeError, ValueError): return 10 return parsed if parsed >= 1 else 10 def _parse_oxc_validation_marker(fn_name: str) -> tuple[str, str, str]: marker = f"{OXC_VALIDATION_FN_MARKER}:" if not fn_name.startswith(marker): return "javascript", "syntax", "auto" suffix = fn_name[len(marker) :] parts = [part.strip() for part in suffix.split(":") if part.strip()] if len(parts) < 2: return "javascript", "syntax", "auto" code_lang = parts[0] if parts[0] in _OXC_LANG_TO_NODE_LANG else "javascript" mode = parts[1] if parts[1] in _OXC_VALIDATION_MODES else "syntax" code_shape = ( parts[2] if len(parts) >= 3 and parts[2] in _OXC_CODE_SHAPES else "auto" ) return code_lang, mode, code_shape @lru_cache(maxsize = 8) def _build_oxc_validation_function(lang: str, validation_mode: str, code_shape: str): node_lang = _OXC_LANG_TO_NODE_LANG.get(lang, "js") mode = validation_mode if validation_mode in _OXC_VALIDATION_MODES else "syntax" normalized_code_shape = code_shape if code_shape in _OXC_CODE_SHAPES else "auto" def _validator(df): import pandas as pd # imported lazily for local callable runtime row_count = int(len(df.index)) if row_count == 0: return pd.DataFrame({"is_valid": []}) code_column = str(df.columns[0]) if len(df.columns) > 0 else "" code_values = ( ["" for _ in range(row_count)] if not code_column else [ "" if value is None else str(value) for value in df[code_column].tolist() ] ) results = _run_oxc_batch( node_lang = node_lang, validation_mode = mode, code_shape = normalized_code_shape, code_values = code_values, ) if len(results) != row_count: results = _fallback_results( row_count, "OXC validator returned mismatched result size.", ) return pd.DataFrame(results) _validator.__name__ = f"{OXC_VALIDATION_FN_MARKER}_{node_lang}_{mode.replace('+', '_')}_{normalized_code_shape}" return _validator def _run_oxc_batch( *, node_lang: str, validation_mode: str, code_shape: str, code_values: list[str], ) -> list[dict[str, Any]]: if not _OXC_RUNNER_PATH.exists(): return _fallback_results( len(code_values), f"OXC runner missing at {_OXC_RUNNER_PATH}", ) payload = { "lang": node_lang, "mode": validation_mode, "code_shape": code_shape, "codes": code_values, } try: tmp_dir = ensure_dir(oxc_validator_tmp_root()) env = dict(os.environ) tmp_dir_str = str(tmp_dir) env["TMPDIR"] = tmp_dir_str env["TMP"] = tmp_dir_str env["TEMP"] = tmp_dir_str proc = subprocess.run( ["node", str(_OXC_RUNNER_PATH)], cwd = str(_OXC_TOOL_DIR), input = json.dumps(payload), text = True, capture_output = True, check = False, env = env, ) except (OSError, ValueError) as exc: logger.warning("OXC subprocess launch failed: %s", exc) return _fallback_results(len(code_values), f"OXC launch failed: {exc}") if proc.returncode != 0: message = (proc.stderr or proc.stdout or "unknown error").strip() if len(message) > 300: message = f"{message[:300]}..." return _fallback_results(len(code_values), f"OXC failed: {message}") try: raw = json.loads(proc.stdout) except json.JSONDecodeError: return _fallback_results(len(code_values), "OXC output parse failed.") if not isinstance(raw, list): return _fallback_results(len(code_values), "OXC output must be an array.") out: list[dict[str, Any]] = [] for item in raw: if not isinstance(item, dict): out.append( { "is_valid": False, "error_count": 1, "error_message": "Invalid OXC result entry.", "severity": None, "code": None, "labels": [], "codeframe": None, "warning_count": 0, } ) continue is_valid_raw = item.get("is_valid") error_count_raw = item.get("error_count") message_raw = item.get("error_message") severity_raw = item.get("severity") code_raw = item.get("code") labels_raw = item.get("labels") codeframe_raw = item.get("codeframe") warning_count_raw = item.get("warning_count") out.append( { "is_valid": bool(is_valid_raw) if isinstance(is_valid_raw, bool) else False, "error_count": int(error_count_raw) if isinstance(error_count_raw, int) else 0, "error_message": str(message_raw or ""), "severity": str(severity_raw) if isinstance(severity_raw, str) else None, "code": str(code_raw) if isinstance(code_raw, str) else None, "labels": labels_raw if isinstance(labels_raw, list) else [], "codeframe": str(codeframe_raw) if isinstance(codeframe_raw, str) else None, "warning_count": int(warning_count_raw) if isinstance(warning_count_raw, int) else 0, } ) return out def _fallback_results(row_count: int, message: str) -> list[dict[str, Any]]: return [ { "is_valid": False, "error_count": 1, "error_message": message, "severity": None, "code": None, "labels": [], "codeframe": None, "warning_count": 0, } for _ in range(row_count) ] ================================================ FILE: studio/backend/core/data_recipe/oxc-validator/package.json ================================================ { "name": "unsloth-oxc-validator-runtime", "private": true, "version": "0.0.1", "type": "module", "dependencies": { "oxc-parser": "^0.116.0", "oxlint": "^1.51.0" } } ================================================ FILE: studio/backend/core/data_recipe/oxc-validator/validate.mjs ================================================ // SPDX-License-Identifier: AGPL-3.0-only // Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 import { spawnSync } from "node:child_process"; import { mkdtempSync, rmSync, writeFileSync } from "node:fs"; import { tmpdir } from "node:os"; import { basename, dirname, join } from "node:path"; import { fileURLToPath } from "node:url"; import { parseSync } from "oxc-parser"; const LANG_TO_EXT = { js: "js", jsx: "jsx", ts: "ts", tsx: "tsx", }; const VALIDATION_MODES = new Set(["syntax", "lint", "syntax+lint"]); const CODE_SHAPES = new Set(["auto", "module", "snippet"]); const SNIPPET_PREFIX = "(() => {\n"; const SNIPPET_SUFFIX = "\n})();\nexport {};\n"; const OXLINT_SUPPRESSED_RULES = ["no-unused-vars", "no-new-array"]; const TOOL_DIR = dirname(fileURLToPath(import.meta.url)); function mapLang(value) { const normalized = String(value || "").trim().toLowerCase(); if (normalized === "javascript" || normalized === "js") { return "js"; } if (normalized === "typescript" || normalized === "ts") { return "ts"; } if (normalized === "jsx") { return "jsx"; } if (normalized === "tsx") { return "tsx"; } return "js"; } function mapMode(value) { const normalized = String(value || "").trim().toLowerCase(); if (VALIDATION_MODES.has(normalized)) { return normalized; } return "syntax"; } function mapCodeShape(value) { const normalized = String(value || "").trim().toLowerCase(); if (CODE_SHAPES.has(normalized)) { return normalized; } return "auto"; } function parseFileIndex(filePath) { if (typeof filePath !== "string") { return null; } const match = basename(filePath).match(/^snippet_(\d+)\./); if (!match) { return null; } const parsed = Number.parseInt(match[1], 10); return Number.isFinite(parsed) ? parsed : null; } function toCodeString(code) { return typeof code === "string" ? code : String(code ?? ""); } function makeValidationEntry({ code, index, lang, codeShape }) { const source = toCodeString(code); if (codeShape === "snippet") { return { index, lang, code: `${SNIPPET_PREFIX}${source}${SNIPPET_SUFFIX}`, offset: SNIPPET_PREFIX.length, }; } return { index, lang, code: source, offset: 0, }; } function shiftOffset(value, offset) { if (!Number.isInteger(value)) { return null; } const shifted = value - offset; return shifted >= 0 ? shifted : null; } function remapDiagnosticOffsets(diagnostic, offset) { if (!diagnostic || typeof diagnostic !== "object" || offset <= 0) { return diagnostic; } return { ...diagnostic, labels: Array.isArray(diagnostic.labels) ? diagnostic.labels.map((label) => ({ ...label, start: shiftOffset(label.start, offset), end: shiftOffset(label.end, offset), })) : [], }; } function normalizeParserError(error) { if (typeof error === "string") { return { code: null, message: error.trim() || "Unknown parser error", severity: null, labels: [], codeframe: null, }; } if (!error || typeof error !== "object") { return { code: null, message: "Unknown parser error", severity: null, labels: [], codeframe: null, }; } const code = typeof error.code === "string" ? error.code : null; const message = String(error.message || error.reason || "").trim() || "Unknown parser error"; const severity = typeof error.severity === "string" ? error.severity : null; const labels = Array.isArray(error.labels) ? error.labels.map((label) => ({ message: label && typeof label === "object" && typeof label.message === "string" ? label.message : null, start: label && typeof label === "object" && Number.isInteger(label.start) ? label.start : null, end: label && typeof label === "object" && Number.isInteger(label.end) ? label.end : null, })) : []; const codeframe = typeof error.codeframe === "string" ? error.codeframe : null; return { code, message, severity, labels, codeframe, }; } function normalizeLintDiagnostic(diagnostic) { if (!diagnostic || typeof diagnostic !== "object") { return null; } const readString = (value) => typeof value === "string" ? value : null; const readInt = (value) => Number.isInteger(value) ? value : null; const asObject = (value) => value && typeof value === "object" ? value : null; const message = String(diagnostic.message || "").trim(); if (!message) { return null; } const severityRaw = String(diagnostic.severity || "").trim().toLowerCase(); const severity = severityRaw === "error" ? "error" : "warning"; const labels = []; if (Array.isArray(diagnostic.labels)) { for (const label of diagnostic.labels) { const labelObj = asObject(label); const span = asObject(labelObj?.span); const start = readInt(span?.offset); const length = readInt(span?.length); labels.push({ message: readString(labelObj?.label), start, end: start !== null && length !== null ? start + length : null, }); } } const code = typeof diagnostic.code === "string" ? diagnostic.code : null; return { code, message: code ? `${code}: ${message}` : message, severity, labels, codeframe: null, }; } function makeResult({ isValid, errorCount, warningCount = 0, message = "", severity = null, code = null, labels = [], codeframe = null, }) { return { is_valid: Boolean(isValid), error_count: Number.isInteger(errorCount) ? errorCount : 0, warning_count: Number.isInteger(warningCount) ? warningCount : 0, error_message: String(message || ""), severity: typeof severity === "string" ? severity : null, code: typeof code === "string" ? code : null, labels: Array.isArray(labels) ? labels : [], codeframe: typeof codeframe === "string" ? codeframe : null, }; } function syntaxResultFromErrors(errors) { const first = errors[0] ?? null; return makeResult({ isValid: errors.length === 0, errorCount: errors.length, warningCount: 0, message: errors.slice(0, 3).map((error) => error.message).join(" | "), severity: first ? first.severity : null, code: first ? first.code : null, labels: first ? first.labels : [], codeframe: first ? first.codeframe : null, }); } function runSyntaxParse(entry) { const ext = LANG_TO_EXT[entry.lang] ?? "js"; const filename = `snippet_${entry.index}.${ext}`; try { const parsed = parseSync(filename, entry.code, { lang: entry.lang, sourceType: "module", showSemanticErrors: true, }); const errors = Array.isArray(parsed?.errors) ? parsed.errors .map(normalizeParserError) .filter(Boolean) .map((error) => remapDiagnosticOffsets(error, entry.offset)) : []; return errors; } catch (error) { return [ remapDiagnosticOffsets( normalizeParserError(error), entry.offset, ), ]; } } function pickPreferredErrorList(firstErrors, secondErrors) { if (secondErrors.length < firstErrors.length) { return secondErrors; } return firstErrors; } function validateSyntaxOne({ code, lang, index, codeShape }) { if (codeShape !== "auto") { const lintEntry = makeValidationEntry({ code, index, lang, codeShape, }); const errors = runSyntaxParse(lintEntry); return { result: syntaxResultFromErrors(errors), lintEntry, }; } const moduleEntry = makeValidationEntry({ code, index, lang, codeShape: "module", }); const moduleErrors = runSyntaxParse(moduleEntry); if (moduleErrors.length === 0) { return { result: syntaxResultFromErrors(moduleErrors), lintEntry: moduleEntry, }; } const snippetEntry = makeValidationEntry({ code, index, lang, codeShape: "snippet", }); const snippetErrors = runSyntaxParse(snippetEntry); if (snippetErrors.length === 0) { return { result: syntaxResultFromErrors(snippetErrors), lintEntry: snippetEntry, }; } const chosenErrors = pickPreferredErrorList(moduleErrors, snippetErrors); const lintEntry = chosenErrors === snippetErrors ? snippetEntry : moduleEntry; return { result: syntaxResultFromErrors(chosenErrors), lintEntry, }; } function resolveLintEntry({ code, lang, index, codeShape }) { if (codeShape !== "auto") { return makeValidationEntry({ code, index, lang, codeShape, }); } const moduleEntry = makeValidationEntry({ code, index, lang, codeShape: "module", }); if (runSyntaxParse(moduleEntry).length === 0) { return moduleEntry; } const snippetEntry = makeValidationEntry({ code, index, lang, codeShape: "snippet", }); if (runSyntaxParse(snippetEntry).length === 0) { return snippetEntry; } return moduleEntry; } function fallbackLintResults(entries, message) { return new Map( entries.map((entry) => [ entry.index, makeResult({ isValid: false, errorCount: 1, warningCount: 0, message, severity: "error", }), ]), ); } function runLintBatch(entries) { if (entries.length === 0) { return new Map(); } const entryByIndex = new Map(entries.map((entry) => [entry.index, entry])); const tempDir = mkdtempSync(join(tmpdir(), "oxlint-")); try { for (const entry of entries) { const ext = LANG_TO_EXT[entry.lang] ?? "js"; const filePath = join(tempDir, `snippet_${entry.index}.${ext}`); writeFileSync(filePath, entry.code, "utf8"); } const oxlintBin = join(TOOL_DIR, "node_modules", ".bin", "oxlint"); const oxlintArgs = [ ...OXLINT_SUPPRESSED_RULES.flatMap((rule) => ["-A", rule]), "--format", "json", tempDir, ]; const exec = spawnSync(oxlintBin, oxlintArgs, { encoding: "utf8", cwd: TOOL_DIR, }); if (exec.error) { return fallbackLintResults( entries, `oxlint execution failed: ${exec.error.message}`, ); } const stdout = String(exec.stdout || "").trim(); if (!stdout) { const stderr = String(exec.stderr || "").trim(); return fallbackLintResults( entries, stderr || "oxlint returned empty output", ); } let parsed; try { parsed = JSON.parse(stdout); } catch { return fallbackLintResults(entries, "oxlint JSON parse failed"); } const rawDiagnostics = Array.isArray(parsed?.diagnostics) ? parsed.diagnostics : []; const byIndex = new Map(); for (const diag of rawDiagnostics) { const filenameRaw = typeof diag?.filename === "string" ? diag.filename : ""; const filename = filenameRaw.startsWith("file://") ? filenameRaw.replace("file://", "") : filenameRaw; const index = parseFileIndex(filename); if (index === null) { continue; } const normalized = normalizeLintDiagnostic(diag); if (!normalized) { continue; } const entry = entryByIndex.get(index); const remapped = remapDiagnosticOffsets(normalized, entry?.offset ?? 0); const list = byIndex.get(index) ?? []; list.push(remapped); byIndex.set(index, list); } const results = new Map(); for (const entry of entries) { const diagnostics = byIndex.get(entry.index) ?? []; const errorDiagnostics = diagnostics.filter( (diag) => diag.severity === "error", ); const warningDiagnostics = diagnostics.filter( (diag) => diag.severity !== "error", ); const top = errorDiagnostics[0] ?? warningDiagnostics[0] ?? null; const messageSource = errorDiagnostics.length > 0 ? errorDiagnostics : warningDiagnostics; results.set( entry.index, makeResult({ isValid: errorDiagnostics.length === 0, errorCount: errorDiagnostics.length, warningCount: warningDiagnostics.length, message: messageSource .slice(0, 3) .map((diag) => diag.message) .join(" | "), severity: top ? top.severity : null, code: top ? top.code : null, labels: top ? top.labels : [], codeframe: top ? top.codeframe : null, }), ); } return results; } catch (error) { return fallbackLintResults(entries, `oxlint execution failed: ${error}`); } finally { rmSync(tempDir, { recursive: true, force: true }); } } function readStdin() { return new Promise((resolve, reject) => { let data = ""; process.stdin.setEncoding("utf8"); process.stdin.on("data", (chunk) => { data += chunk; }); process.stdin.on("end", () => resolve(data)); process.stdin.on("error", (error) => reject(error)); }); } function runValidation({ codes, lang, mode, codeShape }) { if (mode === "syntax") { return codes.map((code, index) => validateSyntaxOne({ code, lang, index, codeShape }).result, ); } if (mode === "lint") { const entries = codes.map((code, index) => resolveLintEntry({ code, lang, index, codeShape }), ); const lintMap = runLintBatch(entries); return entries.map( (entry) => lintMap.get(entry.index) ?? makeResult({ isValid: true, errorCount: 0, warningCount: 0, }), ); } const syntaxRuns = codes.map((code, index) => validateSyntaxOne({ code, lang, index, codeShape }), ); const lintTargets = syntaxRuns .filter((run) => run.result.is_valid === true) .map((run) => run.lintEntry); const lintMap = runLintBatch(lintTargets); return syntaxRuns.map((run) => { if (run.result.is_valid !== true) { return run.result; } return ( lintMap.get(run.lintEntry.index) ?? makeResult({ isValid: true, errorCount: 0, warningCount: 0, }) ); }); } async function main() { const raw = await readStdin(); let payload; try { payload = JSON.parse(raw || "{}"); } catch { process.stdout.write( JSON.stringify([ makeResult({ isValid: false, errorCount: 1, warningCount: 0, message: "Invalid JSON payload", severity: "error", }), ]), ); return; } const lang = mapLang(payload?.lang); const mode = mapMode(payload?.mode); const codeShape = mapCodeShape(payload?.code_shape); const codes = Array.isArray(payload?.codes) ? payload.codes : []; const out = runValidation({ codes, lang, mode, codeShape }); process.stdout.write(JSON.stringify(out)); } main().catch((error) => { process.stderr.write(String(error?.stack || error)); process.exit(1); }); ================================================ FILE: studio/backend/core/data_recipe/service.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 from __future__ import annotations import base64 import io import os from pathlib import Path from typing import Any from .jsonable import to_jsonable from .local_callable_validators import ( register_oxc_local_callable_validators, split_oxc_local_callable_validators, ) _IMAGE_CONTEXT_PATCHED = False def _encode_bytes_to_base64(value: bytes | bytearray) -> str: return base64.b64encode(bytes(value)).decode("utf-8") def _load_image_file_to_base64( path_value: str, *, base_path: str | None = None ) -> str | None: try: path = Path(path_value) candidates: list[Path] = [] if path.is_absolute(): candidates.append(path) else: if base_path: candidates.append(Path(base_path) / path) candidates.append(Path.cwd() / path) for candidate in candidates: if not candidate.exists() or not candidate.is_file(): continue with candidate.open("rb") as f: return _encode_bytes_to_base64(f.read()) except (OSError, TypeError, ValueError): return None return None def _pil_image_to_base64(value: Any) -> str | None: try: from PIL.Image import Image as PILImage # type: ignore except ImportError: return None if not isinstance(value, PILImage): return None buffer = io.BytesIO() image_format = str(getattr(value, "format", "") or "").upper() if image_format not in {"PNG", "JPEG", "JPG", "WEBP", "GIF"}: image_format = "PNG" value.save(buffer, format = image_format) return _encode_bytes_to_base64(buffer.getvalue()) def _normalize_image_context_value(value: Any, *, base_path: str | None = None) -> Any: if isinstance(value, str): return value if isinstance(value, (bytes, bytearray)): return _encode_bytes_to_base64(value) pil_base64 = _pil_image_to_base64(value) if pil_base64 is not None: return pil_base64 if isinstance(value, dict): url = value.get("url") if isinstance(url, str): return url image_url = value.get("image_url") if isinstance(image_url, str): return image_url if isinstance(image_url, dict): nested_url = image_url.get("url") if isinstance(nested_url, str): return nested_url inline_data = value.get("data") if isinstance(inline_data, str): return inline_data raw_bytes = value.get("bytes") if isinstance(raw_bytes, (bytes, bytearray)): return _encode_bytes_to_base64(raw_bytes) if isinstance(raw_bytes, str) and raw_bytes.strip(): return raw_bytes path_value = value.get("path") if isinstance(path_value, str) and path_value.strip(): if as_base64 := _load_image_file_to_base64(path_value, base_path = base_path): return as_base64 return path_value return value def _apply_data_designer_image_context_patch() -> None: global _IMAGE_CONTEXT_PATCHED if _IMAGE_CONTEXT_PATCHED: return try: from data_designer.config.models import ImageContext except ImportError: return if getattr(ImageContext, "_unsloth_image_context_patch_applied", False): _IMAGE_CONTEXT_PATCHED = True return original_auto_resolve = ImageContext._auto_resolve_context_value def _patched_auto_resolve( self: Any, context_value: Any, base_path: str | None ) -> Any: normalized = _normalize_image_context_value(context_value, base_path = base_path) return original_auto_resolve(self, normalized, base_path) ImageContext._auto_resolve_context_value = _patched_auto_resolve setattr(ImageContext, "_unsloth_image_context_patch_applied", True) _IMAGE_CONTEXT_PATCHED = True def build_model_providers(recipe: dict[str, Any]): from data_designer.config.models import ModelProvider providers: list[ModelProvider] = [] for provider in recipe.get("model_providers", []): api_key = provider.get("api_key") api_key_env = provider.get("api_key_env") if not api_key and api_key_env: api_key = os.getenv(api_key_env) providers.append( ModelProvider( name = provider["name"], endpoint = provider["endpoint"], provider_type = provider.get("provider_type", "openai"), api_key = api_key, extra_headers = provider.get("extra_headers"), extra_body = provider.get("extra_body"), ) ) return providers def _recipe_has_llm_columns(recipe: dict[str, Any]) -> bool: for column in recipe.get("columns", []): if not isinstance(column, dict): continue column_type = column.get("column_type") if isinstance(column_type, str) and column_type.startswith("llm-"): return True return False def _validate_recipe_runtime_support( recipe: dict[str, Any], model_providers: list[Any], ) -> None: if not _recipe_has_llm_columns(recipe): raise ValueError( "Recipe Studio currently requires at least one AI generation step." ) if not model_providers: raise ValueError("Add a Provider connection block before running this recipe.") def build_mcp_providers( recipe: dict[str, Any], ) -> list: from data_designer.config.mcp import LocalStdioMCPProvider, MCPProvider providers: list[MCPProvider | LocalStdioMCPProvider] = [] for provider in recipe.get("mcp_providers", []): if not isinstance(provider, dict): continue provider_type = provider.get("provider_type") if provider_type == "stdio": env = provider.get("env") if not isinstance(env, dict): env = {} args = provider.get("args") if not isinstance(args, list): args = [] providers.append( LocalStdioMCPProvider( name = str(provider.get("name", "")), command = str(provider.get("command", "")), args = [str(value) for value in args], env = {str(key): str(value) for key, value in env.items()}, ) ) continue if provider_type in {"sse", "streamable_http"}: api_key = provider.get("api_key") api_key_env = provider.get("api_key_env") if not api_key and api_key_env: api_key = os.getenv(str(api_key_env)) providers.append( MCPProvider( name = str(provider.get("name", "")), endpoint = str(provider.get("endpoint", "")), provider_type = str(provider_type), api_key = str(api_key) if api_key else None, ) ) return providers def build_config_builder(recipe: dict[str, Any]): _apply_data_designer_image_context_patch() from data_designer.config import DataDesignerConfigBuilder from data_designer.config.processors import ProcessorType recipe_core = { key: value for key, value in recipe.items() if key not in {"model_providers", "mcp_providers"} } recipe_core, oxc_local_callable_specs = split_oxc_local_callable_validators( recipe_core ) builder = DataDesignerConfigBuilder.from_config({"data_designer": recipe_core}) register_oxc_local_callable_validators( builder = builder, specs = oxc_local_callable_specs, ) # DataDesignerConfigBuilder.from_config currently skips processors. # Re-attach explicitly so drop_columns/schema_transform survive API payload. for processor in recipe_core.get("processors") or []: if not isinstance(processor, dict): continue processor_type_raw = processor.get("processor_type") if not isinstance(processor_type_raw, str): continue kwargs = {k: v for k, v in processor.items() if k != "processor_type"} builder.add_processor( processor_type = ProcessorType(processor_type_raw), **kwargs, ) return builder def create_data_designer( recipe: dict[str, Any], *, artifact_path: str | None = None, ): _apply_data_designer_image_context_patch() from data_designer.interface.data_designer import DataDesigner model_providers = build_model_providers(recipe) _validate_recipe_runtime_support(recipe, model_providers) return DataDesigner( artifact_path = artifact_path, model_providers = model_providers, mcp_providers = build_mcp_providers(recipe), ) def validate_recipe(recipe: dict[str, Any]) -> None: builder = build_config_builder(recipe) designer = create_data_designer(recipe) designer.validate(builder) def preview_recipe( recipe: dict[str, Any], num_records: int, ) -> tuple[list[dict[str, Any]], dict[str, Any] | None, dict[str, Any] | None]: builder = build_config_builder(recipe) designer = create_data_designer(recipe) results = designer.preview(builder, num_records = num_records) dataset: list[dict[str, Any]] = [] if results.dataset is not None: raw_rows = results.dataset.to_dict(orient = "records") dataset = [to_jsonable(row) for row in raw_rows] artifacts = ( None if results.processor_artifacts is None else to_jsonable(results.processor_artifacts) ) analysis = ( None if results.analysis is None else to_jsonable(results.analysis.model_dump(mode = "json")) ) return dataset, artifacts, analysis ================================================ FILE: studio/backend/core/export/__init__.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Export submodule - Model export operations The default get_export_backend() returns an ExportOrchestrator that delegates to a subprocess. The original ExportBackend runs inside the subprocess and can be imported directly from .export when needed. """ from .orchestrator import ExportOrchestrator, get_export_backend # Expose ExportOrchestrator as ExportBackend for backward compat ExportBackend = ExportOrchestrator __all__ = [ "ExportBackend", "ExportOrchestrator", "get_export_backend", ] ================================================ FILE: studio/backend/core/export/export.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 # backend/export.py """ Export backend - handles model exporting in various formats """ import glob import json import structlog from loggers import get_logger import os import shutil from pathlib import Path from typing import Optional, Tuple, List from peft import PeftModel, PeftModelForCausalLM from unsloth import FastLanguageModel, FastVisionModel from huggingface_hub import HfApi, ModelCard from transformers.modeling_utils import PushToHubMixin import torch from utils.hardware import clear_gpu_cache from utils.models import is_vision_model, get_base_model_from_lora from utils.models.model_config import detect_audio_type from utils.paths import ensure_dir, outputs_root, resolve_export_dir, resolve_output_dir from core.inference import get_inference_backend logger = get_logger(__name__) def _is_wsl(): """Detect if running under Windows Subsystem for Linux.""" try: return "microsoft" in open("/proc/version").read().lower() except Exception: return False def _apply_wsl_sudo_patch(): """On WSL, monkey-patch do_we_need_sudo() to return False. WSL doesn't have passwordless sudo, and do_we_need_sudo() runs `sudo apt-get update` which hangs waiting for a stdin password inside a non-interactive subprocess. setup.sh pre-installs the build dependencies on WSL, so sudo is not needed at runtime. """ if not _is_wsl(): return try: import unsloth_zoo.llama_cpp as llama_cpp_module def _wsl_do_we_need_sudo(system_type = "debian"): logger.info( "WSL detected — skipping sudo check " "(build deps pre-installed by setup.sh)" ) return False llama_cpp_module.do_we_need_sudo = _wsl_do_we_need_sudo logger.info( "Applied WSL sudo patch to " "unsloth_zoo.llama_cpp.do_we_need_sudo" ) except Exception as e: logger.warning(f"Could not apply WSL sudo patch: {e}") # Model card template MODEL_CARD = """--- base_model: {base_model} tags: - text-generation-inference - transformers - unsloth - {model_type} - {extra} license: apache-2.0 language: - en --- # Uploaded finetuned {method} model - **Developed by:** {username} - **License:** apache-2.0 - **Finetuned from model :** {base_model} This {model_type} model was trained 2x faster with [Unsloth](https://github.com/unslothai/unsloth) and Huggingface's TRL library. [](https://github.com/unslothai/unsloth) """ class ExportBackend: """Handles model export operations""" def __init__(self): self.inference_backend = get_inference_backend() self.current_checkpoint = None self.current_model = None self.current_tokenizer = None self.is_vision = False self.is_peft = False self._audio_type = None def cleanup_memory(self): """Offload and delete all models from memory""" try: logger.info("Starting memory cleanup...") # Unload all models from inference backend model_names = list(self.inference_backend.models.keys()) for model_name in model_names: self.inference_backend.unload_model(model_name) # Clear current export state self.current_model = None self.current_tokenizer = None self.current_checkpoint = None self._audio_type = None # Clear GPU memory cache (handles gc + backend-specific cleanup) clear_gpu_cache() logger.info("Memory cleanup completed successfully") return True except Exception as e: logger.error(f"Error during memory cleanup: {e}") return False def scan_checkpoints( self, outputs_dir: str = str(outputs_root()) ) -> List[Tuple[str, List[Tuple[str, str]]]]: """ Scan outputs folder for training runs and their checkpoints. Returns: List of tuples: [(model_name, [(display_name, checkpoint_path), ...]), ...] """ from utils.models.checkpoints import scan_checkpoints return scan_checkpoints(outputs_dir = outputs_dir) def load_checkpoint( self, checkpoint_path: str, max_seq_length: int = 2048, load_in_4bit: bool = True, trust_remote_code: bool = False, ) -> Tuple[bool, str]: """ Load a checkpoint for export. Returns: Tuple of (success: bool, message: str) """ try: logger.info(f"Loading checkpoint: {checkpoint_path}") # First, cleanup existing models self.cleanup_memory() checkpoint_path_obj = Path(checkpoint_path) # Determine the model identity for type detection adapter_config = checkpoint_path_obj / "adapter_config.json" base_model = None if adapter_config.exists(): base_model = get_base_model_from_lora(checkpoint_path) if not base_model: return False, "Could not determine base model for adapter" model_id = base_model or checkpoint_path # Detect audio type and vision self._audio_type = detect_audio_type(model_id) self.is_vision = not self._audio_type and is_vision_model(model_id) # Load model based on type if self._audio_type == "csm": from unsloth import FastModel from transformers import CsmForConditionalGeneration logger.info("Loading as CSM audio model...") model, tokenizer = FastModel.from_pretrained( model_name = checkpoint_path, max_seq_length = max_seq_length, dtype = None, auto_model = CsmForConditionalGeneration, load_in_4bit = False, trust_remote_code = trust_remote_code, ) elif self._audio_type == "whisper": from unsloth import FastModel from transformers import WhisperForConditionalGeneration logger.info("Loading as Whisper audio model...") model, tokenizer = FastModel.from_pretrained( model_name = checkpoint_path, dtype = None, load_in_4bit = False, auto_model = WhisperForConditionalGeneration, trust_remote_code = trust_remote_code, ) elif self._audio_type == "snac": logger.info("Loading as SNAC (Orpheus) audio model...") model, tokenizer = FastLanguageModel.from_pretrained( model_name = checkpoint_path, max_seq_length = max_seq_length, dtype = None, load_in_4bit = load_in_4bit, trust_remote_code = trust_remote_code, ) elif self._audio_type == "bicodec": from unsloth import FastModel logger.info("Loading as BiCodec (Spark-TTS) audio model...") model, tokenizer = FastModel.from_pretrained( model_name = checkpoint_path, max_seq_length = max_seq_length, dtype = torch.float32, load_in_4bit = False, trust_remote_code = trust_remote_code, ) elif self._audio_type == "dac": from unsloth import FastModel logger.info("Loading as DAC (OuteTTS) audio model...") model, tokenizer = FastModel.from_pretrained( model_name = checkpoint_path, max_seq_length = max_seq_length, load_in_4bit = False, trust_remote_code = trust_remote_code, ) elif self.is_vision: logger.info("Loading as vision model...") model, processor = FastVisionModel.from_pretrained( model_name = checkpoint_path, max_seq_length = max_seq_length, dtype = None, load_in_4bit = load_in_4bit, trust_remote_code = trust_remote_code, ) tokenizer = processor # For vision models, processor acts as tokenizer else: logger.info("Loading as text model...") model, tokenizer = FastLanguageModel.from_pretrained( model_name = checkpoint_path, max_seq_length = max_seq_length, dtype = None, load_in_4bit = load_in_4bit, trust_remote_code = trust_remote_code, ) # Check if PEFT model self.is_peft = isinstance(model, (PeftModel, PeftModelForCausalLM)) # Store loaded model self.current_model = model self.current_tokenizer = tokenizer self.current_checkpoint = checkpoint_path if self._audio_type: model_type = f"Audio ({self._audio_type})" elif self.is_vision: model_type = "Vision" else: model_type = "Text" peft_info = " (PEFT Adapter)" if self.is_peft else " (Merged Model)" logger.info(f"Successfully loaded {model_type} model{peft_info}") return True, f"Loaded {model_type} model{peft_info} successfully" except Exception as e: logger.error(f"Error loading checkpoint: {e}") import traceback logger.error(traceback.format_exc()) return False, f"Failed to load checkpoint: {str(e)}" def _write_export_metadata(self, save_directory: str): """Write export_metadata.json with base model info for Chat page discovery.""" try: base_model = ( get_base_model_from_lora(self.current_checkpoint) if self.current_checkpoint else None ) metadata = {"base_model": base_model} metadata_path = os.path.join(save_directory, "export_metadata.json") with open(metadata_path, "w") as f: json.dump(metadata, f, indent = 2) logger.info(f"Wrote export metadata to {metadata_path}") except Exception as e: logger.warning(f"Could not write export metadata: {e}") def export_merged_model( self, save_directory: str, format_type: str = "16-bit (FP16)", push_to_hub: bool = False, repo_id: Optional[str] = None, hf_token: Optional[str] = None, private: bool = False, ) -> Tuple[bool, str]: """ Export merged model (for PEFT models). Args: save_directory: Local directory to save model format_type: "16-bit (FP16)" or "4-bit (FP4)" push_to_hub: Whether to push to Hugging Face Hub repo_id: Hub repository ID (username/model-name) hf_token: Hugging Face token private: Whether to make the repo private Returns: Tuple of (success: bool, message: str) """ if not self.current_model or not self.current_tokenizer: return False, "No model loaded. Please select a checkpoint first." if not self.is_peft: return False, "This is not a PEFT model. Use 'Export Base Model' instead." try: # Determine save method if format_type == "4-bit (FP4)": save_method = "merged_4bit_forced" elif self._audio_type == "whisper": # Whisper uses save_method=None for local 16-bit merged save save_method = None else: # 16-bit (FP16) save_method = "merged_16bit" # Save locally if requested if save_directory: save_directory = str(resolve_export_dir(save_directory)) logger.info(f"Saving merged model locally to: {save_directory}") ensure_dir(Path(save_directory)) self.current_model.save_pretrained_merged( save_directory, self.current_tokenizer, save_method = save_method ) # Write export metadata so the Chat page can identify the base model self._write_export_metadata(save_directory) logger.info(f"Model saved successfully to {save_directory}") # Push to hub if requested if push_to_hub: if not repo_id or not hf_token: return ( False, "Repository ID and Hugging Face token required for Hub upload", ) logger.info(f"Pushing merged model to Hub: {repo_id}") # Whisper uses save_method=None for local but "merged_16bit" for hub push hub_save_method = ( save_method if save_method is not None else "merged_16bit" ) self.current_model.push_to_hub_merged( repo_id, self.current_tokenizer, save_method = hub_save_method, token = hf_token, private = private, ) logger.info(f"Model pushed successfully to {repo_id}") return True, "Model exported successfully" except Exception as e: logger.error(f"Error exporting merged model: {e}") import traceback logger.error(traceback.format_exc()) return False, f"Export failed: {str(e)}" def export_base_model( self, save_directory: str, push_to_hub: bool = False, repo_id: Optional[str] = None, hf_token: Optional[str] = None, private: bool = False, base_model_id: Optional[str] = None, ) -> Tuple[bool, str]: """ Export base model (for non-PEFT models). Returns: Tuple of (success: bool, message: str) """ if not self.current_model or not self.current_tokenizer: return False, "No model loaded. Please select a checkpoint first." if self.is_peft: return ( False, "This is a PEFT model. Use 'Merged Model' export type instead.", ) try: # Save locally if requested if save_directory: save_directory = str(resolve_export_dir(save_directory)) logger.info(f"Saving base model locally to: {save_directory}") ensure_dir(Path(save_directory)) self.current_model.save_pretrained(save_directory) self.current_tokenizer.save_pretrained(save_directory) # Write export metadata so the Chat page can identify the base model self._write_export_metadata(save_directory) logger.info(f"Model saved successfully to {save_directory}") # Push to hub if requested if push_to_hub: if not repo_id or not hf_token: return ( False, "Repository ID and Hugging Face token required for Hub upload", ) logger.info(f"Pushing base model to Hub: {repo_id}") # Get base model name from request or model config base_model = ( base_model_id or self.current_model.config._name_or_path or "unknown" ) # Create repo hf_api = HfApi(token = hf_token) repo_id = PushToHubMixin._create_repo( PushToHubMixin, repo_id = repo_id, private = private, token = hf_token, ) username = repo_id.split("/")[0] # Create and push model card content = MODEL_CARD.format( username = username, base_model = base_model, model_type = self.current_model.config.model_type, method = "", extra = "unsloth", ) card = ModelCard(content) card.push_to_hub( repo_id, token = hf_token, commit_message = "Unsloth Model Card" ) # Upload model files if save_directory: hf_api.upload_folder( folder_path = save_directory, repo_id = repo_id, repo_type = "model" ) logger.info(f"Model pushed successfully to {repo_id}") else: return False, "Local save directory required for Hub upload" return True, "Model exported successfully" except Exception as e: logger.error(f"Error exporting base model: {e}") import traceback logger.error(traceback.format_exc()) return False, f"Export failed: {str(e)}" def export_gguf( self, save_directory: str, quantization_method: str = "Q4_K_M", push_to_hub: bool = False, repo_id: Optional[str] = None, hf_token: Optional[str] = None, ) -> Tuple[bool, str]: """ Export model in GGUF format. Args: save_directory: Local directory to save model quantization_method: GGUF quantization method (e.g., "Q4_K_M") push_to_hub: Whether to push to Hugging Face Hub repo_id: Hub repository ID hf_token: Hugging Face token Returns: Tuple of (success: bool, message: str) """ if not self.current_model or not self.current_tokenizer: return False, "No model loaded. Please select a checkpoint first." try: # Convert quantization method to lowercase for unsloth quant_method = quantization_method.lower() # Save locally if requested if save_directory: save_directory = str(resolve_export_dir(save_directory)) # Resolve to absolute path so unsloth's relative-path internals # (check_llama_cpp, use_local_gguf, _download_convert_hf_to_gguf) # all resolve against the repo root cwd, NOT the export directory. abs_save_dir = os.path.abspath(save_directory) logger.info(f"Saving GGUF model locally to: {abs_save_dir}") # Create the directory if it doesn't exist ensure_dir(Path(abs_save_dir)) # On WSL, patch out sudo check before llama.cpp build _apply_wsl_sudo_patch() # Snapshot existing .gguf files in cwd before conversion. # unsloth's convert_to_gguf writes output files relative to # cwd (repo root), so we diff afterwards and relocate them. cwd = os.getcwd() pre_existing_ggufs = set(glob.glob(os.path.join(cwd, "*.gguf"))) # Pass absolute path — no os.chdir needed. # unsloth saves intermediate HF model files into model_save_path. # unsloth-zoo's check_llama_cpp() uses ~/.unsloth/llama.cpp by default. model_save_path = os.path.join(abs_save_dir, "model") self.current_model.save_pretrained_gguf( model_save_path, self.current_tokenizer, quantization_method = quant_method, ) # Relocate GGUF artifacts into the export directory. # convert_to_gguf writes .gguf files to cwd (repo root) # because --outfile is a relative path like "model.Q4_K_M.gguf". new_ggufs = ( set(glob.glob(os.path.join(cwd, "*.gguf"))) - pre_existing_ggufs ) for src in sorted(new_ggufs): dest = os.path.join(abs_save_dir, os.path.basename(src)) shutil.move(src, dest) logger.info( f"Relocated GGUF: {os.path.basename(src)} → {abs_save_dir}/" ) # Flatten any .gguf files from subdirectories into abs_save_dir. # save_pretrained_gguf may create subdirs (e.g. model_gguf/) # with a name different from model_save_path. for sub in list(Path(abs_save_dir).iterdir()): if not sub.is_dir(): continue for src in sub.glob("*.gguf"): dest = os.path.join(abs_save_dir, src.name) shutil.move(str(src), dest) logger.info(f"Relocated GGUF: {src.name} → {abs_save_dir}/") # Clean up the subdirectory (intermediate HF files, etc.) shutil.rmtree(str(sub), ignore_errors = True) logger.info(f"Cleaned up subdirectory: {sub.name}") # Write export metadata so the Chat page can identify the base model self._write_export_metadata(abs_save_dir) # Log final file locations (after relocation) so it's clear # where the GGUF files actually ended up. final_ggufs = sorted(glob.glob(os.path.join(abs_save_dir, "*.gguf"))) logger.info( "GGUF export complete. Final files in %s:\n %s", abs_save_dir, "\n ".join(os.path.basename(f) for f in final_ggufs) or "(none)", ) # Push to hub if requested if push_to_hub: if not repo_id or not hf_token: return ( False, "Repository ID and Hugging Face token required for Hub upload", ) logger.info(f"Pushing GGUF model to Hub: {repo_id}") self.current_model.push_to_hub_gguf( repo_id, self.current_tokenizer, quantization_method = quant_method, token = hf_token, ) logger.info(f"GGUF model pushed successfully to {repo_id}") return True, f"GGUF model exported successfully ({quantization_method})" except Exception as e: logger.error(f"Error exporting GGUF model: {e}") import traceback logger.error(traceback.format_exc()) return False, f"GGUF export failed: {str(e)}" def export_lora_adapter( self, save_directory: str, push_to_hub: bool = False, repo_id: Optional[str] = None, hf_token: Optional[str] = None, private: bool = False, ) -> Tuple[bool, str]: """ Export LoRA adapter only (not merged). Returns: Tuple of (success: bool, message: str) """ if not self.current_model or not self.current_tokenizer: return False, "No model loaded. Please select a checkpoint first." if not self.is_peft: return False, "This is not a PEFT model. No adapter to export." try: # Save locally if requested if save_directory: save_directory = str(resolve_export_dir(save_directory)) logger.info(f"Saving LoRA adapter locally to: {save_directory}") ensure_dir(Path(save_directory)) self.current_model.save_pretrained(save_directory) self.current_tokenizer.save_pretrained(save_directory) logger.info(f"Adapter saved successfully to {save_directory}") # Push to hub if requested if push_to_hub: if not repo_id or not hf_token: return ( False, "Repository ID and Hugging Face token required for Hub upload", ) logger.info(f"Pushing LoRA adapter to Hub: {repo_id}") self.current_model.push_to_hub(repo_id, token = hf_token, private = private) self.current_tokenizer.push_to_hub( repo_id, token = hf_token, private = private ) logger.info(f"Adapter pushed successfully to {repo_id}") return True, "LoRA adapter exported successfully" except Exception as e: logger.error(f"Error exporting LoRA adapter: {e}") import traceback logger.error(traceback.format_exc()) return False, f"Adapter export failed: {str(e)}" # Global export backend instance _export_backend = None def get_export_backend() -> ExportBackend: """Get or create the global export backend instance""" global _export_backend if _export_backend is None: _export_backend = ExportBackend() return _export_backend ================================================ FILE: studio/backend/core/export/orchestrator.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Export orchestrator — subprocess-based. Provides the same API as ExportBackend, but delegates all ML work to a persistent subprocess. The subprocess is spawned on first checkpoint load and stays alive for subsequent export operations. When switching between checkpoints that need different transformers versions, the old subprocess is killed and a new one is spawned with the correct version. Pattern follows core/inference/orchestrator.py. """ import atexit import structlog from loggers import get_logger import multiprocessing as mp import queue import threading import time from pathlib import Path from typing import Any, List, Optional, Tuple from utils.paths import outputs_root logger = get_logger(__name__) _CTX = mp.get_context("spawn") class ExportOrchestrator: """ Export backend orchestrator — subprocess-based. Exposes the same API surface as ExportBackend so routes/export.py needs minimal changes. Internally, all heavy ML operations happen in a persistent subprocess. """ def __init__(self): # Subprocess state self._proc: Optional[mp.Process] = None self._cmd_queue: Any = None self._resp_queue: Any = None self._lock = threading.Lock() # Local state mirrors (updated from subprocess responses) self.current_checkpoint: Optional[str] = None self.is_vision: bool = False self.is_peft: bool = False atexit.register(self._cleanup) logger.info("ExportOrchestrator initialized (subprocess mode)") # ------------------------------------------------------------------ # Subprocess lifecycle # ------------------------------------------------------------------ def _spawn_subprocess(self, config: dict) -> None: """Spawn a new export subprocess.""" from .worker import run_export_process self._cmd_queue = _CTX.Queue() self._resp_queue = _CTX.Queue() self._proc = _CTX.Process( target = run_export_process, kwargs = { "cmd_queue": self._cmd_queue, "resp_queue": self._resp_queue, "config": config, }, daemon = True, ) self._proc.start() logger.info("Export subprocess started (pid=%s)", self._proc.pid) def _shutdown_subprocess(self, timeout: float = 10.0) -> None: """Gracefully shut down the export subprocess.""" if self._proc is None or not self._proc.is_alive(): self._proc = None return # 1. Drain stale responses self._drain_queue() # 2. Send shutdown command try: self._cmd_queue.put({"type": "shutdown"}) except (OSError, ValueError): pass # 3. Wait for graceful shutdown try: self._proc.join(timeout = timeout) except Exception: pass # 4. Force kill if still alive if self._proc is not None and self._proc.is_alive(): logger.warning("Export subprocess did not exit gracefully, terminating") try: self._proc.terminate() self._proc.join(timeout = 5) except Exception: pass if self._proc is not None and self._proc.is_alive(): logger.warning("Subprocess still alive after terminate, killing") try: self._proc.kill() self._proc.join(timeout = 3) except Exception: pass self._proc = None self._cmd_queue = None self._resp_queue = None logger.info("Export subprocess shut down") def _cleanup(self): """atexit handler.""" self._shutdown_subprocess(timeout = 5.0) def _ensure_subprocess_alive(self) -> bool: """Check if subprocess is alive.""" return self._proc is not None and self._proc.is_alive() # ------------------------------------------------------------------ # Queue helpers # ------------------------------------------------------------------ def _send_cmd(self, cmd: dict) -> None: """Send a command to the subprocess.""" if self._cmd_queue is None: raise RuntimeError("No export subprocess running") try: self._cmd_queue.put(cmd) except (OSError, ValueError) as exc: raise RuntimeError(f"Failed to send command to subprocess: {exc}") def _read_resp(self, timeout: float = 1.0) -> Optional[dict]: """Read a response from the subprocess (non-blocking with timeout).""" if self._resp_queue is None: return None try: return self._resp_queue.get(timeout = timeout) except queue.Empty: return None except (EOFError, OSError, ValueError): return None def _wait_response(self, expected_type: str, timeout: float = 3600.0) -> dict: """Block until a response of the expected type arrives. Export operations can take a very long time — GGUF conversion for large models (30B+) easily takes 20-30 minutes. Default timeout is 1 hour. """ deadline = time.monotonic() + timeout while time.monotonic() < deadline: remaining = max(0.1, deadline - time.monotonic()) resp = self._read_resp(timeout = min(remaining, 2.0)) if resp is None: # Check subprocess health if not self._ensure_subprocess_alive(): raise RuntimeError("Export subprocess crashed during wait") continue rtype = resp.get("type", "") if rtype == expected_type: return resp if rtype == "error": error_msg = resp.get("error", "Unknown error") raise RuntimeError(f"Subprocess error: {error_msg}") if rtype == "status": logger.info("Export subprocess status: %s", resp.get("message", "")) continue # Other response types during wait — skip logger.debug( "Skipping response type '%s' while waiting for '%s'", rtype, expected_type, ) raise RuntimeError( f"Timeout waiting for '{expected_type}' response after {timeout}s" ) def _drain_queue(self) -> list: """Drain all pending responses.""" events = [] if self._resp_queue is None: return events while True: try: events.append(self._resp_queue.get_nowait()) except queue.Empty: return events except (EOFError, OSError, ValueError): return events # ------------------------------------------------------------------ # Public API — same interface as ExportBackend # ------------------------------------------------------------------ def load_checkpoint( self, checkpoint_path: str, max_seq_length: int = 2048, load_in_4bit: bool = True, trust_remote_code: bool = False, ) -> Tuple[bool, str]: """Load a checkpoint for export. Always spawns a fresh subprocess to ensure a clean Python interpreter. """ sub_config = { "checkpoint_path": checkpoint_path, "max_seq_length": max_seq_length, "load_in_4bit": load_in_4bit, "trust_remote_code": trust_remote_code, } # Always kill existing subprocess and spawn fresh. if self._ensure_subprocess_alive(): self._shutdown_subprocess() elif self._proc is not None: self._shutdown_subprocess(timeout = 2) logger.info("Spawning fresh export subprocess for '%s'", checkpoint_path) self._spawn_subprocess(sub_config) try: resp = self._wait_response("loaded", timeout = 300) except RuntimeError as exc: self._shutdown_subprocess(timeout = 5) self.current_checkpoint = None self.is_vision = False self.is_peft = False return False, str(exc) if resp.get("success"): self.current_checkpoint = resp.get("checkpoint") self.is_vision = resp.get("is_vision", False) self.is_peft = resp.get("is_peft", False) logger.info("Checkpoint '%s' loaded in subprocess", checkpoint_path) return True, resp.get("message", "Loaded successfully") else: error = resp.get("message", "Failed to load checkpoint") logger.error("Failed to load checkpoint: %s", error) self.current_checkpoint = None self.is_vision = False self.is_peft = False return False, error def export_merged_model( self, save_directory: str, format_type: str = "16-bit (FP16)", push_to_hub: bool = False, repo_id: Optional[str] = None, hf_token: Optional[str] = None, private: bool = False, ) -> Tuple[bool, str]: """Export merged PEFT model.""" return self._run_export( "merged", { "save_directory": save_directory, "format_type": format_type, "push_to_hub": push_to_hub, "repo_id": repo_id, "hf_token": hf_token, "private": private, }, ) def export_base_model( self, save_directory: str, push_to_hub: bool = False, repo_id: Optional[str] = None, hf_token: Optional[str] = None, private: bool = False, base_model_id: Optional[str] = None, ) -> Tuple[bool, str]: """Export base model (non-PEFT).""" return self._run_export( "base", { "save_directory": save_directory, "push_to_hub": push_to_hub, "repo_id": repo_id, "hf_token": hf_token, "private": private, "base_model_id": base_model_id, }, ) def export_gguf( self, save_directory: str, quantization_method: str = "Q4_K_M", push_to_hub: bool = False, repo_id: Optional[str] = None, hf_token: Optional[str] = None, ) -> Tuple[bool, str]: """Export model in GGUF format.""" return self._run_export( "gguf", { "save_directory": save_directory, "quantization_method": quantization_method, "push_to_hub": push_to_hub, "repo_id": repo_id, "hf_token": hf_token, }, ) def export_lora_adapter( self, save_directory: str, push_to_hub: bool = False, repo_id: Optional[str] = None, hf_token: Optional[str] = None, private: bool = False, ) -> Tuple[bool, str]: """Export LoRA adapter only.""" return self._run_export( "lora", { "save_directory": save_directory, "push_to_hub": push_to_hub, "repo_id": repo_id, "hf_token": hf_token, "private": private, }, ) def _run_export(self, export_type: str, params: dict) -> Tuple[bool, str]: """Send an export command to the subprocess and wait for result.""" if not self._ensure_subprocess_alive(): return False, "No export subprocess running. Load a checkpoint first." cmd = {"type": "export", "export_type": export_type, **params} try: self._send_cmd(cmd) resp = self._wait_response( f"export_{export_type}_done", timeout = 3600, # GGUF for 30B+ models can take 30+ min ) return resp.get("success", False), resp.get("message", "") except RuntimeError as exc: return False, str(exc) def cleanup_memory(self) -> bool: """Cleanup export-related models from memory.""" if not self._ensure_subprocess_alive(): # No subprocess — just clear local state self.current_checkpoint = None self.is_vision = False self.is_peft = False return True try: self._send_cmd({"type": "cleanup"}) resp = self._wait_response("cleanup_done", timeout = 30) success = resp.get("success", False) except RuntimeError: success = False # Shut down subprocess after cleanup — no model loaded self._shutdown_subprocess() self.current_checkpoint = None self.is_vision = False self.is_peft = False return success def scan_checkpoints( self, outputs_dir: str = str(outputs_root()) ) -> List[Tuple[str, list]]: """Scan for checkpoints — no ML imports needed, runs locally.""" from utils.models.checkpoints import scan_checkpoints return scan_checkpoints(outputs_dir = outputs_dir) # ========== GLOBAL INSTANCE ========== _export_backend = None def get_export_backend() -> ExportOrchestrator: """Get global export backend instance (orchestrator).""" global _export_backend if _export_backend is None: _export_backend = ExportOrchestrator() return _export_backend ================================================ FILE: studio/backend/core/export/worker.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Export subprocess entry point. Each export session runs in a persistent subprocess (mp.get_context("spawn")). This gives us a clean Python interpreter with no stale module state — solving the transformers version-switching problem completely. The subprocess stays alive while a model is loaded, accepting commands (load, export_merged, export_base, export_gguf, export_lora, cleanup, shutdown) via mp.Queue. Pattern follows core/inference/worker.py and core/training/worker.py. """ from __future__ import annotations import structlog from loggers import get_logger import os import sys import time import traceback from pathlib import Path from typing import Any logger = get_logger(__name__) def _activate_transformers_version(model_name: str) -> None: """Activate the correct transformers version BEFORE any ML imports. If the model needs transformers 5.x, prepend the pre-installed .venv_t5/ directory to sys.path. Otherwise do nothing (default 4.57.x in .venv/). """ # Ensure backend is on path for utils imports backend_path = str(Path(__file__).resolve().parent.parent.parent) if backend_path not in sys.path: sys.path.insert(0, backend_path) from utils.transformers_version import ( needs_transformers_5, _resolve_base_model, _ensure_venv_t5_exists, _VENV_T5_DIR, ) resolved = _resolve_base_model(model_name) if needs_transformers_5(resolved): if not _ensure_venv_t5_exists(): raise RuntimeError( f"Cannot activate transformers 5.x: .venv_t5 missing at {_VENV_T5_DIR}" ) if _VENV_T5_DIR not in sys.path: sys.path.insert(0, _VENV_T5_DIR) logger.info("Activated transformers 5.x from %s", _VENV_T5_DIR) # Propagate to child subprocesses (e.g. GGUF converter) _pp = os.environ.get("PYTHONPATH", "") os.environ["PYTHONPATH"] = _VENV_T5_DIR + (os.pathsep + _pp if _pp else "") else: logger.info("Using default transformers (4.57.x) for %s", model_name) def _send_response(resp_queue: Any, response: dict) -> None: """Send a response to the parent process.""" try: resp_queue.put(response) except (OSError, ValueError) as exc: logger.error("Failed to send response: %s", exc) def _handle_load(backend, cmd: dict, resp_queue: Any) -> None: """Handle a load_checkpoint command.""" checkpoint_path = cmd["checkpoint_path"] max_seq_length = cmd.get("max_seq_length", 2048) load_in_4bit = cmd.get("load_in_4bit", True) trust_remote_code = cmd.get("trust_remote_code", False) try: _send_response( resp_queue, { "type": "status", "message": f"Loading checkpoint: {checkpoint_path}", "ts": time.time(), }, ) success, message = backend.load_checkpoint( checkpoint_path = checkpoint_path, max_seq_length = max_seq_length, load_in_4bit = load_in_4bit, trust_remote_code = trust_remote_code, ) _send_response( resp_queue, { "type": "loaded", "success": success, "message": message, "checkpoint": checkpoint_path if success else None, "is_vision": backend.is_vision if success else False, "is_peft": backend.is_peft if success else False, "ts": time.time(), }, ) except Exception as exc: _send_response( resp_queue, { "type": "loaded", "success": False, "message": str(exc), "stack": traceback.format_exc(limit = 20), "ts": time.time(), }, ) def _handle_export(backend, cmd: dict, resp_queue: Any) -> None: """Handle any export command (merged, base, gguf, lora).""" export_type = cmd["export_type"] # "merged", "base", "gguf", "lora" response_type = f"export_{export_type}_done" try: if export_type == "merged": success, message = backend.export_merged_model( save_directory = cmd.get("save_directory", ""), format_type = cmd.get("format_type", "16-bit (FP16)"), push_to_hub = cmd.get("push_to_hub", False), repo_id = cmd.get("repo_id"), hf_token = cmd.get("hf_token"), private = cmd.get("private", False), ) elif export_type == "base": success, message = backend.export_base_model( save_directory = cmd.get("save_directory", ""), push_to_hub = cmd.get("push_to_hub", False), repo_id = cmd.get("repo_id"), hf_token = cmd.get("hf_token"), private = cmd.get("private", False), base_model_id = cmd.get("base_model_id"), ) elif export_type == "gguf": success, message = backend.export_gguf( save_directory = cmd.get("save_directory", ""), quantization_method = cmd.get("quantization_method", "Q4_K_M"), push_to_hub = cmd.get("push_to_hub", False), repo_id = cmd.get("repo_id"), hf_token = cmd.get("hf_token"), ) elif export_type == "lora": success, message = backend.export_lora_adapter( save_directory = cmd.get("save_directory", ""), push_to_hub = cmd.get("push_to_hub", False), repo_id = cmd.get("repo_id"), hf_token = cmd.get("hf_token"), private = cmd.get("private", False), ) else: success, message = False, f"Unknown export type: {export_type}" _send_response( resp_queue, { "type": response_type, "success": success, "message": message, "ts": time.time(), }, ) except Exception as exc: _send_response( resp_queue, { "type": response_type, "success": False, "message": str(exc), "stack": traceback.format_exc(limit = 20), "ts": time.time(), }, ) def _handle_cleanup(backend, resp_queue: Any) -> None: """Handle a cleanup command.""" try: success = backend.cleanup_memory() _send_response( resp_queue, { "type": "cleanup_done", "success": success, "ts": time.time(), }, ) except Exception as exc: _send_response( resp_queue, { "type": "cleanup_done", "success": False, "message": str(exc), "ts": time.time(), }, ) def run_export_process( *, cmd_queue: Any, resp_queue: Any, config: dict, ) -> None: """Subprocess entrypoint. Persistent — runs command loop until shutdown. Args: cmd_queue: mp.Queue for receiving commands from parent. resp_queue: mp.Queue for sending responses to parent. config: Initial configuration dict with checkpoint_path. """ import queue as _queue os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["PYTHONWARNINGS"] = ( "ignore" # Suppress warnings at C-level before imports ) import warnings from loggers.config import LogConfig if os.getenv("ENVIRONMENT_TYPE", "production") == "production": warnings.filterwarnings("ignore") LogConfig.setup_logging( service_name = "unsloth-studio-export-worker", env = os.getenv("ENVIRONMENT_TYPE", "production"), ) checkpoint_path = config["checkpoint_path"] # ── 1. Activate correct transformers version BEFORE any ML imports ── try: _activate_transformers_version(checkpoint_path) except Exception as exc: _send_response( resp_queue, { "type": "error", "error": f"Failed to activate transformers version: {exc}", "stack": traceback.format_exc(limit = 20), "ts": time.time(), }, ) return # ── 1b. On Windows, check Triton availability (must be before import torch) ── if sys.platform == "win32": try: import triton # noqa: F401 logger.info("Triton available — torch.compile enabled") except ImportError: os.environ["TORCHDYNAMO_DISABLE"] = "1" logger.warning( "Triton not found on Windows — torch.compile disabled. " 'Install for better performance: pip install "triton-windows<3.7"' ) # ── 2. Import ML libraries (fresh in this clean process) ── try: _send_response( resp_queue, { "type": "status", "message": "Importing Unsloth...", "ts": time.time(), }, ) backend_path = str(Path(__file__).resolve().parent.parent.parent) if backend_path not in sys.path: sys.path.insert(0, backend_path) from core.export.export import ExportBackend import transformers logger.info( "Export subprocess loaded transformers %s", transformers.__version__ ) except Exception as exc: _send_response( resp_queue, { "type": "error", "error": f"Failed to import ML libraries: {exc}", "stack": traceback.format_exc(limit = 20), "ts": time.time(), }, ) return # ── 3. Create export backend and load initial checkpoint ── try: backend = ExportBackend() _handle_load(backend, config, resp_queue) except Exception as exc: _send_response( resp_queue, { "type": "error", "error": f"Failed to initialize export backend: {exc}", "stack": traceback.format_exc(limit = 20), "ts": time.time(), }, ) return # ── 4. Command loop — process commands until shutdown ── logger.info("Export subprocess ready, entering command loop") while True: try: cmd = cmd_queue.get(timeout = 1.0) except _queue.Empty: continue except (EOFError, OSError): logger.info("Command queue closed, shutting down") return if cmd is None: continue cmd_type = cmd.get("type", "") logger.info("Received command: %s", cmd_type) try: if cmd_type == "load": # Load a new checkpoint (reusing this subprocess) backend.cleanup_memory() _handle_load(backend, cmd, resp_queue) elif cmd_type == "export": _handle_export(backend, cmd, resp_queue) elif cmd_type == "cleanup": _handle_cleanup(backend, resp_queue) elif cmd_type == "status": _send_response( resp_queue, { "type": "status_response", "checkpoint": backend.current_checkpoint, "is_vision": backend.is_vision, "is_peft": backend.is_peft, "ts": time.time(), }, ) elif cmd_type == "shutdown": logger.info("Shutdown command received, cleaning up and exiting") try: backend.cleanup_memory() except Exception: pass _send_response( resp_queue, { "type": "shutdown_ack", "ts": time.time(), }, ) return else: logger.warning("Unknown command type: %s", cmd_type) _send_response( resp_queue, { "type": "error", "error": f"Unknown command type: {cmd_type}", "ts": time.time(), }, ) except Exception as exc: logger.error( "Error handling command '%s': %s", cmd_type, exc, exc_info = True ) _send_response( resp_queue, { "type": "error", "error": f"Command '{cmd_type}' failed: {exc}", "stack": traceback.format_exc(limit = 20), "ts": time.time(), }, ) ================================================ FILE: studio/backend/core/inference/__init__.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Inference submodule - Inference backend for model loading and generation The default get_inference_backend() returns an InferenceOrchestrator that delegates to a subprocess. The original InferenceBackend runs inside the subprocess and can be imported directly from .inference when needed. """ from .orchestrator import InferenceOrchestrator, get_inference_backend from .llama_cpp import LlamaCppBackend # Expose InferenceOrchestrator as InferenceBackend for backward compat InferenceBackend = InferenceOrchestrator __all__ = [ "InferenceBackend", "InferenceOrchestrator", "get_inference_backend", "LlamaCppBackend", ] ================================================ FILE: studio/backend/core/inference/audio_codecs.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Audio codec loading and decoding for TTS inference. Supports: SNAC (Orpheus), CSM (Sesame), BiCodec (Spark), DAC (OuteTTS) """ import io import re import wave import structlog from loggers import get_logger from typing import Optional, Tuple import numpy as np import torch logger = get_logger(__name__) def _numpy_to_wav_bytes(waveform: np.ndarray, sample_rate: int) -> bytes: """Convert a float32 numpy waveform to WAV bytes (16-bit PCM).""" waveform = waveform.flatten() peak = max(abs(waveform.max()), abs(waveform.min())) if peak > 1.0: waveform = waveform / peak pcm = (waveform * 32767).astype(np.int16) buf = io.BytesIO() with wave.open(buf, "wb") as wf: wf.setnchannels(1) wf.setsampwidth(2) wf.setframerate(sample_rate) wf.writeframes(pcm.tobytes()) return buf.getvalue() class AudioCodecManager: """Manages loading and caching of audio codec models for TTS decoding.""" def __init__(self): self._snac_model = None self._bicodec_tokenizer = None self._bicodec_repo_path = None self._dac_audio_codec = None def load_codec( self, audio_type: str, device: str = "cuda", model_repo_path: Optional[str] = None, ) -> None: """Load the appropriate codec for the given audio type.""" if audio_type == "snac": self._load_snac(device) elif audio_type == "bicodec": self._load_bicodec(device, model_repo_path) elif audio_type == "dac": self._load_dac(device) elif audio_type == "csm": pass # CSM decoding is built into the model (output_audio=True) else: raise ValueError(f"Unknown audio_type: {audio_type}") # ── Lazy loaders ───────────────────────────────────────────── def _load_snac(self, device: str) -> None: if self._snac_model is not None: return from snac import SNAC self._snac_model = ( SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device).eval() ) logger.info("Loaded SNAC codec (24kHz)") def _load_bicodec(self, device: str, model_repo_path: Optional[str] = None) -> None: if self._bicodec_tokenizer is not None: return import os import sys import subprocess # Clone SparkAudio/Spark-TTS GitHub repo for the sparktts Python package # (same approach as training — the HF model repos don't contain the package) spark_code_dir = os.path.join( os.path.dirname(model_repo_path or "."), "Spark-TTS" ) sparktts_pkg = os.path.join(spark_code_dir, "sparktts") if not os.path.isdir(sparktts_pkg): logger.info(f"Cloning SparkAudio/Spark-TTS to {spark_code_dir}...") subprocess.run( [ "git", "clone", "--depth", "1", "https://github.com/SparkAudio/Spark-TTS", spark_code_dir, ], check = True, ) if spark_code_dir not in sys.path: sys.path.insert(0, spark_code_dir) from sparktts.models.audio_tokenizer import BiCodecTokenizer # BiCodecTokenizer needs the MODEL repo path (contains BiCodec/ weights) tokenizer_path = model_repo_path or spark_code_dir self._bicodec_repo_path = tokenizer_path self._bicodec_tokenizer = BiCodecTokenizer(tokenizer_path, device) logger.info(f"Loaded BiCodec tokenizer from {tokenizer_path}") def _load_dac(self, device: str) -> None: if self._dac_audio_codec is not None: return import os import sys import subprocess # Clone OuteTTS repo (same pattern as Spark-TTS / BiCodec) # The pip package has problematic dependencies; the notebook clones and # removes gguf_model.py, interface.py, __init__.py before importing. base_dir = os.path.dirname(os.path.abspath(__file__)) outetts_code_dir = os.path.join(base_dir, "OuteTTS") outetts_pkg = os.path.join(outetts_code_dir, "outetts") if not os.path.isdir(outetts_pkg): logger.info(f"Cloning edwko/OuteTTS to {outetts_code_dir}...") subprocess.run( [ "git", "clone", "--depth", "1", "https://github.com/edwko/OuteTTS", outetts_code_dir, ], check = True, ) # Remove files that pull in heavy / incompatible dependencies # (matches notebook: gguf_model.py is under models/, others under outetts/) remove_paths = [ os.path.join(outetts_pkg, "models", "gguf_model.py"), os.path.join(outetts_pkg, "interface.py"), os.path.join(outetts_pkg, "__init__.py"), ] for fpath in remove_paths: if os.path.exists(fpath): os.remove(fpath) logger.info(f"Removed {fpath}") if outetts_code_dir not in sys.path: sys.path.insert(0, outetts_code_dir) from outetts.version.v3.audio_processor import AudioProcessor from outetts.models.config import ModelConfig as OuteTTSModelConfig dummy_config = OuteTTSModelConfig( tokenizer_path = "OuteAI/Llama-OuteTTS-1.0-1B", device = device, audio_codec_path = None, ) processor = AudioProcessor(config = dummy_config) self._dac_audio_codec = processor.audio_codec logger.info("Loaded DAC audio codec") # ── Decoders ───────────────────────────────────────────────── def decode_snac( self, generated_ids: torch.Tensor, device: str ) -> Tuple[bytes, int]: """ Decode SNAC tokens (Orpheus) into WAV bytes. generated_ids: full model output including prompt tokens. Looks for START_OF_SPEECH (128257) marker, extracts codes after it, strips EOS (128258), redistributes 7-per-frame codes into 3 SNAC layers. Returns (wav_bytes, 24000). """ # Find START_OF_SPEECH token (128257) token_indices = (generated_ids == 128257).nonzero(as_tuple = True) if len(token_indices[1]) > 0: cropped = generated_ids[:, token_indices[1][-1] + 1 :] else: # Gracefully fall back to using entire output if marker not found logger.warning( "No START_OF_SPEECH token (128257) found — using full generated output" ) cropped = generated_ids row = cropped[0] # Remove EOS tokens (128258) row = row[row != 128258] # Trim to multiple of 7 row = row[: (len(row) // 7) * 7] if len(row) == 0: raise ValueError("No valid audio codes found after START_OF_SPEECH token") codes = [t.item() - 128266 for t in row] # Redistribute into 3 SNAC layers (7 codes per frame → 1+2+4) layer_1, layer_2, layer_3 = [], [], [] for i in range(len(codes) // 7): layer_1.append(codes[7 * i]) layer_2.append(codes[7 * i + 1] - 4096) layer_3.append(codes[7 * i + 2] - 8192) layer_3.append(codes[7 * i + 3] - 12288) layer_2.append(codes[7 * i + 4] - 16384) layer_3.append(codes[7 * i + 5] - 20480) layer_3.append(codes[7 * i + 6] - 24576) snac_codes = [ torch.tensor(layer).unsqueeze(0).to(device) for layer in [layer_1, layer_2, layer_3] ] with torch.no_grad(): audio = self._snac_model.decode(snac_codes) waveform = audio.squeeze().cpu().numpy() return _numpy_to_wav_bytes(waveform, 24000), 24000 def decode_csm(self, audio_values: torch.Tensor) -> Tuple[bytes, int]: """ Decode CSM output (already a waveform from model.generate(output_audio=True)). Returns (wav_bytes, 24000). """ waveform = audio_values[0].to(torch.float32).cpu().numpy() return _numpy_to_wav_bytes(waveform, 24000), 24000 def decode_bicodec(self, generated_text: str, device: str) -> Tuple[bytes, int]: """ Decode BiCodec tokens (Spark-TTS) from generated text. Extracts bicodec_semantic_N and bicodec_global_N tokens via regex. Returns (wav_bytes, sample_rate). """ semantic_matches = re.findall(r"<\|bicodec_semantic_(\d+)\|>", generated_text) global_matches = re.findall(r"<\|bicodec_global_(\d+)\|>", generated_text) logger.info( f"BiCodec decode: {len(global_matches)} global tokens, {len(semantic_matches)} semantic tokens" ) if len(global_matches) < 10: logger.info( f"BiCodec generated text (first 500 chars): {generated_text[:500]}" ) if not semantic_matches: raise ValueError("No bicodec_semantic tokens found in generated output") semantic_ids = ( torch.tensor([int(t) for t in semantic_matches]).long().unsqueeze(0) ) # Speaker encoder expects exactly 32 global tokens (token_num=32 in BiCodec config). # Pad with zeros or truncate to 32. GLOBAL_TOKEN_NUM = 32 if global_matches: raw = [int(t) for t in global_matches] else: raw = [] if len(raw) < GLOBAL_TOKEN_NUM: raw = raw + [0] * (GLOBAL_TOKEN_NUM - len(raw)) raw = raw[:GLOBAL_TOKEN_NUM] global_ids = torch.tensor(raw).long().unsqueeze(0) # (1, 32) self._bicodec_tokenizer.device = device self._bicodec_tokenizer.model.to(device) wav_np = self._bicodec_tokenizer.detokenize( global_ids.to(device), semantic_ids.to(device), ) sr = self._bicodec_tokenizer.config.get("sample_rate", 16000) return _numpy_to_wav_bytes(wav_np, sr), sr def decode_dac(self, generated_text: str, device: str) -> Tuple[bytes, int]: """ Decode DAC tokens (OuteTTS) from generated text. Extracts c1_N and c2_N codec code tokens via regex. Returns (wav_bytes, 24000). """ c1 = list(map(int, re.findall(r"<\|c1_(\d+)\|>", generated_text))) c2 = list(map(int, re.findall(r"<\|c2_(\d+)\|>", generated_text))) if not c1 or not c2: raise ValueError("No DAC code tokens (c1/c2) found in generated output") t = min(len(c1), len(c2)) c1 = c1[:t] c2 = c2[:t] codes = torch.tensor([[c1, c2]], dtype = torch.int64).to(device) with torch.no_grad(): audio = self._dac_audio_codec.decode(codes) waveform = audio.squeeze().cpu().numpy() return _numpy_to_wav_bytes(waveform, 24000), 24000 def decode( self, audio_type: str, device: str, token_ids: Optional[list] = None, text: Optional[str] = None, ) -> Tuple[bytes, int]: """Unified decode — dispatches to the right codec decoder.""" if audio_type == "snac": if not token_ids: raise ValueError("SNAC decoding requires token_ids") return self.decode_snac(torch.tensor([token_ids], dtype = torch.long), device) elif audio_type == "bicodec": if not text: raise ValueError("BiCodec decoding requires text") return self.decode_bicodec(text, device) elif audio_type == "dac": if not text: raise ValueError("DAC decoding requires text") return self.decode_dac(text, device) raise ValueError(f"Cannot decode audio_type: {audio_type}") # ── Cleanup ────────────────────────────────────────────────── def unload(self) -> None: """Release all codec models from memory.""" if self._snac_model is not None: del self._snac_model self._snac_model = None if self._bicodec_tokenizer is not None: del self._bicodec_tokenizer self._bicodec_tokenizer = None self._bicodec_repo_path = None if self._dac_audio_codec is not None: del self._dac_audio_codec self._dac_audio_codec = None logger.info("Unloaded all audio codecs") ================================================ FILE: studio/backend/core/inference/defaults.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """Default model lists for inference, split by platform.""" import utils.hardware.hardware as hw DEFAULT_MODELS_GGUF = [ "unsloth/Llama-3.2-1B-Instruct-GGUF", "unsloth/Llama-3.2-3B-Instruct-GGUF", "unsloth/Llama-3.1-8B-Instruct-GGUF", "unsloth/gemma-3-1b-it-GGUF", "unsloth/gemma-3-4b-it-GGUF", "unsloth/Qwen3-4B-GGUF", ] DEFAULT_MODELS_STANDARD = [ "unsloth/Qwen3-4B-Instruct-2507", "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit", "unsloth/Mistral-Nemo-Instruct-2407-bnb-4bit", "unsloth/Phi-3.5-mini-instruct", "unsloth/Gemma-3-4B-it", "unsloth/Qwen2-VL-2B-Instruct-bnb-4bit", ] def get_default_models() -> list[str]: hw.get_device() # ensure detect_hardware() has run if hw.CHAT_ONLY: return list(DEFAULT_MODELS_GGUF) return list(DEFAULT_MODELS_STANDARD) ================================================ FILE: studio/backend/core/inference/inference.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Core inference backend - streamlined """ from unsloth import FastLanguageModel, FastVisionModel from unsloth.chat_templates import get_chat_template from transformers import TextStreamer from peft import PeftModel, PeftModelForCausalLM import json import sys import torch from pathlib import Path from typing import Optional, Union, Generator, Tuple from utils.models import ModelConfig, get_base_model_from_lora from utils.paths import is_model_cached from utils.utils import format_error_message from utils.hardware import get_device, clear_gpu_cache, log_gpu_memory from core.inference.audio_codecs import AudioCodecManager from io import StringIO import structlog from loggers import get_logger logger = get_logger(__name__) class HarmonyTextStreamer: """Streaming text decoder for gpt-oss harmony channel protocol. gpt-oss models emit multi-channel output using special tokens like ``<|channel|>analysis<|message|>...`` and ``<|channel|>final<|message|>...``. A plain ``TextIteratorStreamer(skip_special_tokens=True)`` strips the special tokens but leaves the channel names concatenated with content, producing garbled output such as ``analysisWe need to respond...assistantfinalHello!``. This streamer decodes with ``skip_special_tokens=False`` so the full harmony markup is visible, then uses **stateful incremental** parsing to emit properly-formatted text: - ```` emitted once when the ``analysis`` channel is first seen - Analysis content streamed incrementally - ```` emitted once when the ``final`` channel is first seen - Final content streamed incrementally This avoids the delta-on-transformed bug where wrapping tags shift position as content grows. Implements the same ``put`` / ``end`` / iterator interface as ``TextIteratorStreamer`` so ``generate_stream`` can use it as a drop-in replacement. """ import re as _re _HARMONY_RE = _re.compile( r"<\|channel\|>(\w+)<\|message\|>(.*?)(?=<\|end\|>|<\|channel\|>|\Z)", _re.DOTALL, ) def __init__(self, tokenizer, *, skip_prompt: bool = True, timeout: float = 0.2): import queue self.tokenizer = tokenizer self.skip_prompt = skip_prompt self.timeout = timeout self._queue: queue.Queue = queue.Queue() self._token_ids: list = [] self._prompt_len: int = 0 self._is_first_put: bool = True self._stop: bool = False # Stateful channel tracking — avoids delta-on-transformed bugs self._emitted_think_open: bool = False self._emitted_think_close: bool = False self._analysis_emitted: int = 0 # chars of analysis content emitted self._final_emitted: int = 0 # chars of final content emitted # ------------------------------------------------------------------ # put / end — called from the generation thread # ------------------------------------------------------------------ def put(self, value): """Receive new token IDs from model.generate().""" import torch if isinstance(value, torch.Tensor): # value shape: (batch, seq) — take first batch element ids = value[0].tolist() if value.dim() > 1 else value.tolist() elif isinstance(value, (list, tuple)): ids = list(value) else: ids = [value] if self._is_first_put and self.skip_prompt: # First call contains the full prompt; remember its length self._prompt_len = len(ids) self._token_ids = list(ids) self._is_first_put = False return self._token_ids.extend(ids) # Decode only the generated part (after the prompt) gen_ids = self._token_ids[self._prompt_len :] raw = self.tokenizer.decode(gen_ids, skip_special_tokens = False) self._process_incremental(raw) def end(self): """Signal generation is complete.""" # Final decode to capture any remaining content gen_ids = self._token_ids[self._prompt_len :] if gen_ids: raw = self.tokenizer.decode(gen_ids, skip_special_tokens = False) self._process_incremental(raw) # Close any open think tags if self._emitted_think_open and not self._emitted_think_close: self._queue.put("") self._emitted_think_close = True self._stop = True self._queue.put(None) # sentinel # ------------------------------------------------------------------ # Iterator interface — consumed by the streaming loop # ------------------------------------------------------------------ def __iter__(self): return self def __next__(self): from queue import Empty while True: try: val = self._queue.get(timeout = self.timeout) except Empty: if self._stop: raise StopIteration raise # propagate Empty so caller can check thread liveness if val is None: raise StopIteration return val # ------------------------------------------------------------------ # Stateful incremental harmony protocol parsing # ------------------------------------------------------------------ def _process_incremental(self, raw: str) -> None: """Parse harmony channels and emit deltas per-channel. Instead of transforming the entire raw text and computing a string delta (which breaks when wrapping ```` tags shift position), this tracks per-channel content lengths and emits: - ```` once when analysis channel first appears - analysis content deltas (computed on channel content directly) - ```` once when final channel first appears - final content deltas """ # If raw contains <|channel|> but no complete channel+message pair yet, # buffer silently — don't emit partial channel names as text. has_channel_token = "<|channel|>" in raw matches = list(self._HARMONY_RE.finditer(raw)) if has_channel_token and not matches: # Partial harmony markup still building — wait for more tokens return if not has_channel_token and not matches: # No harmony protocol at all — should not happen for gpt-oss # but handle gracefully by not emitting anything return for m in matches: channel = m.group(1).lower() content = m.group(2) if channel == "analysis": if not self._emitted_think_open: self._queue.put("") self._emitted_think_open = True new_content = content[self._analysis_emitted :] if new_content: self._analysis_emitted = len(content) self._queue.put(new_content) elif channel in ("final", "assistant"): if self._emitted_think_open and not self._emitted_think_close: self._queue.put("") self._emitted_think_close = True new_content = content[self._final_emitted :] if new_content: self._final_emitted = len(content) self._queue.put(new_content) class InferenceBackend: """Unified inference backend supporting text, vision, and LoRA models""" def __init__(self): self.models = {} self.active_model_name = None self.loading_models = set() self.loaded_local_models = [] # [(display_name, path), ...] from core.inference.defaults import get_default_models self.default_models = get_default_models() self.device = get_device().value self._audio_codec_manager = AudioCodecManager() # Thread safety — _generation_lock serializes model.generate() calls. # Must be a regular Lock (NOT RLock) because in async FastAPI, multiple # requests share the same event-loop thread, so RLock reentrancy lets # concurrent compare-mode requests race on the GPU. The lock is # acquired by the *background generation thread*, not the event-loop. import threading self._generation_lock = threading.Lock() self._model_state_lock = threading.Lock() logger.info(f"InferenceBackend initialized on {self.device}") @staticmethod def _normalize_top_k(top_k: int) -> int: # API supports -1 as "disable top-k"; transformers expects 0 to disable. return 0 if top_k < 0 else top_k def load_model( self, config: ModelConfig, max_seq_length: int = 2048, dtype = None, load_in_4bit: bool = True, hf_token: Optional[str] = None, trust_remote_code: bool = False, ) -> bool: """ Load any model: base, LoRA adapter, text, or vision. """ try: model_name = config.identifier # Check if already loaded if model_name in self.models and self.models[model_name].get("model"): logger.info(f"Model {model_name} already loaded") self.active_model_name = model_name return True # Check if currently loading if model_name in self.loading_models: logger.info(f"Model {model_name} is already being loaded") return False self.loading_models.add(model_name) self.models[model_name] = { "is_vision": config.is_vision, "is_lora": config.is_lora, "is_audio": config.is_audio, "audio_type": config.audio_type, "has_audio_input": config.has_audio_input, "model_path": config.path, "base_model": config.base_model if config.is_lora else None, "loaded_adapters": {}, "active_adapter": None, } # ── Audio model loading path ────────────────────────── if config.is_audio: audio_type = config.audio_type adapter_info = " (LoRA adapter)" if config.is_lora else "" logger.info( f"Loading audio ({audio_type}) model{adapter_info}: {model_name}" ) log_gpu_memory(f"Before loading {model_name}") if audio_type == "csm": from unsloth import FastModel from transformers import CsmForConditionalGeneration model, processor = FastModel.from_pretrained( config.path, auto_model = CsmForConditionalGeneration, load_in_4bit = False, token = hf_token if hf_token and hf_token.strip() else None, trust_remote_code = trust_remote_code, ) FastModel.for_inference(model) self.models[model_name]["model"] = model self.models[model_name]["tokenizer"] = processor self.models[model_name]["processor"] = processor elif audio_type == "bicodec": import os from unsloth import FastModel if config.is_lora and config.base_model: # LoRA adapter: load from local adapter path. # base_model is e.g. /home/.../Spark-TTS-0.5B/LLM # The BiCodec weights are in the parent dir (Spark-TTS-0.5B/). base_path = config.base_model if os.path.isdir(base_path): abs_repo_path = os.path.abspath(os.path.dirname(base_path)) else: # base_model is an HF ID — download it from huggingface_hub import snapshot_download local_dir = base_path.split("/")[-1] repo_path = snapshot_download( base_path, local_dir = local_dir ) abs_repo_path = os.path.abspath(repo_path) logger.info( f"Spark-TTS LoRA: loading adapter from {config.path}, BiCodec from {abs_repo_path}" ) model, tokenizer = FastModel.from_pretrained( config.path, dtype = torch.float32, load_in_4bit = False, token = hf_token if hf_token and hf_token.strip() else None, trust_remote_code = trust_remote_code, ) else: # Base model: download full HF repo, then load from /LLM subfolder from huggingface_hub import snapshot_download hf_repo = config.path local_dir = hf_repo.split("/")[-1] repo_path = snapshot_download(hf_repo, local_dir = local_dir) abs_repo_path = os.path.abspath(repo_path) llm_path = os.path.join(abs_repo_path, "LLM") logger.info( f"Spark-TTS: downloaded repo to {repo_path}, loading LLM from {llm_path}" ) model, tokenizer = FastModel.from_pretrained( llm_path, dtype = torch.float32, load_in_4bit = False, token = hf_token if hf_token and hf_token.strip() else None, trust_remote_code = trust_remote_code, ) FastModel.for_inference(model) self.models[model_name]["model"] = model self.models[model_name]["tokenizer"] = tokenizer self.models[model_name]["model_repo_path"] = abs_repo_path elif audio_type == "dac": # OuteTTS uses FastModel (not FastLanguageModel) from unsloth import FastModel model, tokenizer = FastModel.from_pretrained( config.path, max_seq_length = max_seq_length, load_in_4bit = False, token = hf_token if hf_token and hf_token.strip() else None, trust_remote_code = trust_remote_code, ) FastModel.for_inference(model) self.models[model_name]["model"] = model self.models[model_name]["tokenizer"] = tokenizer elif audio_type == "whisper": # Whisper ASR — uses FastModel with WhisperForConditionalGeneration from unsloth import FastModel from transformers import WhisperForConditionalGeneration model, tokenizer = FastModel.from_pretrained( config.path, auto_model = WhisperForConditionalGeneration, whisper_language = "English", whisper_task = "transcribe", load_in_4bit = False, token = hf_token if hf_token and hf_token.strip() else None, trust_remote_code = trust_remote_code, ) FastModel.for_inference(model) model.eval() # Create ASR pipeline (per notebook) from transformers import pipeline as hf_pipeline whisper_pipe = hf_pipeline( "automatic-speech-recognition", model = model, tokenizer = tokenizer.tokenizer, feature_extractor = tokenizer.feature_extractor, processor = tokenizer, return_language = True, torch_dtype = torch.float16, ) self.models[model_name]["model"] = model self.models[model_name]["tokenizer"] = tokenizer self.models[model_name]["whisper_pipeline"] = whisper_pipe else: # SNAC (Orpheus) uses FastLanguageModel model, tokenizer = FastLanguageModel.from_pretrained( model_name = config.path, max_seq_length = max_seq_length, load_in_4bit = False, token = hf_token if hf_token and hf_token.strip() else None, trust_remote_code = trust_remote_code, ) FastLanguageModel.for_inference(model) self.models[model_name]["model"] = model self.models[model_name]["tokenizer"] = tokenizer # Load the external codec for TTS audio types # (Whisper is ASR, audio_vlm is audio input — neither needs a codec) if audio_type not in ("whisper", "audio_vlm"): model_repo_path = self.models[model_name].get("model_repo_path") self._audio_codec_manager.load_codec( audio_type, self.device, model_repo_path = model_repo_path ) self.active_model_name = model_name self.loading_models.discard(model_name) logger.info(f"Successfully loaded audio model: {model_name}") log_gpu_memory(f"After loading {model_name}") return True model_type = "vision" if config.is_vision else "text" adapter_info = ( " (LoRA adapter)" if self.models[model_name]["is_lora"] else "" ) logger.info(f"Loading {model_type} model{adapter_info}: {model_name}") log_gpu_memory(f"Before loading {model_name}") # Load model - same approach for base models and LoRA adapters if config.is_vision: # Vision model (or vision LoRA adapter) model, processor = FastVisionModel.from_pretrained( model_name = config.path, # Can be base model OR LoRA adapter path max_seq_length = max_seq_length, dtype = dtype, load_in_4bit = load_in_4bit, token = hf_token if hf_token and hf_token.strip() else None, trust_remote_code = trust_remote_code, ) # Apply inference optimization FastVisionModel.for_inference(model) # FastVisionModel may return a raw tokenizer (e.g. GemmaTokenizerFast) # instead of a proper Processor for some models (e.g. Gemma-3). # In that case, load the real processor from the base model. from transformers import ProcessorMixin if not ( isinstance(processor, ProcessorMixin) or hasattr(processor, "image_processor") ): # For LoRA adapters, use the base model. For local merged exports, # read export_metadata.json to find the original base model. processor_source = ( config.base_model if config.is_lora else config.identifier ) if not config.is_lora and config.is_local: _meta_path = Path(config.path) / "export_metadata.json" try: if _meta_path.exists(): _meta = json.loads(_meta_path.read_text()) if _meta.get("base_model"): processor_source = _meta["base_model"] except Exception: pass logger.warning( f"FastVisionModel returned {type(processor).__name__} (no image_processor) " f"for '{model_name}' — loading proper processor from '{processor_source}'" ) from transformers import AutoProcessor processor = AutoProcessor.from_pretrained( processor_source, token = hf_token if hf_token and hf_token.strip() else None, trust_remote_code = trust_remote_code, ) logger.info( f"Loaded {type(processor).__name__} from {processor_source}" ) self.models[model_name]["model"] = model self.models[model_name]["tokenizer"] = processor self.models[model_name]["processor"] = processor else: # Text model (or text LoRA adapter) model, tokenizer = FastLanguageModel.from_pretrained( model_name = config.path, # Can be base model OR LoRA adapter path max_seq_length = max_seq_length, dtype = dtype, load_in_4bit = load_in_4bit, token = hf_token if hf_token and hf_token.strip() else None, trust_remote_code = trust_remote_code, ) # Apply inference optimization FastLanguageModel.for_inference(model) self.models[model_name]["model"] = model self.models[model_name]["tokenizer"] = tokenizer # Load chat template info self._load_chat_template_info(model_name) self.active_model_name = model_name self.loading_models.discard(model_name) logger.info(f"Successfully loaded model: {model_name}") log_gpu_memory(f"After loading {model_name}") return True except Exception as e: logger.error(f"Failed to load model: {e}") error_msg = format_error_message(e, config.identifier) # Cleanup on failure if model_name in self.models: del self.models[model_name] self.loading_models.discard(model_name) raise Exception(error_msg) def unload_model(self, model_name: str) -> bool: """ Completely removes a model from the registry and clears GPU memory. """ if model_name in self.models: try: # If this was an audio model, clean up codecs if self.models[model_name].get("is_audio"): self._audio_codec_manager.unload() logger.info(f"Unloading model '{model_name}' from memory.") # Delete the model entry from our registry del self.models[model_name] # Clear the active model if it was the one being unloaded if self.active_model_name == model_name: self.active_model_name = None # Clear GPU memory cache clear_gpu_cache() # Remove stale compiled cache so the next model gets a fresh one from utils.cache_cleanup import clear_unsloth_compiled_cache clear_unsloth_compiled_cache() logger.info(f"Model '{model_name}' successfully unloaded.") return True except Exception as e: logger.error(f"Error while unloading model '{model_name}': {e}") return False else: logger.warning( f"Attempted to unload model '{model_name}', but it was not found in the registry." ) return True def revert_to_base_model(self, base_model_name: str) -> bool: """ Reverts the model to its pristine base state by unloading AND deleting all adapter configurations, as instructed. """ if base_model_name not in self.models: return False model = self.models[base_model_name].get("model") try: # Step 1: Unload the adapter weights if model is a PeftModel. if isinstance(model, (PeftModel, PeftModelForCausalLM)): logger.info(f"Unloading LoRA adapters from '{base_model_name}'...") unwrapped_base_model = model.unload() self.models[base_model_name]["model"] = unwrapped_base_model model = unwrapped_base_model # Step 2: Clear any lingering peft_config from the unwrapped model. # After model.unload(), the base model may still carry a peft_config # attribute. Removing it ensures PeftModel.from_pretrained() gets # a clean base model without "multiple adapters" warnings. if hasattr(model, "peft_config"): del model.peft_config logger.info(f"Model '{base_model_name}' reverted to clean base state.") return True except Exception as e: logger.error(f"Failed to revert model to base state: {e}") import traceback logger.error(traceback.format_exc()) return False def load_for_eval( self, lora_path: str, max_seq_length: int = 2048, dtype = None, load_in_4bit: bool = True, hf_token: Optional[str] = None, ) -> Tuple[bool, Optional[str], Optional[str]]: """ Final Corrected Version: Ensures the base model and the specified adapter are loaded. This function is idempotent and handles all states correctly. """ try: from utils.models import ModelConfig lora_config = ModelConfig.from_lora_path(lora_path, hf_token) if not lora_config: return False, None, None base_model_name = lora_config.base_model # 1. Load the base model if it's not already in memory if base_model_name not in self.models or not self.models[ base_model_name ].get("model"): logger.info(f"Base model '{base_model_name}' not loaded, loading now.") base_config = ModelConfig.from_ui_selection( base_model_name, None, is_lora = False ) if not self.load_model( base_config, max_seq_length, dtype, load_in_4bit, hf_token ): return False, None, None self.active_model_name = base_model_name # 2. Determine the required adapter name from the user's selection adapter_name = lora_path.split("/")[-1].replace(".", "_") # 3. Call our robust load_adapter function to ensure this specific adapter is loaded. # It will only load from disk if the model doesn't already have it. adapter_success = self.load_adapter( base_model_name = base_model_name, adapter_path = lora_path, adapter_name = adapter_name, ) if not adapter_success: return False, base_model_name, None # 4. Return the correct, verified adapter name for the UI logic to use. return True, base_model_name, adapter_name except Exception as e: logger.error(f"Error during load_for_eval: {e}") import traceback logger.error(traceback.format_exc()) return False, None, None def load_adapter( self, base_model_name: str, adapter_path: str, adapter_name: str ) -> bool: """ Loads an adapter onto the model ONLY if it's not already attached. """ model = self.models[base_model_name].get("model") # Check if this adapter name is already part of the model's config. This is the most reliable check. if hasattr(model, "peft_config") and adapter_name in model.peft_config: logger.info( f"Adapter '{adapter_name}' is already attached to the model. Skipping load." ) return True try: logger.info( f"Loading new adapter '{adapter_name}' from '{adapter_path}' onto {base_model_name}" ) model.load_adapter(adapter_path, adapter_name = adapter_name) # Update our internal registry ONLY after a successful load. if "loaded_adapters" not in self.models[base_model_name]: self.models[base_model_name]["loaded_adapters"] = {} self.models[base_model_name]["loaded_adapters"][adapter_name] = adapter_path total_adapters = len(getattr(model, "peft_config", {})) logger.info( f"Adapter '{adapter_name}' loaded successfully. (Total unique adapters on model: {total_adapters})" ) return True except Exception as e: logger.error(f"Failed to load adapter '{adapter_name}': {e}") return False def set_active_adapter(self, base_model_name: str, adapter_name: str) -> bool: """ Sets the active adapter for generation. This replaces the flawed 'enable_adapter'. """ model = self.models[base_model_name].get("model") try: logger.info(f"Setting active adapter to: '{adapter_name}'") model.set_adapter(adapter_name) self.models[base_model_name]["active_adapter"] = adapter_name return True except Exception as e: # This will catch the "adapter not found" error if something goes wrong. logger.error(f"Failed to set active adapter to '{adapter_name}': {e}") return False def _apply_adapter_state(self, use_adapter: Optional[Union[bool, str]]) -> None: """ Apply adapter state before generation. Must be called under _generation_lock. Uses PEFT's disable_adapter_layers() / enable_adapter_layers() which toggle a boolean flag on each LoRA layer. Unsloth's fast_linear_forward checks this flag (proj.disable_adapters) and skips LoRA computation when True. This is non-destructive — no model unloading/reloading needed. Args: use_adapter: None = no change, False = disable (base model), True = enable current adapter, str = enable specific adapter. """ if use_adapter is None: return base = self.active_model_name if not base or base not in self.models: return model_info = self.models[base] model = model_info.get("model") if model is None: return if use_adapter is False: # Disable LoRA layers → base model output if isinstance(model, (PeftModel, PeftModelForCausalLM)): logger.info( f"Compare mode: disabling adapters on '{base}' for base model generation" ) model.base_model.disable_adapter_layers() else: logger.info( f"Compare mode: model '{base}' is not a PeftModel, already base" ) elif use_adapter is True: # Re-enable LoRA layers → adapter output if isinstance(model, (PeftModel, PeftModelForCausalLM)): logger.info( f"Compare mode: enabling adapters on '{base}' for LoRA generation" ) model.base_model.enable_adapter_layers() else: logger.warning("use_adapter=true but model is not a PeftModel") elif isinstance(use_adapter, str): # Enable adapters and set the specific one active if isinstance(model, (PeftModel, PeftModelForCausalLM)): logger.info( f"Compare mode: enabling adapter '{use_adapter}' on '{base}'" ) model.base_model.enable_adapter_layers() self.set_active_adapter(base, use_adapter) else: logger.warning( f"use_adapter='{use_adapter}' but model is not a PeftModel" ) def generate_with_adapter_control( self, use_adapter: Optional[Union[bool, str]] = None, cancel_event = None, **gen_kwargs, ) -> Generator[str, None, None]: """ Thread-safe generation with optional adapter toggling. The adapter toggle + model.generate() are serialized by _generation_lock inside the background generation thread — NOT in the event-loop thread. This prevents the RLock-reentrant race that occurs when two async SSE handlers share the same event-loop thread. Args: use_adapter: Adapter control (None/False/True/str). See _apply_adapter_state. **gen_kwargs: Forwarded to generate_chat_response. """ yield from self._generate_chat_response_inner( cancel_event = cancel_event, _adapter_state = use_adapter, **gen_kwargs ) def generate_chat_response( self, messages: list, system_prompt: str, image = None, temperature: float = 0.7, top_p: float = 0.9, top_k: int = 40, min_p: float = 0.0, max_new_tokens: int = 256, repetition_penalty: float = 1.0, cancel_event = None, ) -> Generator[str, None, None]: """ Generate response for text or vision models. The generation lock is acquired by the background generation thread. """ yield from self._generate_chat_response_inner( messages = messages, system_prompt = system_prompt, image = image, temperature = temperature, top_p = top_p, top_k = top_k, min_p = min_p, max_new_tokens = max_new_tokens, repetition_penalty = repetition_penalty, cancel_event = cancel_event, ) def _generate_chat_response_inner( self, messages: list, system_prompt: str = "", image = None, temperature: float = 0.7, top_p: float = 0.9, top_k: int = 40, min_p: float = 0.0, max_new_tokens: int = 256, repetition_penalty: float = 1.0, cancel_event = None, _adapter_state = None, ) -> Generator[str, None, None]: """ Inner generation logic. Called by both generate_chat_response and generate_with_adapter_control. _adapter_state is passed to generate_stream/vision so the background thread can toggle adapters under the generation lock. """ if not self.active_model_name: yield "Error: No active model" return model_info = self.models[self.active_model_name] is_vision = model_info.get("is_vision", False) tokenizer = model_info.get("tokenizer") or model_info.get("processor") # Unwrap processor → raw tokenizer for VLMs on the text path tokenizer = getattr(tokenizer, "tokenizer", tokenizer) top_k = self._normalize_top_k(top_k) if is_vision and image: # Vision model generation (only when an image is actually provided) # Check that the stored processor can actually handle images. # FastVisionModel may return a raw tokenizer (e.g. GemmaTokenizerFast) # instead of a proper ProcessorMixin for some models (e.g. Gemma-3). from transformers import ProcessorMixin processor = model_info.get("processor") has_image_processing = processor is not None and ( isinstance(processor, ProcessorMixin) or hasattr(processor, "image_processor") ) if has_image_processing: yield from self._generate_vision_response( messages, system_prompt, image, temperature, top_p, top_k, min_p, max_new_tokens, repetition_penalty, cancel_event = cancel_event, ) return else: logger.warning( f"Model '{self.active_model_name}' is marked as vision but its processor " f"({type(processor).__name__}) has no image_processor — " f"falling back to text-only generation (image will be ignored)." ) # Text path: Use training pipeline approach # Messages are already in ChatML format from eval.py # Step 1: Apply get_chat_template if model is in mapper try: from utils.datasets import ( MODEL_TO_TEMPLATE_MAPPER, get_tokenizer_chat_template, ) model_name_lower = self.active_model_name.lower() # Check if model has a registered template if model_name_lower in MODEL_TO_TEMPLATE_MAPPER: template_name = MODEL_TO_TEMPLATE_MAPPER[model_name_lower] logger.info( f"Applying chat template '{template_name}' for {self.active_model_name}" ) # This modifies the tokenizer with the correct template tokenizer = get_chat_template( tokenizer, chat_template = template_name, ) else: logger.info( f"No registered Unsloth template for {self.active_model_name}, using tokenizer default" ) except Exception as e: logger.warning(f"Could not apply get_chat_template: {e}") # Step 2: Format with tokenizer.apply_chat_template() try: if not (hasattr(tokenizer, "chat_template") and tokenizer.chat_template): raise ValueError( f"Model '{self.active_model_name}' has no chat_template set in its " f"tokenizer_config.json. This is usually a problem with the model's " f"HuggingFace repository — it is missing a 'chat_template' key. " f"Please use a model that includes a chat template, or manually set " f"one via tokenizer.chat_template before inference." ) formatted_prompt = tokenizer.apply_chat_template( messages, tokenize = False, add_generation_prompt = True ) logger.debug(f"Formatted prompt: {formatted_prompt[:200]}...") except Exception as e: logger.error(f"Error applying chat template: {e}") # Fallback to manual formatting formatted_prompt = self.format_chat_prompt(messages, system_prompt) # Step 3: Generate yield from self.generate_stream( formatted_prompt, temperature, top_p, top_k, min_p, max_new_tokens, repetition_penalty, cancel_event = cancel_event, _adapter_state = _adapter_state, ) def _generate_vision_response( self, messages, system_prompt, image, temperature, top_p, top_k, min_p, max_new_tokens, repetition_penalty, cancel_event = None, ) -> Generator[str, None, None]: """Handle vision model generation with true token-by-token streaming.""" model_info = self.models[self.active_model_name] model = model_info["model"] processor = model_info["processor"] # FastVisionModel may return a raw tokenizer (e.g. GemmaTokenizerFast) # instead of a Processor for some models. Safe unwrap for tokenize-only ops. raw_tokenizer = getattr(processor, "tokenizer", processor) # Extract user message user_message = "" if messages and messages[-1]["role"] == "user": import re user_message = messages[-1]["content"] user_message = re.sub(r"]*>", "", user_message).strip() if not user_message: user_message = "Describe this image." if image else "Hello" # Prepare vision messages if image: vision_messages = [ { "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": user_message}, ], } ] input_text = processor.apply_chat_template( vision_messages, add_generation_prompt = True, tokenize = False ) inputs = processor( image, input_text, add_special_tokens = False, return_tensors = "pt", ).to(self.device) else: # Text-only for vision model formatted_prompt = self.format_chat_prompt(messages, system_prompt) inputs = raw_tokenizer(formatted_prompt, return_tensors = "pt").to( self.device ) # Stream with TextIteratorStreamer + background thread try: from transformers import TextIteratorStreamer import threading streamer = TextIteratorStreamer( raw_tokenizer, skip_prompt = True, skip_special_tokens = True, timeout = 0.2, ) generation_kwargs = dict( **inputs, streamer = streamer, max_new_tokens = max_new_tokens, use_cache = True, do_sample = temperature > 0, temperature = temperature, top_p = top_p, top_k = top_k, min_p = min_p, ) err: dict[str, str] = {} def generate_fn(): with self._generation_lock: try: model.generate(**generation_kwargs) except Exception as e: err["msg"] = str(e) logger.error(f"Vision generation error in thread: {e}") finally: try: streamer.end() except Exception: pass thread = threading.Thread(target = generate_fn) thread.start() output = "" from queue import Empty generation_complete = False try: while True: if cancel_event is not None and cancel_event.is_set(): break try: new_token = next(streamer) except StopIteration: generation_complete = True break except Empty: if not thread.is_alive(): generation_complete = True break continue if new_token: output += new_token cleaned = self._clean_generated_text(output) yield cleaned finally: if cancel_event is not None and not generation_complete: cancel_event.set() thread.join(timeout = 10) if thread.is_alive(): logger.warning( "Vision generation thread did not exit after cancel/join timeout" ) if err.get("msg"): yield f"Error: {err['msg']}" except Exception as e: logger.error(f"Vision generation error: {e}") yield f"Error: {str(e)}" def generate_audio_input_response( self, messages, system_prompt, audio_array, temperature, top_p, top_k, min_p, max_new_tokens, repetition_penalty, cancel_event = None, ) -> Generator[str, None, None]: """Handle audio input (ASR) generation — accepts audio numpy array, streams text output. Uses processor.apply_chat_template with audio embedded in messages (Gemma 3n pattern). """ import threading import numpy as np model_info = self.models[self.active_model_name] model = model_info["model"] processor = model_info.get("processor") or model_info.get("tokenizer") raw_tokenizer = getattr(processor, "tokenizer", processor) # Extract last user text — default matches notebook prompt user_text = "Please transcribe this audio." if messages: for msg in reversed(messages): if msg["role"] == "user" and msg.get("content"): user_text = msg["content"] break # Use ASR-specific system prompt if user hasn't set a custom one if not system_prompt: system_prompt = "You are an assistant that transcribes speech accurately." # Build messages in Gemma 3n format — audio goes INTO apply_chat_template audio_messages = [ {"role": "system", "content": [{"type": "text", "text": system_prompt}]}, { "role": "user", "content": [ {"type": "audio", "audio": audio_array}, {"type": "text", "text": user_text}, ], }, ] # apply_chat_template handles audio embedding + tokenization in one step inputs = processor.apply_chat_template( audio_messages, add_generation_prompt = True, tokenize = True, return_dict = True, return_tensors = "pt", truncation = False, ).to(self.device) try: from transformers import TextIteratorStreamer from queue import Empty streamer = TextIteratorStreamer( raw_tokenizer, skip_prompt = True, skip_special_tokens = True, timeout = 0.2, ) # Notebook uses do_sample=False for ASR (greedy decoding for accuracy) generation_kwargs = dict( **inputs, streamer = streamer, max_new_tokens = max_new_tokens, use_cache = True, do_sample = False, ) err: dict[str, str] = {} def generate_fn(): with self._generation_lock: try: model.generate(**generation_kwargs) except Exception as e: err["msg"] = str(e) logger.error(f"Audio input generation error in thread: {e}") finally: try: streamer.end() except Exception: pass thread = threading.Thread(target = generate_fn) thread.start() output = "" try: while True: if cancel_event is not None and cancel_event.is_set(): break try: new_token = next(streamer) except StopIteration: break except Empty: if not thread.is_alive(): break continue if new_token: output += new_token yield new_token finally: if cancel_event is not None: cancel_event.set() thread.join(timeout = 10) if thread.is_alive(): logger.warning( "Audio input generation thread did not exit after cancel/join timeout" ) if err.get("msg"): yield f"Error: {err['msg']}" except Exception as e: logger.error(f"Audio input generation error: {e}") yield f"Error: {str(e)}" def generate_whisper_response( self, audio_array, cancel_event = None ) -> Generator[str, None, None]: """Whisper ASR — takes audio numpy array, yields transcribed text. Uses the pre-built transformers pipeline (created during model loading). """ model_info = self.models[self.active_model_name] whisper_pipe = model_info.get("whisper_pipeline") if not whisper_pipe: yield "Error: Whisper pipeline not initialized" return try: with self._generation_lock: result = whisper_pipe({"raw": audio_array, "sampling_rate": 16000}) text = result.get("text", "") if isinstance(result, dict) else str(result) if text: yield text except Exception as e: logger.error(f"Whisper ASR error: {e}") yield f"Error: {str(e)}" def _is_gpt_oss_model(self, model_name: str = None) -> bool: """Check if the given (or active) model uses the gpt-oss harmony protocol.""" name = (model_name or self.active_model_name or "").lower() try: from utils.datasets import MODEL_TO_TEMPLATE_MAPPER # Exact match if MODEL_TO_TEMPLATE_MAPPER.get(name) == "gpt-oss": return True # Partial match (e.g. name-bnb-4bit variants) for key, tmpl in MODEL_TO_TEMPLATE_MAPPER.items(): if tmpl == "gpt-oss" and (key in name or name in key): return True except Exception: pass return "gpt-oss" in name def generate_stream( self, prompt: str, temperature: float = 0.7, top_p: float = 0.9, top_k: int = 40, min_p: float = 0.0, max_new_tokens: int = 256, repetition_penalty: float = 1.0, cancel_event = None, _adapter_state = None, ) -> Generator[str, None, None]: """Generate streaming text response (text models only). _adapter_state: if not None, the background thread toggles adapters before model.generate(), all under _generation_lock. """ if not self.active_model_name: yield "Error: No active model" return model_info = self.models[self.active_model_name] model = model_info["model"] # For VLMs the stored "tokenizer" is actually the processor. # Unwrap to get the real tokenizer so TextIteratorStreamer's # skip_prompt / skip_special_tokens work correctly. tokenizer = model_info["tokenizer"] tokenizer = getattr(tokenizer, "tokenizer", tokenizer) try: inputs = tokenizer(prompt, return_tensors = "pt").to(model.device) from transformers import TextIteratorStreamer import threading # Use HarmonyTextStreamer for gpt-oss models to properly parse # the multi-channel harmony protocol into tags if self._is_gpt_oss_model(): try: streamer = HarmonyTextStreamer( tokenizer, skip_prompt = True, timeout = 0.2, ) except Exception as e: logger.warning( f"HarmonyTextStreamer init failed, falling back: {e}" ) streamer = TextIteratorStreamer( tokenizer, skip_prompt = True, skip_special_tokens = True, timeout = 0.2, ) else: streamer = TextIteratorStreamer( tokenizer, skip_prompt = True, skip_special_tokens = True, timeout = 0.2, ) generation_kwargs = dict( **inputs, streamer = streamer, max_new_tokens = max_new_tokens, temperature = temperature, top_p = top_p, top_k = top_k, min_p = min_p, repetition_penalty = repetition_penalty, do_sample = temperature > 0, eos_token_id = tokenizer.eos_token_id, pad_token_id = tokenizer.eos_token_id if tokenizer.pad_token_id is None else tokenizer.pad_token_id, ) if cancel_event is not None: from transformers.generation.stopping_criteria import ( StoppingCriteria, StoppingCriteriaList, ) class _CancelCriteria(StoppingCriteria): def __init__(self, ev): self.ev = ev def __call__(self, input_ids, scores, **kwargs): return self.ev.is_set() generation_kwargs["stopping_criteria"] = StoppingCriteriaList( [_CancelCriteria(cancel_event)] ) def generate_fn(): with self._generation_lock: try: if _adapter_state is not None: self._apply_adapter_state(_adapter_state) model.generate(**generation_kwargs) except Exception as e: err["msg"] = str(e) logger.error(f"Generation error: {e}") finally: try: streamer.end() except Exception: pass err: dict[str, str] = {} thread = threading.Thread(target = generate_fn) thread.start() output = "" from queue import Empty generation_complete = False try: while True: if cancel_event is not None and cancel_event.is_set(): break try: new_token = next(streamer) except StopIteration: generation_complete = True break except Empty: if not thread.is_alive(): generation_complete = True break continue if new_token: output += new_token cleaned = self._clean_generated_text(output) yield cleaned finally: # Only set cancel_event when we exited early (user cancel), # NOT on normal completion. cancel_event is a shared mp.Event # — setting it unconditionally would leave a stale cancel # signal that could interfere with the next serialized # generation request (e.g. in compare mode). if cancel_event is not None and not generation_complete: cancel_event.set() thread.join(timeout = 10) if thread.is_alive(): logger.warning( "Generation thread did not exit after cancel/join timeout" ) if err.get("msg"): yield f"Error: {err['msg']}" except Exception as e: logger.error(f"Error during generation: {e}") yield f"Error: {str(e)}" # ── Audio (TTS) Generation ──────────────────────────────────── def generate_audio_response( self, text: str, temperature: float = 0.6, top_p: float = 0.95, top_k: int = 50, min_p: float = 0.0, max_new_tokens: int = 2048, repetition_penalty: float = 1.0, use_adapter: Optional[Union[bool, str]] = None, ) -> Tuple[bytes, int]: """ Generate audio from text for TTS models. Returns (wav_bytes, sample_rate). Blocking — generates complete audio before returning. """ if not self.active_model_name: raise RuntimeError("No active model") model_info = self.models[self.active_model_name] audio_type = model_info.get("audio_type") model = model_info["model"] tokenizer = model_info.get("tokenizer") if not audio_type: raise RuntimeError(f"Model {self.active_model_name} is not an audio model") top_k = self._normalize_top_k(top_k) with self._generation_lock: if use_adapter is not None: self._apply_adapter_state(use_adapter) if audio_type == "snac": return self._generate_snac( model, tokenizer, text, temperature, top_p, max_new_tokens, repetition_penalty, ) elif audio_type == "csm": processor = model_info.get("processor", tokenizer) return self._generate_csm(model, processor, text, max_new_tokens) elif audio_type == "bicodec": return self._generate_bicodec( model, tokenizer, text, temperature, top_k, max_new_tokens ) elif audio_type == "dac": return self._generate_dac( model, tokenizer, text, temperature, top_k, top_p, min_p, max_new_tokens, repetition_penalty, ) else: raise RuntimeError(f"Unknown audio_type: {audio_type}") def _generate_snac( self, model, tokenizer, text, temperature, top_p, max_new_tokens, repetition_penalty, ): """Generate audio using SNAC codec (Orpheus).""" device = model.device start_token = torch.tensor([[128259]], device = device) # START_OF_HUMAN end_tokens = torch.tensor( [[128009, 128260]], device = device ) # EOT, END_OF_HUMAN text_ids = tokenizer(text, return_tensors = "pt").input_ids.to(device) input_ids = torch.cat([start_token, text_ids, end_tokens], dim = 1) attention_mask = torch.ones_like(input_ids) generated = model.generate( input_ids = input_ids, attention_mask = attention_mask, max_new_tokens = max_new_tokens, do_sample = True, temperature = temperature, top_p = top_p, repetition_penalty = repetition_penalty, eos_token_id = 128258, # END_OF_SPEECH use_cache = True, ) return self._audio_codec_manager.decode_snac(generated, str(device)) def _generate_csm(self, model, processor, text, max_new_tokens): """Generate audio using CSM (Sesame).""" speaker_id = 0 inputs = processor( f"[{speaker_id}]{text}", add_special_tokens = True, return_tensors = "pt" ).to(model.device) audio_values = model.generate( **inputs, max_new_tokens = max_new_tokens, output_audio = True ) return self._audio_codec_manager.decode_csm(audio_values) def _generate_bicodec( self, model, tokenizer, text, temperature, top_k, max_new_tokens ): """Generate audio using BiCodec (Spark-TTS).""" prompt = ( "<|task_tts|><|start_content|>" + text + "<|end_content|><|start_global_token|>" ) inputs = tokenizer([prompt], return_tensors = "pt").to(model.device) generated = model.generate( **inputs, max_new_tokens = max_new_tokens, do_sample = True, temperature = temperature, top_k = top_k, eos_token_id = tokenizer.eos_token_id, pad_token_id = tokenizer.pad_token_id, ) new_tokens = generated[:, inputs.input_ids.shape[1] :] decoded_text = tokenizer.batch_decode(new_tokens, skip_special_tokens = False)[0] return self._audio_codec_manager.decode_bicodec(decoded_text, str(model.device)) def _generate_dac( self, model, tokenizer, text, temperature, top_k, top_p, min_p, max_new_tokens, repetition_penalty, ): """Generate audio using DAC (OuteTTS). Follows Oute_TTS_(1B).ipynb exactly.""" # Monkey-patch RepetitionPenaltyLogitsProcessor with a 64-token penalty # window (same as the OuteTTS notebook) to avoid degenerate repetition. self._patch_repetition_penalty_processor() prompt = ( "<|im_start|>\n<|text_start|>" + text + "<|text_end|>\n<|audio_start|><|global_features_start|>\n" ) with torch.inference_mode(): with torch.amp.autocast("cuda", dtype = model.dtype): inputs = tokenizer([prompt], return_tensors = "pt").to(model.device) generated = model.generate( **inputs, temperature = temperature, top_k = top_k, top_p = top_p, min_p = min_p, repetition_penalty = repetition_penalty, max_new_tokens = max_new_tokens, ) decoded_text = tokenizer.batch_decode(generated, skip_special_tokens = False)[0] return self._audio_codec_manager.decode_dac(decoded_text, str(model.device)) _repetition_penalty_patched = False @classmethod def _patch_repetition_penalty_processor(cls): """ Monkey-patch transformers' RepetitionPenaltyLogitsProcessor with a 64-token sliding window variant (from the OuteTTS notebook). Only applied once per process. """ if cls._repetition_penalty_patched: return cls._repetition_penalty_patched = True from transformers import LogitsProcessor import transformers.generation.utils as generation_utils class RepetitionPenaltyLogitsProcessorPatch(LogitsProcessor): def __init__(self, penalty: float): self.penalty_last_n = 64 if not isinstance(penalty, float) or penalty <= 0: raise ValueError( f"`penalty` has to be a positive float, but is {penalty}" ) self.penalty = penalty @torch.no_grad() def __call__( self, input_ids: torch.LongTensor, scores: torch.FloatTensor ) -> torch.FloatTensor: if self.penalty_last_n == 0 or self.penalty == 1.0: return scores batch_size, seq_len = input_ids.shape vocab_size = scores.shape[-1] for b in range(batch_size): start_index = max(0, seq_len - self.penalty_last_n) window_indices = input_ids[b, start_index:] if window_indices.numel() == 0: continue for token_id in set(window_indices.tolist()): if token_id >= vocab_size: continue logit = scores[b, token_id] scores[b, token_id] = ( logit * self.penalty if logit <= 0 else logit / self.penalty ) return scores generation_utils.RepetitionPenaltyLogitsProcessor = ( RepetitionPenaltyLogitsProcessorPatch ) logger.info( "Patched RepetitionPenaltyLogitsProcessor with 64-token window for OuteTTS" ) def format_chat_prompt(self, messages: list, system_prompt: str = None) -> str: if not self.active_model_name or self.active_model_name not in self.models: logger.error("No active model available") return "" if self.models[self.active_model_name].get("tokenizer") is None: logger.error("Tokenizer not loaded for active model") return "" chat_template_info = self.models[self.active_model_name].get( "chat_template_info", {} ) tokenizer = self.models[self.active_model_name]["tokenizer"] tokenizer = getattr(tokenizer, "tokenizer", tokenizer) chat_messages = [] if system_prompt: chat_messages.append({"role": "system", "content": system_prompt}) last_role = "system" if system_prompt else None for msg in messages: role = msg.get("role", "") content = msg.get("content", "") if role in ["system", "user", "assistant"] and content.strip(): if role == last_role: logger.debug( f"Skipping consecutive {role} message to maintain alternation" ) continue if role == "user": import re clean_content = re.sub(r"<[^>]+>", "", content).strip() if clean_content: chat_messages.append({"role": role, "content": clean_content}) last_role = role elif role == "assistant" and content.strip(): chat_messages.append({"role": role, "content": content}) last_role = role elif role == "system": continue if chat_messages and chat_messages[-1]["role"] == "assistant": logger.debug( "Removing final assistant message to ensure proper alternation" ) chat_messages.pop() logger.info(f"Sending {len(chat_messages)} messages to tokenizer:") for i, msg in enumerate(chat_messages): logger.info(f" {i}: {msg['role']} - {msg['content'][:50]}...") try: formatted_prompt = tokenizer.apply_chat_template( chat_messages, tokenize = False, add_generation_prompt = True ) logger.info(f"Successfully applied tokenizer's native chat template") return formatted_prompt except Exception as e: error_msg = str(e).lower() if ( "chat_template is not set" in error_msg or "no template argument" in error_msg ): logger.info( f"Base model detected - no built-in chat template available, using fallback formatting" ) else: logger.warning(f"Failed to apply tokenizer chat template: {e}") logger.debug( f"""Failed with messages: {[f"{m['role']}: {m['content'][:30]}..." for m in chat_messages]}""" ) if chat_template_info.get("has_template", False): logger.info( "Falling back to manual template formatting based on detected patterns" ) template_type = chat_template_info.get("format_type", "generic") manual_prompt = self._format_chat_manual( chat_messages, template_type, chat_template_info.get("special_tokens", {}), ) logger.info(f"Manual template result: {manual_prompt[:200]}...") return manual_prompt else: logger.info("Using generic chat formatting for base model") return self._format_generic_template(chat_messages, {}) def _format_chat_manual( self, messages: list, template_type: str, special_tokens: dict ) -> str: """ Manual chat formatting fallback for when tokenizer template fails Args: messages: List of message dictionaries template_type: Detected template type special_tokens: Dictionary of special tokens Returns: str: Manually formatted prompt """ if template_type == "llama3": return self._format_llama3_template(messages, special_tokens) elif template_type == "mistral": return self._format_mistral_template(messages, special_tokens) elif template_type == "chatml": return self._format_chatml_template(messages, special_tokens) elif template_type == "alpaca": return self._format_alpaca_template(messages, special_tokens) else: return self._format_generic_template(messages, special_tokens) def _format_llama3_template(self, messages: list, special_tokens: dict) -> str: """Format messages using Llama 3 template""" bos_token = special_tokens.get("bos_token", "<|begin_of_text|>") formatted = bos_token for msg in messages: role = msg["role"] content = msg["content"] formatted += ( f"<|start_header_id|>{role}<|end_header_id|>\n\n{content}<|eot_id|>" ) formatted += "<|start_header_id|>assistant<|end_header_id|>\n\n" return formatted def _format_mistral_template(self, messages: list, special_tokens: dict) -> str: """Format messages using Mistral template""" bos_token = special_tokens.get("bos_token", "") formatted = bos_token system_msg = None conversation = [] for msg in messages: if msg["role"] == "system": system_msg = msg["content"] else: conversation.append(msg) i = 0 while i < len(conversation): if conversation[i]["role"] == "user": user_content = conversation[i]["content"] if system_msg and i == 0: user_content = f"{system_msg}\n\n{user_content}" formatted += f"[INST] {user_content} [/INST]" if ( i + 1 < len(conversation) and conversation[i + 1]["role"] == "assistant" ): formatted += f" {conversation[i + 1]['content']}" i += 2 else: formatted += " " break else: i += 1 return formatted def _format_chatml_template(self, messages: list, special_tokens: dict) -> str: """Format messages using ChatML template""" formatted = "" for msg in messages: role = msg["role"] content = msg["content"] formatted += f"<|im_start|>{role}\n{content}<|im_end|>\n" formatted += "<|im_start|>assistant\n" return formatted def _format_alpaca_template(self, messages: list, special_tokens: dict) -> str: """Format messages using Alpaca template""" formatted = "" system_msg = None for msg in messages: if msg["role"] == "system": system_msg = msg["content"] elif msg["role"] == "user": if system_msg: formatted += f"### Instruction:\n{system_msg}\n\n### Input:\n{msg['content']}\n\n### Response:\n" system_msg = None else: formatted += f"### Human:\n{msg['content']}\n\n### Assistant:\n" elif msg["role"] == "assistant": formatted += f"{msg['content']}\n\n" return formatted def _format_generic_template(self, messages: list, special_tokens: dict) -> str: """Generic fallback formatting""" formatted = "" for msg in messages: role = msg["role"].title() content = msg["content"] formatted += f"{role}: {content}\n" formatted += "Assistant: " return formatted def check_vision_model_compatibility(self) -> bool: """ Check if current model supports vision. Returns: bool: True if current model supports vision, False otherwise """ current_model = self.get_current_model() if current_model and current_model in self.models: return self.models[current_model].get("is_vision", False) return False def _reset_model_generation_state(self, model_name: str): """Reset generation state for a specific model to prevent contamination.""" if model_name not in self.models: return model = self.models[model_name].get("model") if not model: return try: # This is a common pattern for Unsloth/Hugging Face models if hasattr(model, "past_key_values"): model.past_key_values = None if hasattr(model, "generation_config"): if hasattr(model.generation_config, "past_key_values"): model.generation_config.past_key_values = None logger.debug(f"Reset generation state for model: {model_name}") except Exception as e: logger.warning(f"Could not fully reset model state for {model_name}: {e}") def reset_generation_state(self): """Reset any cached generation state to prevent hanging after errors""" try: # Clear cached states for ALL loaded models for model_name in self.models.keys(): self._reset_model_generation_state(model_name) clear_gpu_cache() logger.debug("Cleared GPU cache") import gc gc.collect() logger.info("Performed comprehensive generation state reset") except Exception as e: logger.warning(f"Could not fully reset generation state: {e}") def resize_image(self, img, max_size: int = 800): """Resize image while maintaining aspect ratio if either dimension exceeds max_size""" if img is None: return None if img.size[0] > max_size or img.size[1] > max_size: from PIL import Image ratio = min(max_size / img.size[0], max_size / img.size[1]) new_size = (int(img.size[0] * ratio), int(img.size[1] * ratio)) return img.resize(new_size, Image.Resampling.LANCZOS) return img def _clean_generated_text(self, text: str) -> str: """Strip leaked special tokens using the tokenizer's own token list.""" if self._is_gpt_oss_model(): # HarmonyTextStreamer produces clean ... output. # Strip harmony protocol tokens and other gpt-oss added tokens # (e.g. <|return|>) that may leak past the streamer. import re text = re.sub(r"<\|[a-z_]+\|>", "", text) return text.strip() tokenizer = self.models.get(self.active_model_name, {}).get("tokenizer") if tokenizer: for token in getattr(tokenizer, "all_special_tokens", []): if token in text: text = text.replace(token, "") return text.strip() def _load_chat_template_info(self, model_name: str): if model_name not in self.models or not self.models[model_name].get( "tokenizer" ): return tokenizer = self.models[model_name]["tokenizer"] chat_template_info = { "has_template": False, "template": None, "format_type": "generic", "special_tokens": {}, "template_name": None, } try: from utils.datasets import MODEL_TO_TEMPLATE_MAPPER # Try exact match first model_name_lower = model_name.lower() if model_name_lower in MODEL_TO_TEMPLATE_MAPPER: chat_template_info["template_name"] = MODEL_TO_TEMPLATE_MAPPER[ model_name_lower ] logger.info( f"Detected template '{chat_template_info['template_name']}' for {model_name} from mapper" ) else: # Try partial match (for variants like model_name-bnb-4bit) for key in MODEL_TO_TEMPLATE_MAPPER: if key in model_name_lower or model_name_lower in key: chat_template_info["template_name"] = MODEL_TO_TEMPLATE_MAPPER[ key ] logger.info( f"Detected template '{chat_template_info['template_name']}' for {model_name} (partial match)" ) break except Exception as e: logger.warning( f"Could not detect template from mapper for {model_name}: {e}" ) try: if hasattr(tokenizer, "chat_template") and tokenizer.chat_template: chat_template_info["has_template"] = True chat_template_info["template"] = tokenizer.chat_template template_str = tokenizer.chat_template.lower() if ( "start_header_id" in template_str and "end_header_id" in template_str ): chat_template_info["format_type"] = "llama3" elif "[inst]" in template_str and "[/inst]" in template_str: chat_template_info["format_type"] = "mistral" elif "<|im_start|>" in template_str and "<|im_end|>" in template_str: chat_template_info["format_type"] = "chatml" elif "### instruction:" in template_str or "### human:" in template_str: chat_template_info["format_type"] = "alpaca" else: chat_template_info["format_type"] = "custom" logger.info( f"Loaded chat template for {model_name} (detected as {chat_template_info['format_type']} format)" ) logger.debug(f"Template preview: {tokenizer.chat_template[:200]}...") special_tokens = {} if hasattr(tokenizer, "bos_token") and tokenizer.bos_token: special_tokens["bos_token"] = tokenizer.bos_token if hasattr(tokenizer, "eos_token") and tokenizer.eos_token: special_tokens["eos_token"] = tokenizer.eos_token if hasattr(tokenizer, "pad_token") and tokenizer.pad_token: special_tokens["pad_token"] = tokenizer.pad_token chat_template_info["special_tokens"] = special_tokens else: logger.info( f"No chat template found for {model_name}, will use generic formatting" ) except Exception as e: logger.error(f"Error loading chat template info for {model_name}: {e}") self.models[model_name]["chat_template_info"] = chat_template_info if chat_template_info["has_template"]: logger.info( f"Chat template loaded for {model_name}: {chat_template_info['format_type']} format" ) else: logger.info( f"No built-in chat template for {model_name}, will use generic formatting" ) def get_current_model(self) -> Optional[str]: """Get currently active model name""" return self.active_model_name def is_model_loading(self) -> bool: """Check if any model is currently loading""" return len(self.loading_models) > 0 def get_loading_model(self) -> Optional[str]: """Get name of currently loading model""" return next(iter(self.loading_models)) if self.loading_models else None def load_model_simple( self, model_path: str, hf_token: Optional[str] = None, max_seq_length: int = 2048, load_in_4bit: bool = True, ) -> bool: """ Simple model loading wrapper for chat interface. Accepts model path as string and handles ModelConfig creation internally. Args: model_path: Model name or path (e.g., "unsloth/llama-3-8b") hf_token: HuggingFace token for gated models max_seq_length: Maximum sequence length load_in_4bit: Whether to use 4-bit quantization Returns: bool: True if successful, False otherwise """ try: # Create config from string path config = ModelConfig.from_ui_selection( model_path, lora_path = None, # No LoRA for chat is_lora = False, ) # Call existing load_model with config return self.load_model( config = config, max_seq_length = max_seq_length, dtype = None, # Auto-detect load_in_4bit = load_in_4bit, hf_token = hf_token, ) except Exception as e: logger.error(f"Error in load_model_simple: {e}") return False # Global inference backend instance inference_backend = InferenceBackend() def get_inference_backend() -> InferenceBackend: return inference_backend ================================================ FILE: studio/backend/core/inference/llama_cpp.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ llama-server inference backend for GGUF models. Manages a llama-server subprocess and proxies chat completions through its OpenAI-compatible /v1/chat/completions endpoint. """ import atexit import contextlib import json import struct import structlog from loggers import get_logger import shutil import signal import socket import subprocess import threading import time from pathlib import Path from typing import Generator, Optional import httpx logger = get_logger(__name__) class LlamaCppBackend: """ Manages a llama-server subprocess for GGUF model inference. Lifecycle: 1. load_model() — starts llama-server with the GGUF file 2. generate_chat_completion() — proxies to /v1/chat/completions, streams back 3. unload_model() — terminates llama-server subprocess """ def __init__(self): self._process: Optional[subprocess.Popen] = None self._port: Optional[int] = None self._model_identifier: Optional[str] = None self._gguf_path: Optional[str] = None self._hf_repo: Optional[str] = None self._hf_variant: Optional[str] = None self._is_vision: bool = False self._healthy = False self._context_length: Optional[int] = None self._chat_template: Optional[str] = None self._supports_reasoning: bool = False self._supports_tools: bool = False self._cache_type_kv: Optional[str] = None self._reasoning_default: bool = True self._lock = threading.Lock() self._stdout_lines: list[str] = [] self._stdout_thread: Optional[threading.Thread] = None self._cancel_event = threading.Event() self._kill_orphaned_servers() atexit.register(self._cleanup) # ── Properties ──────────────────────────────────────────────── @property def is_loaded(self) -> bool: return self._process is not None and self._healthy @property def is_active(self) -> bool: """True if a llama-server process exists (loading or loaded).""" return self._process is not None @property def base_url(self) -> str: return f"http://127.0.0.1:{self._port}" @property def model_identifier(self) -> Optional[str]: return self._model_identifier @property def is_vision(self) -> bool: return self._is_vision @property def hf_variant(self) -> Optional[str]: return self._hf_variant @property def context_length(self) -> Optional[int]: return self._context_length @property def chat_template(self) -> Optional[str]: return self._chat_template @property def supports_reasoning(self) -> bool: return self._supports_reasoning @property def reasoning_default(self) -> bool: return self._reasoning_default @property def supports_tools(self) -> bool: return self._supports_tools @property def cache_type_kv(self) -> Optional[str]: return self._cache_type_kv # ── Binary discovery ────────────────────────────────────────── @staticmethod def _find_llama_server_binary() -> Optional[str]: """ Locate the llama-server binary. Search order: 1. LLAMA_SERVER_PATH environment variable (direct path to binary) 1b. UNSLOTH_LLAMA_CPP_PATH env var (custom llama.cpp install dir) 2. ~/.unsloth/llama.cpp/llama-server (make build, root dir) 3. ~/.unsloth/llama.cpp/build/bin/llama-server (cmake build, Linux) 4. ~/.unsloth/llama.cpp/build/bin/Release/llama-server.exe (cmake build, Windows) 5. ./llama.cpp/llama-server (legacy: make build, root dir) 6. ./llama.cpp/build/bin/llama-server (legacy: cmake in-tree build) 7. llama-server on PATH (system install) 8. ./bin/llama-server (legacy: extracted binary) """ import os import sys binary_name = "llama-server.exe" if sys.platform == "win32" else "llama-server" # 1. Env var — direct path to binary env_path = os.environ.get("LLAMA_SERVER_PATH") if env_path and Path(env_path).is_file(): return env_path # 1b. UNSLOTH_LLAMA_CPP_PATH — custom llama.cpp install directory custom_llama_cpp = os.environ.get("UNSLOTH_LLAMA_CPP_PATH") if custom_llama_cpp: custom_dir = Path(custom_llama_cpp) # Root dir (make builds) root_bin = custom_dir / binary_name if root_bin.is_file(): return str(root_bin) # build/bin/ (cmake builds on Linux) cmake_bin = custom_dir / "build" / "bin" / binary_name if cmake_bin.is_file(): return str(cmake_bin) # build/bin/Release/ (cmake builds on Windows) if sys.platform == "win32": win_bin = custom_dir / "build" / "bin" / "Release" / binary_name if win_bin.is_file(): return str(win_bin) # 2–4. ~/.unsloth/llama.cpp (primary — setup.sh / setup.ps1 build here) unsloth_home = Path.home() / ".unsloth" / "llama.cpp" # Root dir (make builds copy binaries here) home_root = unsloth_home / binary_name if home_root.is_file(): return str(home_root) # build/bin/ (cmake builds on Linux) home_linux = unsloth_home / "build" / "bin" / binary_name if home_linux.is_file(): return str(home_linux) # 3. Windows MSVC build has Release subdir if sys.platform == "win32": home_win = unsloth_home / "build" / "bin" / "Release" / binary_name if home_win.is_file(): return str(home_win) # 5–6. Legacy: in-tree build (older setup.sh / setup.ps1 versions) project_root = Path(__file__).resolve().parents[4] # Root dir (make builds) root_path = project_root / "llama.cpp" / binary_name if root_path.is_file(): return str(root_path) # build/bin/ (cmake builds) build_path = project_root / "llama.cpp" / "build" / "bin" / binary_name if build_path.is_file(): return str(build_path) if sys.platform == "win32": win_path = ( project_root / "llama.cpp" / "build" / "bin" / "Release" / binary_name ) if win_path.is_file(): return str(win_path) # 7. System PATH system_path = shutil.which("llama-server") if system_path: return system_path # 8. Legacy: extracted to bin/ bin_path = project_root / "bin" / binary_name if bin_path.is_file(): return str(bin_path) return None # ── GPU allocation ──────────────────────────────────────────── @staticmethod def _get_gguf_size_bytes(model_path: str) -> int: """Get total GGUF size in bytes, including split shards.""" import re main = Path(model_path) total = main.stat().st_size # Check for split shards (e.g., model-00001-of-00003.gguf) shard_pat = re.compile(r"^(.*)-(\d{5})-of-(\d{5})\.gguf$") m = shard_pat.match(main.name) if m: prefix, _, num_total = m.group(1), m.group(2), m.group(3) sibling_pat = re.compile( r"^" + re.escape(prefix) + r"-\d{5}-of-" + re.escape(num_total) + r"\.gguf$" ) for sibling in main.parent.iterdir(): if sibling != main and sibling_pat.match(sibling.name): total += sibling.stat().st_size return total @staticmethod def _get_gpu_free_memory() -> list[tuple[int, int]]: """Query free memory per GPU via nvidia-smi. Returns list of (gpu_index, free_mib) sorted by index. Respects CUDA_VISIBLE_DEVICES if set. Returns empty list if nvidia-smi is not available. """ import os try: result = subprocess.run( [ "nvidia-smi", "--query-gpu=index,memory.free", "--format=csv,noheader,nounits", ], capture_output = True, text = True, timeout = 10, ) if result.returncode != 0: return [] # Parse which GPUs are allowed by existing CUDA_VISIBLE_DEVICES allowed = None cvd = os.environ.get("CUDA_VISIBLE_DEVICES") if cvd is not None and cvd.strip(): try: allowed = set(int(x.strip()) for x in cvd.split(",")) except ValueError: pass # Non-numeric (e.g., "GPU-uuid"), ignore filter gpus = [] for line in result.stdout.strip().splitlines(): parts = line.split(",") if len(parts) == 2: idx = int(parts[0].strip()) free_mib = int(parts[1].strip()) if allowed is not None and idx not in allowed: continue gpus.append((idx, free_mib)) return gpus except Exception: return [] @staticmethod def _select_gpus( model_size_bytes: int, gpus: list[tuple[int, int]], ) -> tuple[Optional[list[int]], bool]: """Pick GPU(s) for a model based on file size and free memory. Uses GGUF file size as a rough proxy for VRAM usage (actual usage is higher due to KV cache and compute buffers, but 70% threshold accounts for that). Returns (gpu_indices, use_fit): - ([1], False) model fits on 1 GPU at 70% of free - ([1, 2], False) model needs 2 GPUs - (None, True) model too large, let --fit handle it """ if not gpus: return None, True model_size_mib = model_size_bytes / (1024 * 1024) # Sort GPUs by free memory descending ranked = sorted(gpus, key = lambda g: g[1], reverse = True) # Try fitting on 1 GPU (70% of free memory threshold) if ranked[0][1] * 0.70 >= model_size_mib: return [ranked[0][0]], False # Try fitting on N GPUs (accumulate free memory from most-free) cumulative = 0 selected = [] for idx, free_mib in ranked: selected.append(idx) cumulative += free_mib * 0.70 if cumulative >= model_size_mib: return sorted(selected), False # Model is too large even for all GPUs, let --fit handle it return None, True # ── Variant fallback ──────────────────────────────────────────── @staticmethod def _find_smallest_fitting_variant( hf_repo: str, free_bytes: int, hf_token: Optional[str] = None, ) -> Optional[tuple[str, int]]: """Find the smallest GGUF variant (including all shards) that fits. Groups split shards by variant prefix and sums their sizes. For example, UD-Q4_K_XL with 9 shards of 50 GB each = 450 GB total. Returns (first_shard_filename, total_size_bytes) or None if nothing fits. """ import re try: from huggingface_hub import get_paths_info, list_repo_files files = list_repo_files(hf_repo, token = hf_token) gguf_files = [ f for f in files if f.endswith(".gguf") and "mmproj" not in f.lower() ] if not gguf_files: return None # Get sizes for all GGUF files path_infos = list(get_paths_info(hf_repo, gguf_files, token = hf_token)) size_map = {p.path: (p.size or 0) for p in path_infos} # Group files by variant: shards share a prefix before -NNNNN-of-NNNNN shard_pat = re.compile(r"^(.*)-\d{5}-of-\d{5}\.gguf$") variants: dict[str, list[str]] = {} for f in gguf_files: m = shard_pat.match(f) key = m.group(1) if m else f variants.setdefault(key, []).append(f) # Sum shard sizes per variant, track the first shard (for download) variant_sizes: list[tuple[str, int, list[str]]] = [] for key, shard_files in variants.items(): total = sum(size_map.get(f, 0) for f in shard_files) first = sorted(shard_files)[0] variant_sizes.append((first, total, shard_files)) # Sort by total size ascending and pick the smallest that fits variant_sizes.sort(key = lambda x: x[1]) for first_file, total_size, _ in variant_sizes: if total_size > 0 and total_size <= free_bytes: return first_file, total_size return None except Exception: return None # ── Port allocation ─────────────────────────────────────────── @staticmethod def _find_free_port() -> int: """Find an available TCP port.""" with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("127.0.0.1", 0)) return s.getsockname()[1] # ── Stdout drain (prevents pipe deadlock on Windows) ───────── def _drain_stdout(self): """ Read lines from the subprocess stdout in a background thread. This prevents a pipe-buffer deadlock on Windows where the default pipe buffer is only ~4 KB. Without draining, llama-server blocks on writes and never becomes healthy. """ try: for line in self._process.stdout: line = line.rstrip() if line: self._stdout_lines.append(line) logger.debug(f"[llama-server] {line}") except (ValueError, OSError): # Pipe closed — process is terminating pass # GGUF KV type sizes for fast skipping _GGUF_TYPE_SIZE = { 0: 1, 1: 1, 2: 2, 3: 2, 4: 4, 5: 4, 6: 4, 7: 1, 10: 8, 11: 8, 12: 8, } @staticmethod def _gguf_skip_value(f, vtype: int) -> None: """Skip a GGUF KV value without reading it.""" sz = LlamaCppBackend._GGUF_TYPE_SIZE.get(vtype) if sz is not None: f.seek(sz, 1) elif vtype == 8: # STRING slen = struct.unpack(" None: """Read context_length and chat_template from a GGUF file's KV header. Parses only the KV pairs we need (~30ms even for multi-GB files). For split GGUFs, metadata is always in shard 1. """ # Reset metadata from any previously loaded model so stale flags # (eg _supports_reasoning) do not carry over when switching models. self._context_length = None self._chat_template = None self._supports_reasoning = False self._supports_tools = False try: WANTED = {"general.architecture", "tokenizer.chat_template"} arch = None ctx_key = None with open(gguf_path, "rb") as f: magic = struct.unpack(" str: """Download GGUF file(s) from HuggingFace. Returns local path. Runs WITHOUT self._lock so that unload_model() can set _cancel_event at any time. Checks _cancel_event between each shard download. """ try: from huggingface_hub import hf_hub_download except ImportError: raise RuntimeError( "huggingface_hub is required for HF model loading. " "Install it with: pip install huggingface_hub" ) # Determine the filename from the variant gguf_filename = None gguf_extra_shards: list[str] = [] if hf_variant: try: import re from huggingface_hub import list_repo_files files = list_repo_files(hf_repo, token = hf_token) variant_lower = hf_variant.lower() boundary = re.compile( r"(? 0: cache_dir = os.environ.get( "HF_HUB_CACHE", str(Path.home() / ".cache" / "huggingface" / "hub"), ) Path(cache_dir).mkdir(parents = True, exist_ok = True) free_bytes = shutil.disk_usage(cache_dir).free total_gb = total_download_bytes / (1024**3) free_gb = free_bytes / (1024**3) logger.info( f"GGUF download: {total_gb:.1f} GB needed, " f"{free_gb:.1f} GB free on disk" ) if total_download_bytes > free_bytes: smaller = self._find_smallest_fitting_variant( hf_repo, free_bytes, hf_token, ) if smaller: fallback_file, fallback_size = smaller logger.info( f"Selected variant too large ({total_gb:.1f} GB), " f"falling back to {fallback_file} ({fallback_size / (1024**3):.1f} GB)" ) gguf_filename = fallback_file import re as _re _shard_pat = _re.compile(r"^(.*)-\d{5}-of-\d{5}\.gguf$") _m = _shard_pat.match(gguf_filename) _prefix = _m.group(1) if _m else None if _prefix: gguf_extra_shards = sorted( f for f in all_gguf_files if f.startswith(_prefix) and f != gguf_filename and "mmproj" not in f.lower() ) else: gguf_extra_shards = [] else: raise RuntimeError( f"Not enough disk space to download any variant. " f"Only {free_gb:.1f} GB free in {cache_dir}" ) except RuntimeError: raise except Exception as e: logger.warning(f"Could not check disk space: {e}") gguf_label = f"{hf_repo}/{gguf_filename}" + ( f" (+{len(gguf_extra_shards)} shards)" if gguf_extra_shards else "" ) logger.info(f"Resolving GGUF: {gguf_label}") try: if self._cancel_event.is_set(): raise RuntimeError("Cancelled") dl_start = time.monotonic() local_path = hf_hub_download( repo_id = hf_repo, filename = gguf_filename, token = hf_token, ) for shard in gguf_extra_shards: if self._cancel_event.is_set(): raise RuntimeError("Cancelled") logger.info(f"Resolving GGUF shard: {shard}") hf_hub_download( repo_id = hf_repo, filename = shard, token = hf_token, ) except RuntimeError as e: if "Cancelled" in str(e): raise raise RuntimeError( f"Failed to download GGUF file '{gguf_filename}' from {hf_repo}: {e}" ) except Exception as e: raise RuntimeError( f"Failed to download GGUF file '{gguf_filename}' from {hf_repo}: {e}" ) dl_elapsed = time.monotonic() - dl_start if dl_elapsed < 2.0: logger.info(f"GGUF resolved from cache: {local_path}") else: logger.info(f"GGUF downloaded in {dl_elapsed:.1f}s: {local_path}") return local_path def _download_mmproj( self, *, hf_repo: str, hf_token: Optional[str] = None, ) -> Optional[str]: """Download the mmproj (vision projection) file from a GGUF repo. Prefers mmproj-F16.gguf, falls back to any mmproj*.gguf file. Returns the local path, or None if no mmproj file exists. """ try: from huggingface_hub import hf_hub_download, list_repo_files files = list_repo_files(hf_repo, token = hf_token) mmproj_files = sorted( f for f in files if f.endswith(".gguf") and "mmproj" in f.lower() ) if not mmproj_files: return None # Prefer F16 variant target = None for f in mmproj_files: if "f16" in f.lower(): target = f break if target is None: target = mmproj_files[0] logger.info(f"Downloading mmproj: {hf_repo}/{target}") local_path = hf_hub_download( repo_id = hf_repo, filename = target, token = hf_token, ) return local_path except Exception as e: logger.warning(f"Could not download mmproj: {e}") return None # ── Lifecycle ───────────────────────────────────────────────── def load_model( self, *, # Local mode: pass a path to a .gguf file gguf_path: Optional[str] = None, # Vision projection (mmproj) for local vision models mmproj_path: Optional[str] = None, # HF mode: let llama-server download via -hf "repo:quant" hf_repo: Optional[str] = None, hf_variant: Optional[str] = None, hf_token: Optional[str] = None, # Common model_identifier: str, is_vision: bool = False, n_ctx: int = 4096, chat_template_override: Optional[str] = None, cache_type_kv: Optional[str] = None, n_threads: Optional[int] = None, n_gpu_layers: Optional[int] = None, # Accepted for caller compat, unused ) -> bool: """ Start llama-server with a GGUF model. Two modes: - Local: ``gguf_path="/path/to/model.gguf"`` → uses ``-m`` - HF: ``hf_repo="unsloth/gemma-3-4b-it-GGUF", hf_variant="Q4_K_M"`` → uses ``-hf`` In HF mode, llama-server handles downloading, caching, and auto-loading mmproj files for vision models. Returns True if server started and health check passed. """ self._cancel_event.clear() # ── Phase 1: kill old process (under lock, fast) ────────── with self._lock: self._kill_process() binary = self._find_llama_server_binary() if not binary: raise RuntimeError( "llama-server binary not found. " "Run setup.sh to build it, install llama.cpp, " "or set LLAMA_SERVER_PATH environment variable." ) # ── Phase 2: download (NO lock held, so cancel can proceed) ── if hf_repo: model_path = self._download_gguf( hf_repo = hf_repo, hf_variant = hf_variant, hf_token = hf_token, ) # Auto-download mmproj for vision models if is_vision and not mmproj_path: mmproj_path = self._download_mmproj( hf_repo = hf_repo, hf_token = hf_token, ) elif gguf_path: if not Path(gguf_path).is_file(): raise FileNotFoundError(f"GGUF file not found: {gguf_path}") model_path = gguf_path else: raise ValueError("Either gguf_path or hf_repo must be provided") # Set identifier early so _read_gguf_metadata can use it for DeepSeek detection self._model_identifier = model_identifier # Read GGUF metadata (context_length, chat_template) -- fast, header only self._read_gguf_metadata(model_path) # Check cancel after download if self._cancel_event.is_set(): logger.info("Load cancelled after download phase") return False # ── Phase 3: start llama-server (under lock) ────────────── with self._lock: # Re-check cancel inside lock if self._cancel_event.is_set(): logger.info("Load cancelled before server start") return False self._port = self._find_free_port() # Select GPU(s) based on model size and free memory try: model_size = self._get_gguf_size_bytes(model_path) gpus = self._get_gpu_free_memory() gpu_indices, use_fit = self._select_gpus(model_size, gpus) logger.info( f"GGUF size: {model_size / (1024**3):.1f} GB, " f"GPUs free: {gpus}, selected: {gpu_indices}, fit: {use_fit}" ) except Exception as e: logger.warning(f"GPU selection failed ({e}), using --fit on") gpu_indices, use_fit = None, True cmd = [ binary, "-m", model_path, "--port", str(self._port), "-c", "0", # 0 = use model's native context size "--parallel", "1", # Single-user studio, saves VRAM "--flash-attn", "on", # Force flash attention for speed ] if use_fit: cmd.extend(["--fit", "on"]) if n_threads is not None: cmd.extend(["--threads", str(n_threads)]) # Always enable Jinja chat template rendering for proper template support cmd.extend(["--jinja"]) # KV cache data type _valid_cache_types = { "f16", "bf16", "q8_0", "q4_0", "q4_1", "q5_0", "q5_1", "iq4_nl", "f32", } if cache_type_kv and cache_type_kv in _valid_cache_types: cmd.extend( ["--cache-type-k", cache_type_kv, "--cache-type-v", cache_type_kv] ) self._cache_type_kv = cache_type_kv logger.info(f"KV cache type: {cache_type_kv}") else: self._cache_type_kv = None # Apply custom chat template override if provided if chat_template_override: import tempfile self._chat_template_file = tempfile.NamedTemporaryFile( mode = "w", suffix = ".jinja", delete = False, prefix = "unsloth_chat_template_", ) self._chat_template_file.write(chat_template_override) self._chat_template_file.close() cmd.extend(["--chat-template-file", self._chat_template_file.name]) logger.info( f"Using custom chat template file: {self._chat_template_file.name}" ) # For reasoning models, set default thinking mode. # Qwen3.5 models below 9B (0.8B, 2B, 4B) disable thinking by default. # Only 9B and larger enable thinking. if self._supports_reasoning: import re thinking_default = True mid = (model_identifier or "").lower() if "qwen3.5" in mid: # Extract size like "0.8b", "4b", "35b" etc. size_match = re.search(r"(\d+\.?\d*)\s*b", mid) if size_match: size_val = float(size_match.group(1)) if size_val < 9: thinking_default = False self._reasoning_default = thinking_default cmd.extend( [ "--chat-template-kwargs", json.dumps({"enable_thinking": thinking_default}), ] ) logger.info( f"Reasoning model: enable_thinking={thinking_default} by default" ) if mmproj_path: if not Path(mmproj_path).is_file(): logger.warning(f"mmproj file not found: {mmproj_path}") else: cmd.extend(["--mmproj", mmproj_path]) logger.info(f"Using mmproj for vision: {mmproj_path}") logger.info(f"Starting llama-server: {' '.join(cmd)}") # Set library paths so llama-server can find its shared libs and CUDA DLLs import os import sys env = os.environ.copy() binary_dir = str(Path(binary).parent) if sys.platform == "win32": # On Windows, CUDA DLLs (cublas64_12.dll, cudart64_12.dll, etc.) # must be on PATH. Add CUDA_PATH\bin if available. path_dirs = [binary_dir] cuda_path = os.environ.get("CUDA_PATH", "") if cuda_path: cuda_bin = os.path.join(cuda_path, "bin") if os.path.isdir(cuda_bin): path_dirs.append(cuda_bin) # Some CUDA installs put DLLs in bin\x64 cuda_bin_x64 = os.path.join(cuda_path, "bin", "x64") if os.path.isdir(cuda_bin_x64): path_dirs.append(cuda_bin_x64) existing_path = env.get("PATH", "") env["PATH"] = ";".join(path_dirs) + ";" + existing_path else: # Linux: set LD_LIBRARY_PATH for shared libs next to the binary # and CUDA runtime libs (libcudart, libcublas, etc.) import platform lib_dirs = [binary_dir] _arch = platform.machine() # x86_64, aarch64, etc. for cuda_lib in [ "/usr/local/cuda/lib64", f"/usr/local/cuda/targets/{_arch}-linux/lib", # Fallback CUDA compat paths (e.g. binary built with # CUDA 12 on a system where default /usr/local/cuda # points to CUDA 13+). "/usr/local/cuda-12/lib64", "/usr/local/cuda-12.8/lib64", f"/usr/local/cuda-12/targets/{_arch}-linux/lib", f"/usr/local/cuda-12.8/targets/{_arch}-linux/lib", ]: if os.path.isdir(cuda_lib): lib_dirs.append(cuda_lib) existing_ld = env.get("LD_LIBRARY_PATH", "") new_ld = ":".join(lib_dirs) env["LD_LIBRARY_PATH"] = ( f"{new_ld}:{existing_ld}" if existing_ld else new_ld ) # Pin to selected GPU(s) via CUDA_VISIBLE_DEVICES if gpu_indices is not None: env["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in gpu_indices) self._stdout_lines = [] self._process = subprocess.Popen( cmd, stdout = subprocess.PIPE, stderr = subprocess.STDOUT, text = True, env = env, ) # Start background thread to drain stdout and prevent pipe deadlock self._stdout_thread = threading.Thread( target = self._drain_stdout, daemon = True, name = "llama-stdout" ) self._stdout_thread.start() self._gguf_path = gguf_path self._hf_repo = hf_repo self._hf_variant = hf_variant self._is_vision = is_vision self._model_identifier = model_identifier # Wait for llama-server to become healthy if not self._wait_for_health(timeout = 120.0): self._kill_process() raise RuntimeError( "llama-server failed to start. " "Check that the GGUF file is valid and you have enough memory." ) self._healthy = True logger.info( f"llama-server ready on port {self._port} " f"for model '{model_identifier}'" ) return True def unload_model(self) -> bool: """Terminate the llama-server subprocess and cancel any in-flight download.""" self._cancel_event.set() with self._lock: self._kill_process() logger.info(f"Unloaded GGUF model: {self._model_identifier}") self._model_identifier = None self._gguf_path = None self._hf_repo = None self._hf_variant = None self._is_vision = False self._is_audio = False self._audio_type = None self._port = None self._healthy = False self._context_length = None self._chat_template = None self._supports_reasoning = False self._supports_tools = False self._cache_type_kv = None # Clean up temp chat template file if hasattr(self, "_chat_template_file") and self._chat_template_file: try: import os os.unlink(self._chat_template_file.name) except Exception: pass self._chat_template_file = None # Free audio codec GPU memory if LlamaCppBackend._codec_mgr is not None: LlamaCppBackend._codec_mgr.unload() LlamaCppBackend._codec_mgr = None import torch if torch.cuda.is_available(): torch.cuda.empty_cache() return True def _kill_process(self): """Terminate the subprocess if running.""" if self._process is None: return try: self._process.terminate() self._process.wait(timeout = 5) except subprocess.TimeoutExpired: logger.warning("llama-server did not exit on SIGTERM, sending SIGKILL") self._process.kill() self._process.wait(timeout = 5) except Exception as e: logger.warning(f"Error killing llama-server process: {e}") finally: self._process = None if self._stdout_thread is not None: self._stdout_thread.join(timeout = 2) self._stdout_thread = None @staticmethod def _kill_orphaned_servers(): """Kill orphaned llama-server processes started by studio. Only kills processes whose binary lives under ~/.unsloth/llama.cpp/ to avoid terminating unrelated llama-server instances on the machine. """ import os import signal try: # Use pgrep with full command match to identify studio-managed servers result = subprocess.run( ["pgrep", "-a", "-f", "llama-server"], capture_output = True, text = True, timeout = 5, ) if result.returncode != 0: return for line in result.stdout.strip().splitlines(): parts = line.strip().split(None, 1) if len(parts) < 2: continue pid = int(parts[0]) cmdline = parts[1] if pid == os.getpid(): continue # Only kill if it's a studio-managed server (lives under .unsloth/) if ".unsloth/" not in cmdline and "unsloth" not in cmdline.lower(): continue try: os.kill(pid, signal.SIGKILL) logger.info(f"Killed orphaned llama-server process (pid={pid})") except ProcessLookupError: pass except PermissionError: pass except Exception: pass def _cleanup(self): """atexit handler to ensure llama-server is terminated.""" self._kill_process() def _wait_for_health(self, timeout: float = 120.0, interval: float = 0.5) -> bool: """ Poll llama-server's /health endpoint until it responds 200. Also monitors subprocess for early exit/crash. """ deadline = time.monotonic() + timeout url = f"http://127.0.0.1:{self._port}/health" while time.monotonic() < deadline: # Check if process crashed if self._process.poll() is not None: # Give the drain thread a moment to collect final output if self._stdout_thread is not None: self._stdout_thread.join(timeout = 2) output = "\n".join(self._stdout_lines[-50:]) logger.error( f"llama-server exited with code {self._process.returncode}. " f"Output: {output[:2000]}" ) return False try: resp = httpx.get(url, timeout = 2.0) if resp.status_code == 200: return True except (httpx.ConnectError, httpx.TimeoutException): pass time.sleep(interval) logger.error(f"llama-server health check timed out after {timeout}s") return False # ── Message building (OpenAI format) ────────────────────────── @staticmethod def _parse_tool_calls_from_text(content: str) -> list[dict]: """ Parse tool calls from XML markup in content text. Handles formats like: {"name":"web_search","arguments":{"query":"..."}} ... Closing tags (, , ) are all optional since models frequently omit them. """ import re tool_calls = [] # Pattern 1: JSON inside tags. # Use balanced-brace extraction that skips braces inside JSON strings. for m in re.finditer(r"\s*\{", content): brace_start = m.end() - 1 # position of the opening { depth, i = 0, brace_start in_string = False while i < len(content): ch = content[i] if in_string: if ch == "\\" and i + 1 < len(content): i += 2 # skip escaped character continue if ch == '"': in_string = False elif ch == '"': in_string = True elif ch == "{": depth += 1 elif ch == "}": depth -= 1 if depth == 0: break i += 1 if depth == 0: json_str = content[brace_start : i + 1] try: obj = json.loads(json_str) tc = { "id": f"call_{len(tool_calls)}", "type": "function", "function": { "name": obj.get("name", ""), "arguments": obj.get("arguments", {}), }, } if isinstance(tc["function"]["arguments"], dict): tc["function"]["arguments"] = json.dumps( tc["function"]["arguments"] ) tool_calls.append(tc) except (json.JSONDecodeError, ValueError): pass # Pattern 2: XML-style value # All closing tags optional -- models frequently omit , # , and/or . if not tool_calls: # Step 1: Find all positions and extract their bodies. # Body boundary: use only or next as a boundary because # code parameter values can contain that literal string. # After extracting, we trim a trailing if present. func_starts = list(re.finditer(r"\s*", content)) for idx, fm in enumerate(func_starts): func_name = fm.group(1) body_start = fm.end() # Hard boundaries: next next_func = ( func_starts[idx + 1].start() if idx + 1 < len(func_starts) else len(content) ) end_tag = re.search(r"", content[body_start:]) if end_tag: body_end = body_start + end_tag.start() else: body_end = len(content) body_end = min(body_end, next_func) body = content[body_start:body_end] # Trim trailing if present (it's the real closing tag) body = re.sub(r"\s*\s*$", "", body) # Step 2: Extract parameters from body. # For single-parameter functions (the common case: code, command, # query), use body end as the only boundary to avoid false matches # on inside code strings. arguments = {} param_starts = list(re.finditer(r"\s*", body)) if len(param_starts) == 1: # Single parameter: value is everything from after the tag # to end of body, trimming any trailing . pm = param_starts[0] val = body[pm.end() :] val = re.sub(r"\s*\s*$", "", val) arguments[pm.group(1)] = val.strip() else: for pidx, pm in enumerate(param_starts): param_name = pm.group(1) val_start = pm.end() # Value ends at next if present val = re.sub(r"\s*\s*$", "", val) arguments[param_name] = val.strip() tc = { "id": f"call_{len(tool_calls)}", "type": "function", "function": { "name": func_name, "arguments": json.dumps(arguments), }, } tool_calls.append(tc) return tool_calls @staticmethod def _build_openai_messages( messages: list[dict], image_b64: Optional[str] = None, ) -> list[dict]: """ Build OpenAI-format messages, optionally injecting an image_url content part into the last user message for vision models. If no image is provided, returns messages as-is. """ if not image_b64: return messages # Find the last user message and convert to multimodal content parts result = [msg.copy() for msg in messages] last_user_idx = None for i, msg in enumerate(result): if msg["role"] == "user": last_user_idx = i if last_user_idx is not None: text_content = result[last_user_idx].get("content", "") result[last_user_idx]["content"] = [ {"type": "text", "text": text_content}, { "type": "image_url", "image_url": { "url": f"data:image/png;base64,{image_b64}", }, }, ] return result # ── Generation (proxy to llama-server) ──────────────────────── @staticmethod def _iter_text_cancellable( response: "httpx.Response", cancel_event: Optional[threading.Event] = None, ) -> Generator[str, None, None]: """Iterate over an httpx streaming response with cancel support. Checks cancel_event between chunks and on ReadTimeout. The cancel watcher in _stream_with_retry also calls response.close() on cancel, which unblocks iter_text() once the response exists. During normal streaming llama-server sends tokens frequently, so the cancel check between chunks is the primary mechanism. """ text_iter = response.iter_text() while True: if cancel_event is not None and cancel_event.is_set(): response.close() return try: chunk = next(text_iter) yield chunk except StopIteration: return except httpx.ReadTimeout: # No data within the timeout window -- just loop back # and re-check cancel_event. continue @staticmethod @contextlib.contextmanager def _stream_with_retry( client: "httpx.Client", url: str, payload: dict, cancel_event: Optional[threading.Event] = None, ): """Open an httpx streaming POST with cancel support. Sends the request once with a long read timeout (120 s) so prompt processing (prefill) can finish without triggering a retry storm. The previous 0.5 s timeout caused duplicate POST requests every half second, forcing llama-server to restart processing each time. A background watcher thread provides cancel by closing the response when cancel_event is set. Limitation: httpx does not allow interrupting a blocked read from another thread before the response object exists, so cancel during the initial header wait (prefill phase) only takes effect once headers arrive. After that, response.close() unblocks reads promptly. In practice llama-server prefill is 1-5 s for typical prompts, during which cancel is deferred -- still much better than the old retry storm which made prefill slower. """ if cancel_event is not None and cancel_event.is_set(): raise GeneratorExit # Background watcher: close the response if cancel is requested. # Only effective after response headers arrive (httpx limitation). _cancel_closed = threading.Event() _response_ref: list = [None] def _cancel_watcher(): while not _cancel_closed.is_set(): if cancel_event.wait(timeout = 0.3): # Cancel requested. Keep polling until the response object # exists so we can close it, or until the main thread # finishes on its own (_cancel_closed is set in finally). while not _cancel_closed.is_set(): r = _response_ref[0] if r is not None: try: r.close() return except Exception as e: logger.debug( f"Error closing response in cancel watcher: {e}" ) # Response not created yet -- wait briefly and retry _cancel_closed.wait(timeout = 0.1) return watcher = None if cancel_event is not None: watcher = threading.Thread( target = _cancel_watcher, daemon = True, name = "prefill-cancel" ) watcher.start() try: # Long read timeout so prefill (prompt processing) can finish # without triggering a retry storm. Cancel during both # prefill and streaming is handled by the watcher thread # which closes the response, unblocking any httpx read. prefill_timeout = httpx.Timeout( connect = 30, read = 120.0, write = 10, pool = 10, ) with client.stream( "POST", url, json = payload, timeout = prefill_timeout ) as response: _response_ref[0] = response if cancel_event is not None and cancel_event.is_set(): raise GeneratorExit yield response return except (httpx.ReadError, httpx.RemoteProtocolError, httpx.CloseError): # Response was closed by the cancel watcher if cancel_event is not None and cancel_event.is_set(): raise GeneratorExit raise finally: _cancel_closed.set() def generate_chat_completion( self, messages: list[dict], image_b64: Optional[str] = None, temperature: float = 0.6, top_p: float = 0.95, top_k: int = 20, min_p: float = 0.01, max_tokens: Optional[int] = None, repetition_penalty: float = 1.0, presence_penalty: float = 0.0, stop: Optional[list[str]] = None, cancel_event: Optional[threading.Event] = None, enable_thinking: Optional[bool] = None, ) -> Generator[str, None, None]: """ Send a chat completion request to llama-server and stream tokens back. Uses /v1/chat/completions — llama-server handles chat template application and vision (multimodal image_url parts) natively. Yields cumulative text (matching InferenceBackend's convention). """ if not self.is_loaded: raise RuntimeError("llama-server is not loaded") openai_messages = self._build_openai_messages(messages, image_b64) payload = { "messages": openai_messages, "stream": True, "temperature": temperature, "top_p": top_p, "top_k": top_k if top_k >= 0 else 0, "min_p": min_p, "repeat_penalty": repetition_penalty, "presence_penalty": presence_penalty, } # Pass enable_thinking per-request for reasoning models if self._supports_reasoning and enable_thinking is not None: payload["chat_template_kwargs"] = {"enable_thinking": enable_thinking} if max_tokens is not None: payload["max_tokens"] = max_tokens if stop: payload["stop"] = stop url = f"{self.base_url}/v1/chat/completions" cumulative = "" in_thinking = False try: # _stream_with_retry uses a 120 s read timeout so prefill # can finish. Cancel during streaming is handled by the # watcher thread (closes the response on cancel_event). stream_timeout = httpx.Timeout(connect = 10, read = 0.5, write = 10, pool = 10) with httpx.Client(timeout = stream_timeout) as client: with self._stream_with_retry( client, url, payload, cancel_event ) as response: if response.status_code != 200: error_body = response.read().decode() raise RuntimeError( f"llama-server returned {response.status_code}: {error_body}" ) buffer = "" has_content_tokens = False reasoning_text = "" for raw_chunk in self._iter_text_cancellable( response, cancel_event ): buffer += raw_chunk while "\n" in buffer: line, buffer = buffer.split("\n", 1) line = line.strip() if not line: continue if line == "data: [DONE]": if in_thinking: if has_content_tokens: # Real thinking + content: close the tag cumulative += "" yield cumulative else: # Only reasoning_content, no content tokens: # the model put its entire reply in reasoning # (e.g. Qwen3 always-think mode). Show it # as the main response, not as a thinking block. cumulative = reasoning_text yield cumulative return if not line.startswith("data: "): continue try: data = json.loads(line[6:]) choices = data.get("choices", []) if choices: delta = choices[0].get("delta", {}) # Handle reasoning/thinking tokens # llama-server sends these as "reasoning_content" # Wrap in tags for the frontend parser reasoning = delta.get("reasoning_content", "") if reasoning: reasoning_text += reasoning if not in_thinking: cumulative += "" in_thinking = True cumulative += reasoning yield cumulative token = delta.get("content", "") if token: has_content_tokens = True if in_thinking: cumulative += "" in_thinking = False cumulative += token yield cumulative except json.JSONDecodeError: logger.debug( f"Skipping malformed SSE line: {line[:100]}" ) except httpx.ConnectError: raise RuntimeError("Lost connection to llama-server") except Exception as e: if cancel_event is not None and cancel_event.is_set(): return raise # ── Tool-calling agentic loop ────────────────────────────── def generate_chat_completion_with_tools( self, messages: list[dict], tools: list[dict], temperature: float = 0.6, top_p: float = 0.95, top_k: int = 20, min_p: float = 0.01, max_tokens: Optional[int] = None, repetition_penalty: float = 1.0, presence_penalty: float = 0.0, stop: Optional[list[str]] = None, cancel_event: Optional[threading.Event] = None, enable_thinking: Optional[bool] = None, max_tool_iterations: int = 10, auto_heal_tool_calls: bool = True, tool_call_timeout: int = 300, session_id: Optional[str] = None, ) -> Generator[dict, None, None]: """ Agentic loop: let the model call tools, execute them, and continue. Yields dicts with: {"type": "status", "text": "Searching: ..."} -- tool status updates {"type": "content", "text": "token"} -- streamed content tokens (cumulative) {"type": "reasoning", "text": "token"} -- streamed reasoning tokens (cumulative) """ from core.inference.tools import execute_tool if not self.is_loaded: raise RuntimeError("llama-server is not loaded") conversation = list(messages) url = f"{self.base_url}/v1/chat/completions" for iteration in range(max_tool_iterations): if cancel_event is not None and cancel_event.is_set(): return # Build payload for non-streaming tool detection pass payload = { "messages": conversation, "stream": False, "temperature": temperature, "top_p": top_p, "top_k": top_k if top_k >= 0 else 0, "min_p": min_p, "repeat_penalty": repetition_penalty, "presence_penalty": presence_penalty, "tools": tools, "tool_choice": "auto", } if self._supports_reasoning and enable_thinking is not None: payload["chat_template_kwargs"] = {"enable_thinking": enable_thinking} if max_tokens is not None: payload["max_tokens"] = max_tokens if stop: payload["stop"] = stop try: with httpx.Client(timeout = None) as client: resp = client.post(url, json = payload) if resp.status_code != 200: raise RuntimeError( f"llama-server returned {resp.status_code}: {resp.text}" ) data = resp.json() except httpx.ConnectError: raise RuntimeError("Lost connection to llama-server") choices = data.get("choices", []) if not choices: return choice = choices[0] finish_reason = choice.get("finish_reason", "") message = choice.get("message", {}) # If model wants to call tools tool_calls = message.get("tool_calls") # Fallback: detect tool calls embedded as XML/text in content # Some models output XML instead of structured tool_calls, # or bare tags without wrapper. content_text = message.get("content", "") or "" if ( auto_heal_tool_calls and not tool_calls and ("" in content_text or " blocks since they # can contain arbitrary content including code. import re # Strip ... blocks (greedy inside) content_text = re.sub( r".*?", "", content_text, flags = re.DOTALL, ) # Strip unterminated ... to end content_text = re.sub( r".*$", "", content_text, flags = re.DOTALL, ) # Strip bare ... blocks content_text = re.sub( r".*?", "", content_text, flags = re.DOTALL, ) # Strip unterminated bare to end content_text = re.sub( r".*$", "", content_text, flags = re.DOTALL, ).strip() logger.info( f"Parsed {len(tool_calls)} tool call(s) from content text" ) if finish_reason == "tool_calls" or (tool_calls and len(tool_calls) > 0): # Append the assistant message with tool_calls to conversation assistant_msg = {"role": "assistant", "content": content_text} if tool_calls: assistant_msg["tool_calls"] = tool_calls conversation.append(assistant_msg) # Execute each tool call for tc in tool_calls or []: func = tc.get("function", {}) tool_name = func.get("name", "") raw_args = func.get("arguments", {}) # Handle arguments as either string or dict if isinstance(raw_args, str): try: arguments = json.loads(raw_args) except (json.JSONDecodeError, ValueError): if auto_heal_tool_calls: arguments = {"query": raw_args} else: arguments = {"raw": raw_args} else: arguments = raw_args # Yield status update if tool_name == "web_search": status_text = f"Searching: {arguments.get('query', '')}" elif tool_name == "python": preview = ( (arguments.get("code") or "").strip().split("\n")[0][:60] ) status_text = ( f"Running Python: {preview}" if preview else "Running Python..." ) elif tool_name == "terminal": cmd_preview = (arguments.get("command") or "")[:60] status_text = ( f"Running: {cmd_preview}" if cmd_preview else "Running command..." ) else: status_text = f"Calling: {tool_name}" yield {"type": "status", "text": status_text} # Emit tool_start so the frontend can record inputs yield { "type": "tool_start", "tool_name": tool_name, "tool_call_id": tc.get("id", ""), "arguments": arguments, } # Execute the tool _effective_timeout = ( None if tool_call_timeout >= 9999 else tool_call_timeout ) result = execute_tool( tool_name, arguments, cancel_event = cancel_event, timeout = _effective_timeout, session_id = session_id, ) # Emit tool_end so the frontend can record outputs yield { "type": "tool_end", "tool_name": tool_name, "tool_call_id": tc.get("id", ""), "result": result, } # Append tool result to conversation tool_msg = { "role": "tool", "name": tool_name, "content": result, } tool_call_id = tc.get("id") if tool_call_id: tool_msg["tool_call_id"] = tool_call_id conversation.append(tool_msg) # Continue the loop to let model respond with context continue # No tool calls -- model answered directly. # If no tools were executed at all, just yield the content # from this response instead of making a redundant second request. if iteration == 0 and content_text: yield {"type": "status", "text": ""} yield {"type": "content", "text": content_text} return # Tools were called in previous iterations; do a final # streaming pass so the model can synthesize a response # incorporating the tool results. break # Clear status yield {"type": "status", "text": ""} # Final streaming pass with the full conversation context stream_payload = { "messages": conversation, "stream": True, "temperature": temperature, "top_p": top_p, "top_k": top_k if top_k >= 0 else 0, "min_p": min_p, "repeat_penalty": repetition_penalty, "presence_penalty": presence_penalty, } if self._supports_reasoning and enable_thinking is not None: stream_payload["chat_template_kwargs"] = { "enable_thinking": enable_thinking } if max_tokens is not None: stream_payload["max_tokens"] = max_tokens if stop: stream_payload["stop"] = stop import re as _re_final # Closed blocks only -- safe to strip mid-stream without shrinking later. _TOOL_CLOSED_PATTERNS = [ _re_final.compile(r".*?", _re_final.DOTALL), _re_final.compile(r".*?", _re_final.DOTALL), ] # Open-ended patterns strip from an opening tag to end-of-string. # Only applied on the final flush to avoid non-monotonic shrinking. _TOOL_ALL_PATTERNS = _TOOL_CLOSED_PATTERNS + [ _re_final.compile(r".*$", _re_final.DOTALL), _re_final.compile(r".*$", _re_final.DOTALL), ] def _strip_tool_markup(text: str, *, final: bool = False) -> str: if not auto_heal_tool_calls: return text patterns = _TOOL_ALL_PATTERNS if final else _TOOL_CLOSED_PATTERNS for pat in patterns: text = pat.sub("", text) return text.strip() if final else text cumulative = "" _last_emitted = "" in_thinking = False has_content_tokens = False reasoning_text = "" try: stream_timeout = httpx.Timeout(connect = 10, read = 0.5, write = 10, pool = 10) with httpx.Client(timeout = stream_timeout) as client: with self._stream_with_retry( client, url, stream_payload, cancel_event ) as response: if response.status_code != 200: error_body = response.read().decode() raise RuntimeError( f"llama-server returned {response.status_code}: {error_body}" ) buffer = "" for raw_chunk in self._iter_text_cancellable( response, cancel_event ): buffer += raw_chunk while "\n" in buffer: line, buffer = buffer.split("\n", 1) line = line.strip() if not line: continue if line == "data: [DONE]": if in_thinking: if has_content_tokens: cumulative += "" yield { "type": "content", "text": _strip_tool_markup( cumulative, final = True ), } else: cumulative = reasoning_text yield {"type": "content", "text": cumulative} return if not line.startswith("data: "): continue try: chunk_data = json.loads(line[6:]) choices = chunk_data.get("choices", []) if choices: delta = choices[0].get("delta", {}) reasoning = delta.get("reasoning_content", "") if reasoning: reasoning_text += reasoning if not in_thinking: cumulative += "" in_thinking = True cumulative += reasoning yield {"type": "content", "text": cumulative} token = delta.get("content", "") if token: has_content_tokens = True if in_thinking: cumulative += "" in_thinking = False cumulative += token cleaned = _strip_tool_markup(cumulative) # Only emit when cleaned text grows (monotonic). if len(cleaned) > len(_last_emitted): _last_emitted = cleaned yield {"type": "content", "text": cleaned} except json.JSONDecodeError: logger.debug( f"Skipping malformed SSE line: {line[:100]}" ) except httpx.ConnectError: raise RuntimeError("Lost connection to llama-server") except Exception as e: if cancel_event is not None and cancel_event.is_set(): return raise # ── TTS support ──────────────────────────────────────────── def detect_audio_type(self) -> Optional[str]: """Detect audio/TTS codec by probing the loaded model's vocabulary.""" if not self.is_loaded: return None try: with httpx.Client(timeout = 10) as client: def _detok(tid: int) -> str: r = client.post( f"{self.base_url}/detokenize", json = {"tokens": [tid]} ) return r.json().get("content", "") if r.status_code == 200 else "" def _tok(text: str) -> list[int]: r = client.post( f"{self.base_url}/tokenize", json = {"content": text, "add_special": False}, ) return r.json().get("tokens", []) if r.status_code == 200 else [] # Check codec-specific tokens (not generic ones that may exist in non-audio models) if "")) == 1 and len(_tok("<|audio_eos|>")) == 1: return "csm" if len(_tok("<|startoftranscript|>")) == 1: return "whisper" if ( len(_tok("<|bicodec_semantic_0|>")) == 1 and len(_tok("<|bicodec_global_0|>")) == 1 ): return "bicodec" if len(_tok("<|c1_0|>")) == 1 and len(_tok("<|c2_0|>")) == 1: return "dac" except Exception as e: logger.debug(f"Audio type detection failed: {e}") return None # Prompt format per codec: (template, stop_tokens, needs_token_ids) # Matches prompts in InferenceBackend._generate_snac/bicodec/dac _TTS_PROMPTS = { "snac": ( "{text}<|eot_id|>", [""], True, ), "bicodec": ( "<|task_tts|><|start_content|>{text}<|end_content|><|start_global_token|>", ["<|im_end|>", ""], False, ), "dac": ( "<|im_start|>\n<|text_start|>{text}<|text_end|>\n<|audio_start|><|global_features_start|>\n", ["<|im_end|>", "<|audio_end|>"], False, ), } _codec_mgr = None # Shared AudioCodecManager instance def init_audio_codec(self, audio_type: str) -> None: """Load the audio codec at model load time (mirrors non-GGUF path).""" import torch from core.inference.audio_codecs import AudioCodecManager if LlamaCppBackend._codec_mgr is None: LlamaCppBackend._codec_mgr = AudioCodecManager() device = "cuda" if torch.cuda.is_available() else "cpu" model_repo_path = None # BiCodec needs a repo with BiCodec/ weights — download canonical SparkTTS if audio_type == "bicodec": from huggingface_hub import snapshot_download import os repo_path = snapshot_download( "unsloth/Spark-TTS-0.5B", local_dir = "Spark-TTS-0.5B" ) model_repo_path = os.path.abspath(repo_path) LlamaCppBackend._codec_mgr.load_codec( audio_type, device, model_repo_path = model_repo_path ) logger.info(f"Loaded audio codec for GGUF TTS: {audio_type}") def generate_audio_response( self, text: str, audio_type: str, temperature: float = 0.6, top_p: float = 0.95, top_k: int = 50, min_p: float = 0.0, max_new_tokens: int = 2048, repetition_penalty: float = 1.1, ) -> tuple: """ Generate TTS audio via llama-server /completion + codec decoding. Returns (wav_bytes, sample_rate). """ if audio_type not in self._TTS_PROMPTS: raise RuntimeError(f"GGUF TTS does not support '{audio_type}' codec.") tpl, stop, need_ids = self._TTS_PROMPTS[audio_type] payload: dict = { "prompt": tpl.format(text = text), "stream": False, "n_predict": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k if top_k >= 0 else 0, "min_p": min_p, "repeat_penalty": repetition_penalty, } if stop: payload["stop"] = stop if need_ids: payload["n_probs"] = 1 with httpx.Client(timeout = httpx.Timeout(300, connect = 10)) as client: resp = client.post(f"{self.base_url}/completion", json = payload) if resp.status_code != 200: raise RuntimeError( f"llama-server returned {resp.status_code}: {resp.text}" ) data = resp.json() token_ids = ( [p["id"] for p in data.get("completion_probabilities", []) if "id" in p] if need_ids else None ) import torch device = "cuda" if torch.cuda.is_available() else "cpu" return LlamaCppBackend._codec_mgr.decode( audio_type, device, token_ids = token_ids, text = data.get("content", "") ) ================================================ FILE: studio/backend/core/inference/orchestrator.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Inference orchestrator — subprocess-based. Provides the same API as InferenceBackend, but delegates all ML work to a persistent subprocess. The subprocess is spawned on first model load and stays alive for subsequent requests. When switching between models that need different transformers versions (e.g. GLM-4.7-Flash needs 5.x, Qwen needs 4.57.x), the old subprocess is killed and a new one is spawned with the correct version. Pattern follows core/training/training.py. """ import atexit import base64 import structlog from loggers import get_logger import multiprocessing as mp import queue import threading import time import uuid from io import BytesIO from pathlib import Path from typing import Any, Generator, Optional, Tuple, Union logger = get_logger(__name__) _CTX = mp.get_context("spawn") # Dispatcher timeout constants (seconds) _DISPATCH_READ_TIMEOUT = 30.0 _DISPATCH_POLL_INTERVAL = 0.5 _DISPATCH_STOP_TIMEOUT = 5.0 _DISPATCH_IDLE_TIMEOUT = 30.0 _DISPATCH_DRAIN_TIMEOUT = 5.0 class InferenceOrchestrator: """ Inference backend orchestrator — subprocess-based. Exposes the same API surface as InferenceBackend so routes/inference.py needs minimal changes. Internally, all heavy ML operations happen in a persistent subprocess. """ def __init__(self): # Subprocess state self._proc: Optional[mp.Process] = None self._cmd_queue: Any = None self._resp_queue: Any = None self._cancel_event: Any = None # mp.Event — set to cancel generation instantly self._lock = threading.Lock() self._gen_lock = ( threading.Lock() ) # Serializes generation — one request at a time # Dispatcher state — for compare mode (adapter-controlled requests). # Instead of serializing via _gen_lock, adapter-controlled requests # send commands directly to the subprocess and read from per-request # mailboxes. A dispatcher thread routes resp_queue events by request_id. self._mailboxes: dict[str, queue.Queue] = {} self._mailbox_lock = threading.Lock() # Protects _mailboxes dict self._dispatcher_thread: Optional[threading.Thread] = None self._dispatcher_stop = threading.Event() # Local state mirrors (updated from subprocess responses) self.active_model_name: Optional[str] = None self.models: dict = {} self.loading_models: set = set() self.loaded_local_models: list = [] from core.inference.defaults import get_default_models self._static_models = get_default_models() self._top_gguf_cache: Optional[list[str]] = None self._top_hub_cache: Optional[list[str]] = None self._top_models_ready = threading.Event() # Version tracking for subprocess reuse self._current_transformers_major: Optional[str] = None # "4" or "5" atexit.register(self._cleanup) logger.info("InferenceOrchestrator initialized (subprocess mode)") # Kick off background fetch of top models from HF threading.Thread( target = self._fetch_top_models, daemon = True, name = "top-models" ).start() # ------------------------------------------------------------------ # Default models (top GGUFs fetched dynamically from HF) # ------------------------------------------------------------------ @property def default_models(self) -> list[str]: # Wait up to 5s for background HF fetch to finish self._top_models_ready.wait(timeout = 5) top_gguf = self._top_gguf_cache or [] top_hub = self._top_hub_cache or [] # GGUFs first, then hub models, then static fallbacks. # Send extras so the frontend still has 4 per category # after removing already-downloaded models. result: list[str] = [] seen: set[str] = set() for m in top_gguf + top_hub + self._static_models: if m not in seen: result.append(m) seen.add(m) return result def _fetch_top_models(self) -> None: """Fetch top GGUF and non-GGUF repos from unsloth by downloads.""" try: import httpx resp = httpx.get( "https://huggingface.co/api/models", params = { "author": "unsloth", "sort": "downloads", "direction": "-1", "limit": "80", }, timeout = 15, ) if resp.status_code == 200: models = resp.json() # Top 40 GGUFs - frontend pages through them on-demand via # infinite scroll, so we send a deep pool. gguf_ids = [ m["id"] for m in models if m.get("id", "").upper().endswith("-GGUF") ][:40] # Top 40 non-GGUF hub models hub_ids = [ m["id"] for m in models if not m.get("id", "").upper().endswith("-GGUF") ][:40] if gguf_ids: self._top_gguf_cache = gguf_ids logger.info("Top GGUF models: %s", gguf_ids) if hub_ids: self._top_hub_cache = hub_ids logger.info("Top hub models: %s", hub_ids) except Exception as e: logger.warning("Failed to fetch top models: %s", e) finally: self._top_models_ready.set() # ------------------------------------------------------------------ # Subprocess lifecycle # ------------------------------------------------------------------ def _spawn_subprocess(self, config: dict) -> None: """Spawn a new inference subprocess.""" from .worker import run_inference_process self._cmd_queue = _CTX.Queue() self._resp_queue = _CTX.Queue() self._cancel_event = _CTX.Event() self._proc = _CTX.Process( target = run_inference_process, kwargs = { "cmd_queue": self._cmd_queue, "resp_queue": self._resp_queue, "cancel_event": self._cancel_event, "config": config, }, daemon = True, ) self._proc.start() logger.info("Inference subprocess started (pid=%s)", self._proc.pid) def _cancel_generation(self) -> None: """Cancel any ongoing generation in the subprocess (instant).""" if self._cancel_event is not None: self._cancel_event.set() def _shutdown_subprocess(self, timeout: float = 10.0) -> None: """Gracefully shut down the inference subprocess.""" self._stop_dispatcher() # Stop dispatcher before killing subprocess if self._proc is None or not self._proc.is_alive(): self._proc = None return # 1. Cancel any ongoing generation first (instant via mp.Event) self._cancel_generation() time.sleep(0.5) # Brief wait for generation to stop # 2. Drain stale responses from queue self._drain_queue() # 3. Send shutdown command try: self._cmd_queue.put({"type": "shutdown"}) except (OSError, ValueError): pass # 4. Wait for graceful shutdown try: self._proc.join(timeout = timeout) except Exception: pass # 5. Force kill if still alive if self._proc is not None and self._proc.is_alive(): logger.warning("Inference subprocess did not exit gracefully, terminating") try: self._proc.terminate() self._proc.join(timeout = 5) except Exception: pass if self._proc is not None and self._proc.is_alive(): logger.warning("Subprocess still alive after terminate, killing") try: self._proc.kill() self._proc.join(timeout = 3) except Exception: pass self._proc = None self._cmd_queue = None self._resp_queue = None self._cancel_event = None logger.info("Inference subprocess shut down") def _cleanup(self): """atexit handler.""" self._shutdown_subprocess(timeout = 5.0) def _ensure_subprocess_alive(self) -> bool: """Check if subprocess is alive.""" return self._proc is not None and self._proc.is_alive() # ------------------------------------------------------------------ # Queue helpers # ------------------------------------------------------------------ def _send_cmd(self, cmd: dict) -> None: """Send a command to the subprocess.""" if self._cmd_queue is None: raise RuntimeError("No inference subprocess running") try: self._cmd_queue.put(cmd) except (OSError, ValueError) as exc: raise RuntimeError(f"Failed to send command to subprocess: {exc}") def _read_resp(self, timeout: float = 1.0) -> Optional[dict]: """Read a response from the subprocess (non-blocking with timeout).""" if self._resp_queue is None: return None try: return self._resp_queue.get(timeout = timeout) except queue.Empty: return None except (EOFError, OSError, ValueError): return None def _wait_response(self, expected_type: str, timeout: float = 120.0) -> dict: """Block until a response of the expected type arrives. Also handles 'status' and 'error' events during the wait. Returns the matching response dict. Raises RuntimeError on timeout or subprocess crash. """ deadline = time.monotonic() + timeout while time.monotonic() < deadline: remaining = max(0.1, deadline - time.monotonic()) resp = self._read_resp(timeout = min(remaining, 1.0)) if resp is None: # Check subprocess health if not self._ensure_subprocess_alive(): raise RuntimeError("Inference subprocess crashed during wait") continue rtype = resp.get("type", "") if rtype == expected_type: return resp if rtype == "error": error_msg = resp.get("error", "Unknown error") raise RuntimeError(f"Subprocess error: {error_msg}") if rtype == "status": logger.info("Subprocess status: %s", resp.get("message", "")) continue # Other response types during wait — skip logger.debug( "Skipping response type '%s' while waiting for '%s'", rtype, expected_type, ) raise RuntimeError( f"Timeout waiting for '{expected_type}' response after {timeout}s" ) def _drain_queue(self) -> list: """Drain all pending responses.""" events = [] if self._resp_queue is None: return events while True: try: events.append(self._resp_queue.get_nowait()) except queue.Empty: return events except (EOFError, OSError, ValueError): return events def _drain_until_gen_done(self, timeout: float = 5.0) -> None: """Consume resp_queue events until gen_done/gen_error, discarding them. Called after cancel to ensure stale tokens from the cancelled generation don't leak into the next request. """ deadline = time.monotonic() + timeout while time.monotonic() < deadline: resp = self._read_resp(timeout = min(0.5, deadline - time.monotonic())) if resp is None: if not self._ensure_subprocess_alive(): return continue rtype = resp.get("type", "") if rtype in ("gen_done", "gen_error"): return logger.warning("Timed out waiting for gen_done after cancel") # ------------------------------------------------------------------ # Dispatcher — per-request mailbox routing for compare mode # ------------------------------------------------------------------ def _start_dispatcher(self) -> None: """Start the dispatcher thread if not already running. The dispatcher reads from the shared resp_queue and routes responses to per-request mailbox queues. This allows multiple adapter-controlled (compare) requests to be in-flight without holding _gen_lock. """ if self._dispatcher_thread is not None and self._dispatcher_thread.is_alive(): return self._dispatcher_stop.clear() self._dispatcher_thread = threading.Thread( target = self._dispatcher_loop, daemon = True, name = "inference-dispatcher", ) self._dispatcher_thread.start() logger.debug("Dispatcher thread started") def _stop_dispatcher(self) -> None: """Signal the dispatcher to stop and wait for it.""" if self._dispatcher_thread is None: return self._dispatcher_stop.set() self._dispatcher_thread.join(timeout = _DISPATCH_STOP_TIMEOUT) self._dispatcher_thread = None logger.debug("Dispatcher thread stopped") def _dispatcher_loop(self) -> None: """Background loop: read resp_queue → route to mailboxes by request_id.""" while not self._dispatcher_stop.is_set(): if self._resp_queue is None: break try: resp = self._resp_queue.get(timeout = _DISPATCH_POLL_INTERVAL) except queue.Empty: continue except (EOFError, OSError, ValueError): break rid = resp.get("request_id") rtype = resp.get("type", "") # Status messages — log and skip if rtype == "status": logger.info("Subprocess status: %s", resp.get("message", "")) continue # Route to mailbox if a matching request_id exists if rid: with self._mailbox_lock: mbox = self._mailboxes.get(rid) if mbox is not None: mbox.put(resp) continue # No matching mailbox — might be for a _gen_lock reader or orphaned # Push it back so _read_resp can pick it up. But we can't un-get # from mp.Queue, so log a warning. if rtype not in ("status",): logger.debug( "Dispatcher: no mailbox for request_id=%s type=%s, dropping", rid, rtype, ) def _generate_dispatched( self, messages: list = None, system_prompt: str = "", image = None, temperature: float = 0.7, top_p: float = 0.9, top_k: int = 40, min_p: float = 0.0, max_new_tokens: int = 256, repetition_penalty: float = 1.0, cancel_event = None, use_adapter = None, ) -> Generator[str, None, None]: """Dispatched generation — sends command without holding _gen_lock. Uses a per-request mailbox to receive tokens. This allows two compare-mode requests to be queued in the subprocess simultaneously, eliminating the inter-generation round-trip overhead. The subprocess processes commands sequentially from its cmd_queue, so generation is still serialized at the GPU level — we just avoid the orchestrator-level lock contention. """ if not self._ensure_subprocess_alive(): yield "Error: Inference subprocess is not running" return if not self.active_model_name: yield "Error: No active model" return # Ensure dispatcher is running self._start_dispatcher() request_id = str(uuid.uuid4()) # Convert PIL Image to base64 if needed image_b64 = None if image is not None: image_b64 = self._pil_to_base64(image) cmd = { "type": "generate", "request_id": request_id, "messages": messages or [], "system_prompt": system_prompt, "image_base64": image_b64, "temperature": temperature, "top_p": top_p, "top_k": top_k, "min_p": min_p, "max_new_tokens": max_new_tokens, "repetition_penalty": repetition_penalty, } if use_adapter is not None: cmd["use_adapter"] = use_adapter # Create mailbox BEFORE sending command mailbox: queue.Queue = queue.Queue() with self._mailbox_lock: self._mailboxes[request_id] = mailbox try: self._send_cmd(cmd) except RuntimeError as exc: with self._mailbox_lock: self._mailboxes.pop(request_id, None) yield f"Error: {exc}" return # Read tokens from our private mailbox try: while True: try: resp = mailbox.get(timeout = _DISPATCH_READ_TIMEOUT) except queue.Empty: # Timeout — check subprocess health if not self._ensure_subprocess_alive(): yield "Error: Inference subprocess crashed during generation" return continue rtype = resp.get("type", "") if rtype == "token": # Check cancel from route (e.g. SSE connection closed) if cancel_event is not None and cancel_event.is_set(): self._cancel_generation() # Drain remaining events for this request self._drain_mailbox(mailbox, timeout = 5.0) return yield resp.get("text", "") elif rtype == "gen_done": return elif rtype == "gen_error": yield f"Error: {resp.get('error', 'Unknown error')}" return finally: with self._mailbox_lock: self._mailboxes.pop(request_id, None) def _drain_mailbox(self, mailbox: queue.Queue, timeout: float = 5.0) -> None: """Drain a mailbox until gen_done/gen_error, discarding tokens.""" deadline = time.monotonic() + timeout while time.monotonic() < deadline: try: resp = mailbox.get( timeout = min(_DISPATCH_POLL_INTERVAL, deadline - time.monotonic()) ) except queue.Empty: continue rtype = resp.get("type", "") if rtype in ("gen_done", "gen_error"): return logger.warning("Timed out draining mailbox after cancel") def _wait_dispatcher_idle(self) -> None: """Wait for all dispatched requests to complete, then stop dispatcher. Called by _generate_inner before using the _gen_lock path, to ensure the dispatcher thread isn't competing for resp_queue reads. """ if self._dispatcher_thread is None or not self._dispatcher_thread.is_alive(): return # Wait for all mailboxes to be emptied (dispatched requests complete) deadline = time.monotonic() + _DISPATCH_IDLE_TIMEOUT while time.monotonic() < deadline: with self._mailbox_lock: if not self._mailboxes: break time.sleep(0.1) # Only stop dispatcher if all mailboxes drained. If compare # requests are still active, leave the dispatcher running so # their token routing isn't killed mid-stream. with self._mailbox_lock: still_active = bool(self._mailboxes) if still_active: logger.warning( "Dispatcher still has %d active mailbox(es); " "leaving dispatcher running for compare requests", len(self._mailboxes), ) else: self._stop_dispatcher() # ------------------------------------------------------------------ # Public API — same interface as InferenceBackend # ------------------------------------------------------------------ def load_model( self, config, # ModelConfig max_seq_length: int = 2048, dtype = None, load_in_4bit: bool = True, hf_token: Optional[str] = None, trust_remote_code: bool = False, ) -> bool: """Load a model for inference. Always spawns a fresh subprocess for each model load. This ensures a clean Python interpreter — no stale unsloth patches, torch.compile caches, or inspect.getsource() failures from a previous model. """ from utils.transformers_version import needs_transformers_5 model_name = config.identifier self.loading_models.add(model_name) try: needed_major = "5" if needs_transformers_5(model_name) else "4" # Build config dict for subprocess sub_config = { "model_name": model_name, "max_seq_length": max_seq_length, "load_in_4bit": load_in_4bit, "hf_token": hf_token or "", "gguf_variant": getattr(config, "gguf_variant", None), "trust_remote_code": trust_remote_code, } # Always kill existing subprocess and spawn fresh. # Reusing a subprocess after unsloth patches torch internals # causes inspect.getsource() failures on the next model load. if self._ensure_subprocess_alive(): self._cancel_generation() time.sleep(0.3) self._shutdown_subprocess() elif self._proc is not None: # Dead subprocess — clean up self._shutdown_subprocess(timeout = 2) logger.info( "Spawning fresh inference subprocess for '%s' (transformers %s.x)", model_name, needed_major, ) self._spawn_subprocess(sub_config) resp = self._wait_response("loaded", timeout = 180) # Update local state from response if resp.get("success"): self._current_transformers_major = needed_major model_info = resp.get("model_info", {}) self.active_model_name = model_info.get("identifier", model_name) self.models[self.active_model_name] = { "is_vision": model_info.get("is_vision", False), "is_lora": model_info.get("is_lora", False), "display_name": model_info.get("display_name", model_name), "is_audio": model_info.get("is_audio", False), "audio_type": model_info.get("audio_type"), "has_audio_input": model_info.get("has_audio_input", False), } self.loading_models.discard(model_name) logger.info("Model '%s' loaded successfully in subprocess", model_name) return True else: error = resp.get("error", "Failed to load model") self.loading_models.discard(model_name) self.active_model_name = None self.models.clear() raise Exception(error) except Exception: self.loading_models.discard(model_name) self.active_model_name = None self.models.clear() raise def unload_model(self, model_name: str) -> bool: """Unload a model from the subprocess.""" if not self._ensure_subprocess_alive(): # No subprocess — just clear local state self.models.pop(model_name, None) if self.active_model_name == model_name: self.active_model_name = None return True try: self._send_cmd( { "type": "unload", "model_name": model_name, } ) resp = self._wait_response("unloaded", timeout = 30) # Update local state self.models.pop(model_name, None) if self.active_model_name == model_name: self.active_model_name = None logger.info("Model '%s' unloaded from subprocess", model_name) return True except Exception as exc: logger.error("Error unloading model '%s': %s", model_name, exc) # Clear local state anyway self.models.pop(model_name, None) if self.active_model_name == model_name: self.active_model_name = None return False def generate_chat_response( self, messages: list, system_prompt: str = "", image = None, temperature: float = 0.7, top_p: float = 0.9, top_k: int = 40, min_p: float = 0.0, max_new_tokens: int = 256, repetition_penalty: float = 1.0, cancel_event = None, ) -> Generator[str, None, None]: """Generate response, streaming tokens from subprocess.""" yield from self._generate_inner( messages = messages, system_prompt = system_prompt, image = image, temperature = temperature, top_p = top_p, top_k = top_k, min_p = min_p, max_new_tokens = max_new_tokens, repetition_penalty = repetition_penalty, cancel_event = cancel_event, use_adapter = None, ) def generate_with_adapter_control( self, use_adapter: Optional[Union[bool, str]] = None, cancel_event = None, **gen_kwargs, ) -> Generator[str, None, None]: """Generate with adapter control, streaming tokens from subprocess. Uses the dispatcher path (no _gen_lock) so that compare-mode requests don't block each other. The subprocess naturally serializes them via its sequential command loop. """ yield from self._generate_dispatched( use_adapter = use_adapter, cancel_event = cancel_event, **gen_kwargs, ) def _generate_inner( self, messages: list = None, system_prompt: str = "", image = None, temperature: float = 0.7, top_p: float = 0.9, top_k: int = 40, min_p: float = 0.0, max_new_tokens: int = 256, repetition_penalty: float = 1.0, cancel_event = None, use_adapter = None, ) -> Generator[str, None, None]: """Inner generation logic — sends command to subprocess, yields tokens. Serialized by _gen_lock: only one generation runs at a time. This prevents concurrent readers from consuming each other's tokens off the shared resp_queue. """ if not self._ensure_subprocess_alive(): yield "Error: Inference subprocess is not running" return if not self.active_model_name: yield "Error: No active model" return # If the dispatcher is running (from a previous compare-mode request), # wait for all dispatched requests to finish, then stop the dispatcher # so we can safely read from resp_queue directly. self._wait_dispatcher_idle() # Serialize generation — single GPU, one generation at a time. # Without this lock, two concurrent readers on the same resp_queue # can consume and drop each other's token events. with self._gen_lock: yield from self._generate_locked( messages = messages, system_prompt = system_prompt, image = image, temperature = temperature, top_p = top_p, top_k = top_k, min_p = min_p, max_new_tokens = max_new_tokens, repetition_penalty = repetition_penalty, cancel_event = cancel_event, use_adapter = use_adapter, ) def _generate_locked( self, messages: list = None, system_prompt: str = "", image = None, temperature: float = 0.7, top_p: float = 0.9, top_k: int = 40, min_p: float = 0.0, max_new_tokens: int = 256, repetition_penalty: float = 1.0, cancel_event = None, use_adapter = None, ) -> Generator[str, None, None]: """Actual generation logic — must be called under _gen_lock.""" request_id = str(uuid.uuid4()) # Convert PIL Image to base64 if needed image_b64 = None if image is not None: image_b64 = self._pil_to_base64(image) cmd = { "type": "generate", "request_id": request_id, "messages": messages or [], "system_prompt": system_prompt, "image_base64": image_b64, "temperature": temperature, "top_p": top_p, "top_k": top_k, "min_p": min_p, "max_new_tokens": max_new_tokens, "repetition_penalty": repetition_penalty, } if use_adapter is not None: cmd["use_adapter"] = use_adapter try: self._send_cmd(cmd) except RuntimeError as exc: yield f"Error: {exc}" return # Yield tokens from response queue — we are the only reader # because _gen_lock is held. while True: resp = self._read_resp(timeout = 30.0) if resp is None: # Check subprocess health if not self._ensure_subprocess_alive(): yield "Error: Inference subprocess crashed during generation" return continue rtype = resp.get("type", "") # Status messages — skip if rtype == "status": continue # Error without request_id = subprocess-level error resp_rid = resp.get("request_id") if rtype == "error" and not resp_rid: yield f"Error: {resp.get('error', 'Unknown error')}" return if rtype == "token": # Check cancel from route (e.g. SSE connection closed) if cancel_event is not None and cancel_event.is_set(): self._cancel_generation() # Wait for the subprocess to acknowledge cancellation # (gen_done/gen_error) so stale events don't leak into # the next generation request. self._drain_until_gen_done(timeout = 5.0) return yield resp.get("text", "") elif rtype == "gen_done": return elif rtype == "gen_error": yield f"Error: {resp.get('error', 'Unknown error')}" return def reset_generation_state(self): """Cancel any ongoing generation and reset state.""" self._cancel_generation() if not self._ensure_subprocess_alive(): return try: self._send_cmd({"type": "reset"}) except RuntimeError: pass # ------------------------------------------------------------------ # Audio generation — TTS, ASR, audio input # ------------------------------------------------------------------ def generate_audio_response( self, text: str, temperature: float = 0.6, top_p: float = 0.95, top_k: int = 50, min_p: float = 0.0, max_new_tokens: int = 2048, repetition_penalty: float = 1.0, use_adapter: Optional[Union[bool, str]] = None, ) -> Tuple[bytes, int]: """Generate TTS audio. Returns (wav_bytes, sample_rate). Blocking — sends command and waits for the complete audio response. """ if not self._ensure_subprocess_alive(): raise RuntimeError("Inference subprocess is not running") if not self.active_model_name: raise RuntimeError("No active model") import uuid request_id = str(uuid.uuid4()) cmd = { "type": "generate_audio", "request_id": request_id, "text": text, "temperature": temperature, "top_p": top_p, "top_k": top_k, "min_p": min_p, "max_new_tokens": max_new_tokens, "repetition_penalty": repetition_penalty, } if use_adapter is not None: cmd["use_adapter"] = use_adapter self._send_cmd(cmd) # Wait for audio_done or audio_error deadline = time.monotonic() + 120.0 while time.monotonic() < deadline: remaining = max(0.1, deadline - time.monotonic()) resp = self._read_resp(timeout = min(remaining, 1.0)) if resp is None: if not self._ensure_subprocess_alive(): raise RuntimeError( "Inference subprocess crashed during audio generation" ) continue rtype = resp.get("type", "") if rtype == "audio_done": wav_bytes = base64.b64decode(resp["wav_base64"]) sample_rate = resp["sample_rate"] return wav_bytes, sample_rate if rtype == "audio_error": raise RuntimeError(resp.get("error", "Audio generation failed")) if rtype == "error": raise RuntimeError(resp.get("error", "Unknown error")) if rtype == "status": continue raise RuntimeError("Timeout waiting for audio generation (120s)") def generate_whisper_response( self, audio_array, cancel_event = None, ) -> Generator[str, None, None]: """Whisper ASR — sends audio to subprocess, yields text.""" yield from self._generate_audio_input_inner( audio_array = audio_array, audio_type = "whisper", messages = [], system_prompt = "", cancel_event = cancel_event, ) def generate_audio_input_response( self, messages, system_prompt, audio_array, temperature: float = 0.7, top_p: float = 0.9, top_k: int = 40, min_p: float = 0.0, max_new_tokens: int = 512, repetition_penalty: float = 1.0, cancel_event = None, ) -> Generator[str, None, None]: """Audio input generation (e.g. Gemma 3n) — streams text tokens.""" yield from self._generate_audio_input_inner( audio_array = audio_array, audio_type = None, # worker will use generate_audio_input_response messages = messages, system_prompt = system_prompt, temperature = temperature, top_p = top_p, top_k = top_k, min_p = min_p, max_new_tokens = max_new_tokens, repetition_penalty = repetition_penalty, cancel_event = cancel_event, ) def _generate_audio_input_inner( self, audio_array, audio_type: Optional[str] = None, messages: list = None, system_prompt: str = "", temperature: float = 0.7, top_p: float = 0.9, top_k: int = 40, min_p: float = 0.0, max_new_tokens: int = 512, repetition_penalty: float = 1.0, cancel_event = None, ) -> Generator[str, None, None]: """Shared inner logic for audio input generation (Whisper + ASR).""" if not self._ensure_subprocess_alive(): yield "Error: Inference subprocess is not running" return if not self.active_model_name: yield "Error: No active model" return with self._gen_lock: import uuid request_id = str(uuid.uuid4()) # Convert numpy array to list for mp.Queue serialization audio_data = ( audio_array.tolist() if hasattr(audio_array, "tolist") else list(audio_array) ) cmd = { "type": "generate_audio_input", "request_id": request_id, "audio_data": audio_data, "audio_type": audio_type, "messages": messages or [], "system_prompt": system_prompt, "temperature": temperature, "top_p": top_p, "top_k": top_k, "min_p": min_p, "max_new_tokens": max_new_tokens, "repetition_penalty": repetition_penalty, } try: self._send_cmd(cmd) except RuntimeError as exc: yield f"Error: {exc}" return # Yield tokens — same pattern as _generate_locked while True: resp = self._read_resp(timeout = 30.0) if resp is None: if not self._ensure_subprocess_alive(): yield "Error: Inference subprocess crashed during audio input generation" return continue rtype = resp.get("type", "") if rtype == "status": continue if rtype == "error" and not resp.get("request_id"): yield f"Error: {resp.get('error', 'Unknown error')}" return if rtype == "token": if cancel_event is not None and cancel_event.is_set(): self._cancel_generation() self._drain_until_gen_done(timeout = 5.0) return yield resp.get("text", "") elif rtype == "gen_done": return elif rtype == "gen_error": yield f"Error: {resp.get('error', 'Unknown error')}" return # ------------------------------------------------------------------ # Local helpers (no subprocess needed) # ------------------------------------------------------------------ def resize_image(self, img, max_size: int = 800): """Resize image while maintaining aspect ratio. No ML imports needed — runs locally in parent process. """ if img is None: return None if img.size[0] > max_size or img.size[1] > max_size: from PIL import Image ratio = min(max_size / img.size[0], max_size / img.size[1]) new_size = (int(img.size[0] * ratio), int(img.size[1] * ratio)) return img.resize(new_size, Image.Resampling.LANCZOS) return img @staticmethod def _pil_to_base64(img) -> str: """Convert a PIL Image to base64 string for IPC.""" buf = BytesIO() img.save(buf, format = "PNG") return base64.b64encode(buf.getvalue()).decode("ascii") def get_current_model(self) -> Optional[str]: """Get currently active model name.""" return self.active_model_name def is_model_loading(self) -> bool: """Check if any model is currently loading.""" return len(self.loading_models) > 0 def get_loading_model(self) -> Optional[str]: """Get name of currently loading model.""" return next(iter(self.loading_models)) if self.loading_models else None def check_vision_model_compatibility(self) -> bool: """Check if current model supports vision.""" if self.active_model_name and self.active_model_name in self.models: return self.models[self.active_model_name].get("is_vision", False) return False # ========== GLOBAL INSTANCE ========== _inference_backend = None def get_inference_backend() -> InferenceOrchestrator: """Get global inference backend instance (orchestrator).""" global _inference_backend if _inference_backend is None: _inference_backend = InferenceOrchestrator() return _inference_backend ================================================ FILE: studio/backend/core/inference/tools.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Tool definitions and executors for LLM tool calling. Supports web search (DuckDuckGo), Python code execution, and terminal commands. """ import ast import os os.environ["UNSLOTH_IS_PRESENT"] = "1" import subprocess import sys import tempfile import threading from loggers import get_logger logger = get_logger(__name__) _EXEC_TIMEOUT = 300 # 5 minutes _MAX_OUTPUT_CHARS = 8000 # truncate long output _BASH_BLOCKED_WORDS = {"rm", "sudo", "dd", "chmod", "mkfs", "shutdown", "reboot"} # Per-session working directories so each chat thread gets its own sandbox. # Falls back to a shared ~/studio_sandbox/ for API callers without a session_id. _workdirs: dict[str, str] = {} def _get_workdir(session_id: str | None = None) -> str: """Return (and lazily create) a persistent working directory for tool execution.""" global _workdirs key = session_id or "_default" if key not in _workdirs or not os.path.isdir(_workdirs[key]): home = os.path.expanduser("~") sandbox_root = os.path.join(home, "studio_sandbox") if session_id: # Sanitize: strip path separators and parent-dir references safe_id = os.path.basename(session_id.replace("..", "")) if not safe_id: safe_id = "_invalid" workdir = os.path.join(sandbox_root, safe_id) # Verify resolved path stays under sandbox root if not os.path.realpath(workdir).startswith(os.path.realpath(sandbox_root)): workdir = os.path.join(sandbox_root, "_invalid") else: workdir = sandbox_root os.makedirs(workdir, exist_ok = True) _workdirs[key] = workdir return _workdirs[key] WEB_SEARCH_TOOL = { "type": "function", "function": { "name": "web_search", "description": "Search the web for current information, recent events, or facts you are uncertain about.", "parameters": { "type": "object", "properties": { "query": { "type": "string", "description": "The search query", } }, "required": ["query"], }, }, } PYTHON_TOOL = { "type": "function", "function": { "name": "python", "description": "Execute Python code in a sandbox and return stdout/stderr.", "parameters": { "type": "object", "properties": { "code": { "type": "string", "description": "The Python code to run", } }, "required": ["code"], }, }, } TERMINAL_TOOL = { "type": "function", "function": { "name": "terminal", "description": "Execute a terminal command and return stdout/stderr.", "parameters": { "type": "object", "properties": { "command": { "type": "string", "description": "The command to run", } }, "required": ["command"], }, }, } ALL_TOOLS = [WEB_SEARCH_TOOL, PYTHON_TOOL, TERMINAL_TOOL] _TIMEOUT_UNSET = object() def execute_tool( name: str, arguments: dict, cancel_event = None, timeout: int | None = _TIMEOUT_UNSET, session_id: str | None = None, ) -> str: """Execute a tool by name with the given arguments. Returns result as a string. ``timeout``: int sets per-call limit in seconds, ``None`` means no limit, unset (default) uses ``_EXEC_TIMEOUT`` (300 s). ``session_id``: optional thread/session ID for per-conversation sandbox isolation. """ logger.info( f"execute_tool: name={name}, session_id={session_id}, timeout={timeout}" ) effective_timeout = _EXEC_TIMEOUT if timeout is _TIMEOUT_UNSET else timeout if name == "web_search": return _web_search(arguments.get("query", ""), timeout = effective_timeout) if name == "python": return _python_exec( arguments.get("code", ""), cancel_event, effective_timeout, session_id ) if name == "terminal": return _bash_exec( arguments.get("command", ""), cancel_event, effective_timeout, session_id ) return f"Unknown tool: {name}" def _web_search(query: str, max_results: int = 5, timeout: int = _EXEC_TIMEOUT) -> str: """Search the web using DuckDuckGo and return formatted results.""" if not query.strip(): return "No query provided." try: from ddgs import DDGS results = DDGS(timeout = timeout).text(query, max_results = max_results) if not results: return "No results found." parts = [] for r in results: parts.append( f"Title: {r.get('title', '')}\n" f"URL: {r.get('href', '')}\n" f"Snippet: {r.get('body', '')}" ) return "\n\n---\n\n".join(parts) except Exception as e: return f"Search failed: {e}" def _check_signal_escape_patterns(code: str): """ Check if code contains patterns that could escape signal-based timeouts. Vendored from unsloth_zoo.rl_environments to avoid importing unsloth_zoo (which requires GPU drivers and fails on Mac/Apple Silicon). Returns (safe: bool, details: dict) """ try: tree = ast.parse(code) except SyntaxError as e: return False, { "error": f"SyntaxError: {e}", "signal_tampering": [], "exception_catching": [], "warnings": [], } signal_tampering = [] exception_catching = [] warnings = [] def _ast_name_matches(node, names): if isinstance(node, ast.Name): return node.id in names elif isinstance(node, ast.Attribute): full_name = [] current = node while isinstance(current, ast.Attribute): full_name.append(current.attr) current = current.value if isinstance(current, ast.Name): full_name.append(current.id) full_name = ".".join(reversed(full_name)) return full_name in names return False class SignalEscapeVisitor(ast.NodeVisitor): def __init__(self): self.imports_signal = False self.signal_aliases = {"signal"} self.loop_depth = 0 def visit_Import(self, node): for alias in node.names: if alias.name == "signal": self.imports_signal = True if alias.asname: self.signal_aliases.add(alias.asname) self.generic_visit(node) def visit_ImportFrom(self, node): if node.module == "signal": self.imports_signal = True for alias in node.names: if alias.name in ( "signal", "SIGALRM", "SIG_IGN", "setitimer", "ITIMER_REAL", "pthread_sigmask", "SIG_BLOCK", "alarm", ): self.signal_aliases.add(alias.asname or alias.name) self.generic_visit(node) def visit_While(self, node): self.loop_depth += 1 self.generic_visit(node) self.loop_depth -= 1 def visit_For(self, node): self.loop_depth += 1 self.generic_visit(node) self.loop_depth -= 1 def visit_Call(self, node): func = node.func func_name = None if isinstance(func, ast.Attribute): if isinstance(func.value, ast.Name): if func.value.id in self.signal_aliases: func_name = f"signal.{func.attr}" elif isinstance(func, ast.Name): if func.id in ("signal", "setitimer", "alarm", "pthread_sigmask"): func_name = func.id if func_name: if func_name in ("signal.signal", "signal"): if len(node.args) >= 1: if _ast_name_matches( node.args[0], ("SIGALRM", "signal.SIGALRM") ): signal_tampering.append( { "type": "signal_handler_override", "line": node.lineno, "description": "Overrides SIGALRM handler", } ) elif func_name in ("signal.setitimer", "setitimer"): if len(node.args) >= 1: if _ast_name_matches( node.args[0], ("ITIMER_REAL", "signal.ITIMER_REAL") ): signal_tampering.append( { "type": "timer_manipulation", "line": node.lineno, "description": "Manipulates ITIMER_REAL timer", } ) elif func_name in ("signal.alarm", "alarm"): signal_tampering.append( { "type": "alarm_manipulation", "line": node.lineno, "description": "Manipulates alarm timer", } ) elif func_name in ("signal.pthread_sigmask", "pthread_sigmask"): signal_tampering.append( { "type": "signal_mask", "line": node.lineno, "description": "Modifies signal mask (may block SIGALRM)", } ) self.generic_visit(node) def visit_ExceptHandler(self, node): if self.loop_depth == 0: self.generic_visit(node) return if node.type is None: exception_catching.append( { "type": "bare_except_in_loop", "line": node.lineno, "description": "Bare except in loop catches TimeoutError and continues looping", } ) elif isinstance(node.type, ast.Name): if node.type.id in ("TimeoutError", "BaseException", "Exception"): exception_catching.append( { "type": f"catches_{node.type.id}_in_loop", "line": node.lineno, "description": f"Catches {node.type.id} in loop - may suppress timeout and continue", } ) elif isinstance(node.type, ast.Tuple): for elt in node.type.elts: if isinstance(elt, ast.Name): if elt.id in ("TimeoutError", "BaseException", "Exception"): exception_catching.append( { "type": f"catches_{elt.id}_in_loop", "line": node.lineno, "description": f"Catches {elt.id} in loop - may suppress timeout and continue", } ) self.generic_visit(node) visitor = SignalEscapeVisitor() visitor.visit(tree) if visitor.imports_signal and not signal_tampering: warnings.append("Code imports 'signal' module - review manually for safety") is_safe = len(signal_tampering) == 0 and len(exception_catching) == 0 return is_safe, { "signal_tampering": signal_tampering, "exception_catching": exception_catching, "warnings": warnings, } def _check_code_safety(code: str) -> str | None: """Validate code safety via static analysis. Returns an error message string if the code is unsafe, or None if OK. """ safe, info = _check_signal_escape_patterns(code) if not safe: reasons = [ item.get("description", "") for item in info.get("signal_tampering", []) ] return ( f"Error: unsafe code detected ({'; '.join(reasons)}). " f"Please remove signal manipulation from your code." ) return None def _cancel_watcher(proc, cancel_event, poll_interval = 0.2): """Daemon thread that kills a process when cancel_event is set.""" while proc.poll() is None: if cancel_event is not None and cancel_event.is_set(): proc.kill() return cancel_event.wait(poll_interval) if cancel_event else None def _truncate(text: str, limit: int = _MAX_OUTPUT_CHARS) -> str: if len(text) > limit: return text[:limit] + f"\n\n... (truncated, {len(text)} chars total)" return text def _python_exec( code: str, cancel_event = None, timeout: int = _EXEC_TIMEOUT, session_id: str | None = None, ) -> str: """Execute Python code in a subprocess sandbox.""" if not code or not code.strip(): return "No code provided." # Validate imports and code safety error = _check_code_safety(code) if error: return error tmp_path = None workdir = _get_workdir(session_id) try: fd, tmp_path = tempfile.mkstemp( suffix = ".py", prefix = "studio_exec_", dir = workdir ) with os.fdopen(fd, "w") as f: f.write(code) proc = subprocess.Popen( [sys.executable, tmp_path], stdout = subprocess.PIPE, stderr = subprocess.STDOUT, text = True, cwd = workdir, ) # Spawn cancel watcher if we have a cancel event if cancel_event is not None: watcher = threading.Thread( target = _cancel_watcher, args = (proc, cancel_event), daemon = True ) watcher.start() try: output, _ = proc.communicate(timeout = timeout) except subprocess.TimeoutExpired: proc.kill() proc.communicate() return _truncate(f"Execution timed out after {timeout} seconds.") if cancel_event is not None and cancel_event.is_set(): return "Execution cancelled." result = output or "" if proc.returncode != 0: result = f"Exit code {proc.returncode}:\n{result}" return _truncate(result) if result.strip() else "(no output)" except Exception as e: return f"Execution error: {e}" finally: if tmp_path and os.path.exists(tmp_path): try: os.unlink(tmp_path) except OSError: pass def _bash_exec( command: str, cancel_event = None, timeout: int = _EXEC_TIMEOUT, session_id: str | None = None, ) -> str: """Execute a bash command in a subprocess sandbox.""" if not command or not command.strip(): return "No command provided." # Block dangerous commands tokens = set(command.lower().split()) blocked = tokens & _BASH_BLOCKED_WORDS if blocked: return f"Blocked command(s) for safety: {', '.join(sorted(blocked))}" try: workdir = _get_workdir(session_id) proc = subprocess.Popen( ["bash", "-c", command], stdout = subprocess.PIPE, stderr = subprocess.STDOUT, text = True, cwd = workdir, ) if cancel_event is not None: watcher = threading.Thread( target = _cancel_watcher, args = (proc, cancel_event), daemon = True ) watcher.start() try: output, _ = proc.communicate(timeout = timeout) except subprocess.TimeoutExpired: proc.kill() proc.communicate() return _truncate(f"Execution timed out after {timeout} seconds.") if cancel_event is not None and cancel_event.is_set(): return "Execution cancelled." result = output or "" if proc.returncode != 0: result = f"Exit code {proc.returncode}:\n{result}" return _truncate(result) if result.strip() else "(no output)" except Exception as e: return f"Execution error: {e}" ================================================ FILE: studio/backend/core/inference/worker.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Inference subprocess entry point. Each inference session runs in a persistent subprocess (mp.get_context("spawn")). This gives us a clean Python interpreter with no stale module state — solving the transformers version-switching problem completely. The subprocess stays alive while a model is loaded, accepting commands (generate, load, unload) via mp.Queue. It exits on shutdown or unload. Pattern follows core/training/worker.py. """ from __future__ import annotations import base64 import structlog from loggers import get_logger import os import queue as _queue import sys import time import traceback from io import BytesIO from pathlib import Path from typing import Any logger = get_logger(__name__) def _activate_transformers_version(model_name: str) -> None: """Activate the correct transformers version BEFORE any ML imports. If the model needs transformers 5.x, prepend the pre-installed .venv_t5/ directory to sys.path. Otherwise do nothing (default 4.57.x in .venv/). """ # Ensure backend is on path for utils imports backend_path = str(Path(__file__).resolve().parent.parent.parent) if backend_path not in sys.path: sys.path.insert(0, backend_path) from utils.transformers_version import ( needs_transformers_5, _resolve_base_model, _ensure_venv_t5_exists, _VENV_T5_DIR, ) resolved = _resolve_base_model(model_name) if needs_transformers_5(resolved): if not _ensure_venv_t5_exists(): raise RuntimeError( f"Cannot activate transformers 5.x: .venv_t5 missing at {_VENV_T5_DIR}" ) if _VENV_T5_DIR not in sys.path: sys.path.insert(0, _VENV_T5_DIR) logger.info("Activated transformers 5.x from %s", _VENV_T5_DIR) # Propagate to child subprocesses (e.g. GGUF converter) _pp = os.environ.get("PYTHONPATH", "") os.environ["PYTHONPATH"] = _VENV_T5_DIR + (os.pathsep + _pp if _pp else "") else: logger.info("Using default transformers (4.57.x) for %s", model_name) def _decode_image(image_base64: str): """Decode base64 string to PIL.Image.""" from PIL import Image image_data = base64.b64decode(image_base64) return Image.open(BytesIO(image_data)) def _resize_image(img, max_size: int = 800): """Resize image while maintaining aspect ratio.""" if img is None: return None if img.size[0] > max_size or img.size[1] > max_size: from PIL import Image ratio = min(max_size / img.size[0], max_size / img.size[1]) new_size = (int(img.size[0] * ratio), int(img.size[1] * ratio)) return img.resize(new_size, Image.Resampling.LANCZOS) return img def _send_response(resp_queue: Any, response: dict) -> None: """Send a response to the parent process.""" try: resp_queue.put(response) except (OSError, ValueError) as exc: logger.error("Failed to send response: %s", exc) def _build_model_config(config: dict): """Build a ModelConfig from the config dict.""" from utils.models import ModelConfig model_name = config["model_name"] hf_token = config.get("hf_token") hf_token = hf_token if hf_token and hf_token.strip() else None gguf_variant = config.get("gguf_variant") mc = ModelConfig.from_identifier( model_id = model_name, hf_token = hf_token, gguf_variant = gguf_variant, ) if not mc: raise ValueError(f"Invalid model identifier: {model_name}") return mc def _handle_load(backend, config: dict, resp_queue: Any) -> None: """Handle a load command: load a model into the backend.""" try: mc = _build_model_config(config) hf_token = config.get("hf_token") hf_token = hf_token if hf_token and hf_token.strip() else None # Auto-detect quantization for LoRA adapters load_in_4bit = config.get("load_in_4bit", True) if mc.is_lora and mc.path: import json from pathlib import Path adapter_cfg_path = Path(mc.path) / "adapter_config.json" if adapter_cfg_path.exists(): try: with open(adapter_cfg_path) as f: adapter_cfg = json.load(f) training_method = adapter_cfg.get("unsloth_training_method") if training_method == "lora" and load_in_4bit: logger.info( "adapter_config.json says lora — setting load_in_4bit=False" ) load_in_4bit = False elif training_method == "qlora" and not load_in_4bit: logger.info( "adapter_config.json says qlora — setting load_in_4bit=True" ) load_in_4bit = True elif not training_method: if ( mc.base_model and "-bnb-4bit" not in mc.base_model.lower() and load_in_4bit ): logger.info( "No training method, base model has no -bnb-4bit — setting load_in_4bit=False" ) load_in_4bit = False except Exception as e: logger.warning("Could not read adapter_config.json: %s", e) success = backend.load_model( config = mc, max_seq_length = config.get("max_seq_length", 2048), load_in_4bit = load_in_4bit, hf_token = hf_token, trust_remote_code = config.get("trust_remote_code", False), ) if success: # Build model_info for the parent to mirror model_info = { "identifier": mc.identifier, "display_name": mc.display_name, "is_vision": mc.is_vision, "is_lora": mc.is_lora, "is_gguf": False, "is_audio": getattr(mc, "is_audio", False), "audio_type": getattr(mc, "audio_type", None), "has_audio_input": getattr(mc, "has_audio_input", False), } _send_response( resp_queue, { "type": "loaded", "success": True, "model_info": model_info, "ts": time.time(), }, ) else: _send_response( resp_queue, { "type": "loaded", "success": False, "error": "Failed to load model", "ts": time.time(), }, ) except Exception as exc: _send_response( resp_queue, { "type": "loaded", "success": False, "error": str(exc), "stack": traceback.format_exc(limit = 20), "ts": time.time(), }, ) def _handle_generate( backend, cmd: dict, resp_queue: Any, cancel_event, ) -> None: """Handle a generate command: stream tokens back via resp_queue. cancel_event is an mp.Event shared with the parent process. The parent can set it at any time (e.g. user stops generation, or user loads a new model while generating) and generation stops within 1-2 tokens. """ request_id = cmd.get("request_id", "") try: # Decode image if provided image = None image_b64 = cmd.get("image_base64") if image_b64: image = _decode_image(image_b64) image = _resize_image(image) # Build generation kwargs gen_kwargs = { "messages": cmd["messages"], "system_prompt": cmd.get("system_prompt", ""), "image": image, "temperature": cmd.get("temperature", 0.7), "top_p": cmd.get("top_p", 0.9), "top_k": cmd.get("top_k", 40), "min_p": cmd.get("min_p", 0.0), "max_new_tokens": cmd.get("max_new_tokens", 256), "repetition_penalty": cmd.get("repetition_penalty", 1.0), "cancel_event": cancel_event, } # Choose generation path use_adapter = cmd.get("use_adapter") if use_adapter is not None: generator = backend.generate_with_adapter_control( use_adapter = use_adapter, **gen_kwargs, ) else: generator = backend.generate_chat_response(**gen_kwargs) logger.info("Starting text generation for request_id=%s", request_id) for cumulative_text in generator: # cancel_event is an mp.Event — checked instantly, no queue polling if cancel_event.is_set(): logger.info("Generation cancelled for request %s", request_id) break _send_response( resp_queue, { "type": "token", "request_id": request_id, "text": cumulative_text, "ts": time.time(), }, ) _send_response( resp_queue, { "type": "gen_done", "request_id": request_id, "ts": time.time(), }, ) logger.info("Finished text generation for request_id=%s", request_id) except Exception as exc: logger.error("Generation error: %s", exc, exc_info = True) _send_response( resp_queue, { "type": "gen_error", "request_id": request_id, "error": str(exc), "stack": traceback.format_exc(limit = 20), "ts": time.time(), }, ) def _handle_generate_audio( backend, cmd: dict, resp_queue: Any, ) -> None: """Handle TTS audio generation — returns WAV bytes + sample_rate.""" request_id = cmd.get("request_id", "") try: logger.info("Starting audio generation for request_id=%s", request_id) wav_bytes, sample_rate = backend.generate_audio_response( text = cmd["text"], temperature = cmd.get("temperature", 0.6), top_p = cmd.get("top_p", 0.95), top_k = cmd.get("top_k", 50), min_p = cmd.get("min_p", 0.0), max_new_tokens = cmd.get("max_new_tokens", 2048), repetition_penalty = cmd.get("repetition_penalty", 1.0), use_adapter = cmd.get("use_adapter"), ) # Send WAV bytes as base64 (bytes can't go through mp.Queue directly) _send_response( resp_queue, { "type": "audio_done", "request_id": request_id, "wav_base64": base64.b64encode(wav_bytes).decode("ascii"), "sample_rate": sample_rate, "ts": time.time(), }, ) logger.info("Finished audio generation for request_id=%s", request_id) except Exception as exc: logger.error("Audio generation error: %s", exc, exc_info = True) _send_response( resp_queue, { "type": "audio_error", "request_id": request_id, "error": str(exc), "stack": traceback.format_exc(limit = 20), "ts": time.time(), }, ) def _handle_generate_audio_input( backend, cmd: dict, resp_queue: Any, cancel_event, ) -> None: """Handle audio input generation (ASR/Whisper) — streams text tokens back.""" request_id = cmd.get("request_id", "") try: import numpy as np # Decode audio array from list (numpy arrays can't go through mp.Queue) audio_array = np.array(cmd["audio_data"], dtype = np.float32) audio_type = cmd.get("audio_type") if audio_type == "whisper": generator = backend.generate_whisper_response( audio_array = audio_array, cancel_event = cancel_event, ) else: generator = backend.generate_audio_input_response( messages = cmd.get("messages", []), system_prompt = cmd.get("system_prompt", ""), audio_array = audio_array, temperature = cmd.get("temperature", 0.7), top_p = cmd.get("top_p", 0.9), top_k = cmd.get("top_k", 40), min_p = cmd.get("min_p", 0.0), max_new_tokens = cmd.get("max_new_tokens", 512), repetition_penalty = cmd.get("repetition_penalty", 1.0), cancel_event = cancel_event, ) logger.info("Starting audio input generation for request_id=%s", request_id) for text_chunk in generator: if cancel_event.is_set(): logger.info( "Audio input generation cancelled for request %s", request_id ) break _send_response( resp_queue, { "type": "token", "request_id": request_id, "text": text_chunk, "ts": time.time(), }, ) _send_response( resp_queue, { "type": "gen_done", "request_id": request_id, "ts": time.time(), }, ) logger.info("Finished audio input generation for request_id=%s", request_id) except Exception as exc: logger.error("Audio input generation error: %s", exc, exc_info = True) _send_response( resp_queue, { "type": "gen_error", "request_id": request_id, "error": str(exc), "stack": traceback.format_exc(limit = 20), "ts": time.time(), }, ) def _handle_unload(backend, cmd: dict, resp_queue: Any) -> None: """Handle an unload command.""" model_name = cmd.get("model_name", "") try: if model_name and model_name in backend.models: backend.unload_model(model_name) elif backend.active_model_name: backend.unload_model(backend.active_model_name) _send_response( resp_queue, { "type": "unloaded", "model_name": model_name, "ts": time.time(), }, ) except Exception as exc: logger.error("Unload error: %s", exc) _send_response( resp_queue, { "type": "unloaded", "model_name": model_name, "error": str(exc), "ts": time.time(), }, ) def run_inference_process( *, cmd_queue: Any, resp_queue: Any, cancel_event, config: dict, ) -> None: """Subprocess entrypoint. Persistent — runs command loop until shutdown. Args: cmd_queue: mp.Queue for receiving commands from parent. resp_queue: mp.Queue for sending responses to parent. cancel_event: mp.Event shared with parent — set by parent to cancel generation. config: Initial configuration dict with model info. """ os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["PYTHONWARNINGS"] = ( "ignore" # Suppress warnings at C-level before imports ) import warnings from loggers.config import LogConfig if os.getenv("ENVIRONMENT_TYPE", "production") == "production": warnings.filterwarnings("ignore") LogConfig.setup_logging( service_name = "unsloth-studio-inference-worker", env = os.getenv("ENVIRONMENT_TYPE", "production"), ) model_name = config["model_name"] # ── 1. Activate correct transformers version BEFORE any ML imports ── try: _activate_transformers_version(model_name) except Exception as exc: _send_response( resp_queue, { "type": "error", "error": f"Failed to activate transformers version: {exc}", "stack": traceback.format_exc(limit = 20), "ts": time.time(), }, ) return # ── 1b. On Windows, check Triton availability (must be before import torch) ── if sys.platform == "win32": try: import triton # noqa: F401 logger.info("Triton available — torch.compile enabled") except ImportError: os.environ["TORCHDYNAMO_DISABLE"] = "1" logger.warning( "Triton not found on Windows — torch.compile disabled. " 'Install for better performance: pip install "triton-windows<3.7"' ) # ── 2. Import ML libraries (fresh in this clean process) ── try: _send_response( resp_queue, { "type": "status", "message": "Importing Unsloth...", "ts": time.time(), }, ) backend_path = str(Path(__file__).resolve().parent.parent.parent) if backend_path not in sys.path: sys.path.insert(0, backend_path) from core.inference.inference import InferenceBackend import transformers logger.info("Subprocess loaded transformers %s", transformers.__version__) except Exception as exc: _send_response( resp_queue, { "type": "error", "error": f"Failed to import ML libraries: {exc}", "stack": traceback.format_exc(limit = 20), "ts": time.time(), }, ) return # ── 3. Create inference backend and load initial model ── try: backend = InferenceBackend() _send_response( resp_queue, { "type": "status", "message": "Loading model...", "ts": time.time(), }, ) _handle_load(backend, config, resp_queue) except Exception as exc: _send_response( resp_queue, { "type": "error", "error": f"Failed to initialize inference backend: {exc}", "stack": traceback.format_exc(limit = 20), "ts": time.time(), }, ) return # ── 4. Command loop — process commands until shutdown ── # cancel_event is an mp.Event shared with parent — parent can set it # at any time to cancel generation instantly (no queue polling needed). logger.info("Inference subprocess ready, entering command loop") while True: try: cmd = cmd_queue.get(timeout = 1.0) except _queue.Empty: continue except (EOFError, OSError): logger.info("Command queue closed, shutting down") return if cmd is None: continue cmd_type = cmd.get("type", "") logger.info("Received command: %s", cmd_type) try: if cmd_type == "generate": cancel_event.clear() _handle_generate(backend, cmd, resp_queue, cancel_event) elif cmd_type == "load": # Load a new model (reusing this subprocess) # First unload current model if backend.active_model_name: backend.unload_model(backend.active_model_name) _handle_load(backend, cmd, resp_queue) elif cmd_type == "generate_audio": cancel_event.clear() _handle_generate_audio(backend, cmd, resp_queue) elif cmd_type == "generate_audio_input": cancel_event.clear() _handle_generate_audio_input(backend, cmd, resp_queue, cancel_event) elif cmd_type == "unload": _handle_unload(backend, cmd, resp_queue) elif cmd_type == "cancel": # Redundant with mp.Event but handle gracefully cancel_event.set() logger.info("Cancel command received") elif cmd_type == "reset": cancel_event.set() backend.reset_generation_state() _send_response( resp_queue, { "type": "reset_ack", "ts": time.time(), }, ) elif cmd_type == "status": # Return current status _send_response( resp_queue, { "type": "status_response", "active_model": backend.active_model_name, "models": { name: { "is_vision": info.get("is_vision", False), "is_lora": info.get("is_lora", False), } for name, info in backend.models.items() }, "loading": list(backend.loading_models), "ts": time.time(), }, ) elif cmd_type == "shutdown": logger.info("Shutdown command received, exiting") # Unload all models for model_name in list(backend.models.keys()): try: backend.unload_model(model_name) except Exception: pass _send_response( resp_queue, { "type": "shutdown_ack", "ts": time.time(), }, ) return else: logger.warning("Unknown command type: %s", cmd_type) _send_response( resp_queue, { "type": "error", "error": f"Unknown command type: {cmd_type}", "ts": time.time(), }, ) except Exception as exc: logger.error( "Error handling command '%s': %s", cmd_type, exc, exc_info = True ) _send_response( resp_queue, { "type": "error", "error": f"Command '{cmd_type}' failed: {exc}", "stack": traceback.format_exc(limit = 20), "ts": time.time(), }, ) ================================================ FILE: studio/backend/core/training/__init__.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Training submodule - Training backends and trainer classes """ from .training import TrainingBackend, TrainingProgress, get_training_backend __all__ = [ "TrainingProgress", "TrainingBackend", "get_training_backend", ] ================================================ FILE: studio/backend/core/training/trainer.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Unsloth Training Backend Integrates Unsloth training capabilities with the FastAPI backend """ import os import sys # Prevent tokenizer parallelism deadlocks when datasets uses multiprocessing fork os.environ["TOKENIZERS_PARALLELISM"] = "false" import torch from utils.hardware import clear_gpu_cache, safe_num_proc torch._dynamo.config.recompile_limit = 64 from unsloth import FastLanguageModel, FastVisionModel, is_bfloat16_supported from unsloth.chat_templates import get_chat_template import json import threading import math import structlog from loggers import get_logger import time from pathlib import Path from typing import Optional, Callable from dataclasses import dataclass import pandas as pd from datasets import Dataset, load_dataset from utils.models import is_vision_model, detect_audio_type from utils.datasets import format_and_template_dataset from utils.datasets import MODEL_TO_TEMPLATE_MAPPER, TEMPLATE_TO_RESPONSES_MAPPER from utils.paths import ( ensure_dir, resolve_dataset_path, resolve_output_dir, resolve_tensorboard_dir, ) from trl import SFTTrainer, SFTConfig logger = get_logger(__name__) def _build_report_targets(training_args) -> list[str] | str: report_to: list[str] = [] if training_args.get("enable_wandb", False): report_to.append("wandb") if training_args.get("enable_tensorboard", False): report_to.append("tensorboard") return report_to or "none" @dataclass class TrainingProgress: """Training progress tracking""" epoch: float = 0 step: int = 0 total_steps: int = 0 loss: float = 0.0 learning_rate: float = 0.0 is_training: bool = False is_completed: bool = False error: Optional[str] = None status_message: str = "Ready to train" # Current stage message elapsed_seconds: Optional[float] = None eta_seconds: Optional[float] = None grad_norm: Optional[float] = None num_tokens: Optional[int] = None eval_loss: Optional[float] = None class UnslothTrainer: """ Unsloth Training Backend """ def __init__(self): self.model = None self.tokenizer = None self.trainer = None self.training_thread = None self.training_progress = TrainingProgress() self.progress_callbacks = [] self.is_training = False self.should_stop = False self.save_on_stop = True self.load_in_4bit = True # Track quantization mode for metadata # Model state tracking self.is_vlm = False self.is_audio = False self.is_audio_vlm = ( False # Multimodal model (e.g. Gemma 3N) trained on audio data ) self._audio_type = None # 'csm', 'whisper', 'snac', 'bicodec', 'dac' self._cuda_audio_used = ( False # Set once after audio CUDA preprocessing; never cleared ) self._spark_tts_repo_dir = ( None # Path to downloaded Spark-TTS repo (for BiCodecTokenizer) ) self.model_name = None # Training metrics tracking self.training_start_time: Optional[float] = None self.batch_size: Optional[int] = None self.max_seq_length: Optional[int] = None self.gradient_accumulation_steps: Optional[int] = None # Thread safety self._lock = threading.Lock() # Store training context for later transfer self.training_context = { "base_model_name": None, "output_dir": None, "is_lora": True, # Default to LoRA } def pre_detect_and_load_tokenizer( self, model_name: str, max_seq_length: int = 2048, hf_token: Optional[str] = None, is_dataset_image: bool = False, is_dataset_audio: bool = False, trust_remote_code: bool = False, ) -> None: """Lightweight detection and tokenizer load — no model weights, no VRAM. Sets is_vlm, _audio_type, is_audio_vlm, model_name and loads a lightweight tokenizer for dataset formatting. Call this before load_and_format_dataset() when you want to process the dataset BEFORE loading the training model (avoids VRAM contention with the LLM-assisted detection helper). load_model() may be called afterwards — it will re-detect and load the full model + tokenizer, overwriting the lightweight one set here. """ self.model_name = model_name self.max_seq_length = max_seq_length self.trust_remote_code = trust_remote_code if hf_token: os.environ["HF_TOKEN"] = hf_token # --- Detect audio type (reads config.json only, no VRAM) --- self._audio_type = detect_audio_type(model_name, hf_token) if self._audio_type == "audio_vlm": self.is_audio = False self.is_audio_vlm = is_dataset_audio self._audio_type = None else: self.is_audio = self._audio_type is not None self.is_audio_vlm = False if not self.is_audio and not self.is_audio_vlm: self._cuda_audio_used = False # --- Detect VLM --- vision = is_vision_model(model_name) if not self.is_audio else False self.is_vlm = not self.is_audio_vlm and vision and is_dataset_image logger.info( "pre_detect: audio_type=%s, is_audio=%s, is_audio_vlm=%s, is_vlm=%s", self._audio_type, self.is_audio, self.is_audio_vlm, self.is_vlm, ) # --- Load lightweight tokenizer/processor (CPU only, no VRAM) --- # Whisper needs AutoProcessor (has feature_extractor + tokenizer). # All others work with AutoTokenizer (CSM loads its own processor inline). if self._audio_type == "whisper": from transformers import AutoProcessor self.tokenizer = AutoProcessor.from_pretrained( model_name, trust_remote_code = trust_remote_code, token = hf_token, ) else: from transformers import AutoTokenizer self.tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code = trust_remote_code, token = hf_token, ) logger.info("Pre-loaded tokenizer for %s", model_name) def add_progress_callback(self, callback: Callable[[TrainingProgress], None]): """Add callback for training progress updates""" self.progress_callbacks.append(callback) def _update_progress(self, **kwargs): """Update training progress and notify callbacks""" with self._lock: for key, value in kwargs.items(): if hasattr(self.training_progress, key): setattr(self.training_progress, key, value) # Notify all callbacks for callback in self.progress_callbacks: try: callback(self.training_progress) except Exception as e: logger.error(f"Error in progress callback: {e}") def _create_progress_callback(self): """Create a TrainerCallback for progress tracking. Reused by all training branches.""" from transformers import TrainerCallback trainer_ref = self class _ProgressCallback(TrainerCallback): def on_log(self, args, state, control, logs = None, **kwargs): if not logs: return loss_value = logs.get("loss", logs.get("train_loss", 0.0)) current_step = state.global_step grad_norm = logs.get("grad_norm", None) elapsed_seconds = None if trainer_ref.training_start_time is not None: elapsed_seconds = time.time() - trainer_ref.training_start_time eta_seconds = None if elapsed_seconds is not None and current_step > 0: total_steps = trainer_ref.training_progress.total_steps if total_steps > 0: steps_remaining = total_steps - current_step if steps_remaining > 0: eta_seconds = ( elapsed_seconds / current_step ) * steps_remaining num_tokens = getattr(state, "num_input_tokens_seen", None) trainer_ref._update_progress( step = current_step, epoch = round(state.epoch, 2) if state.epoch else 0, loss = loss_value, learning_rate = logs.get("learning_rate", 0.0), elapsed_seconds = elapsed_seconds, eta_seconds = eta_seconds, grad_norm = grad_norm, num_tokens = num_tokens, eval_loss = logs.get("eval_loss", None), status_message = "", ) def on_epoch_end(self, args, state, control, **kwargs): trainer_ref._update_progress(epoch = state.epoch, step = state.global_step) def on_step_end(self, args, state, control, **kwargs): if trainer_ref.should_stop: logger.info(f"Stop detected at step {state.global_step}\n") control.should_training_stop = True return control return _ProgressCallback() def _calculate_total_steps( self, num_samples, batch_size, grad_accum, num_epochs, max_steps ): """Calculate total training steps from dataset size and training params.""" if max_steps and max_steps > 0: return max_steps len_dataloader = math.ceil(num_samples / batch_size) steps_per_epoch = max( len_dataloader // grad_accum + int(len_dataloader % grad_accum > 0), 1 ) return steps_per_epoch * num_epochs def _build_audio_training_args(self, training_args, output_dir, *, extra_args = None): """Build training args dict for audio branches. Constructs the common config (batch size, lr, warmup, fp16/bf16, etc.) and applies per-branch overrides via extra_args. """ batch_size = training_args.get("batch_size", 2) gradient_accumulation_steps = training_args.get( "gradient_accumulation_steps", 4 ) warmup_steps_val = training_args.get("warmup_steps", 5) max_steps_val = training_args.get("max_steps", 0) learning_rate = training_args.get("learning_rate", 2e-4) weight_decay = training_args.get("weight_decay", 0.001) lr_scheduler_type = training_args.get("lr_scheduler_type", "linear") random_seed = training_args.get("random_seed", 3407) optim_value = training_args.get("optim", "adamw_8bit") config = { "per_device_train_batch_size": batch_size, "gradient_accumulation_steps": gradient_accumulation_steps, "warmup_steps": warmup_steps_val if warmup_steps_val is not None else 5, "learning_rate": learning_rate, "fp16": not is_bfloat16_supported(), "bf16": is_bfloat16_supported(), "logging_steps": 1, "optim": optim_value, "weight_decay": weight_decay, "lr_scheduler_type": lr_scheduler_type, "seed": random_seed, "output_dir": output_dir, "report_to": _build_report_targets(training_args), } if training_args.get("enable_tensorboard", False): config["logging_dir"] = str( resolve_tensorboard_dir(training_args.get("tensorboard_dir")) ) # max_steps vs epochs if max_steps_val and max_steps_val > 0: config["max_steps"] = max_steps_val else: config["num_train_epochs"] = training_args.get("num_epochs", 3) # save_steps save_steps_val = training_args.get("save_steps", 0) if save_steps_val and save_steps_val > 0: config["save_steps"] = save_steps_val config["save_strategy"] = "steps" # Apply per-branch overrides if extra_args: config.update(extra_args) return config def _finalize_training(self, output_dir, label = ""): """Save model after training and update progress. Used by all training branches.""" if self.should_stop and self.save_on_stop: self.trainer.save_model() self.tokenizer.save_pretrained(output_dir) self._patch_adapter_config(output_dir) msg = f"{label} training stopped" if label else "Training stopped" logger.info(f"\n{msg}. Model saved to {output_dir}\n") self._update_progress( is_training = False, status_message = f"Training stopped. Model saved to {output_dir}", ) elif self.should_stop: msg = f"{label} training cancelled" if label else "Training cancelled" logger.info(f"\n{msg}.\n") self._update_progress( is_training = False, status_message = "Training cancelled." ) else: self.trainer.save_model() self.tokenizer.save_pretrained(output_dir) self._patch_adapter_config(output_dir) msg = f"{label} training completed" if label else "Training completed" logger.info(f"\n{msg}! Model saved to {output_dir}\n") self._update_progress( is_training = False, is_completed = True, status_message = f"Training completed! Model saved to {output_dir}", ) def _cleanup_audio_artifacts(self): """Remove sys.path entries and sys.modules from previous audio preprocessing. After audio training, cloned repo dirs (OuteTTS, Spark-TTS) remain on sys.path and heavy audio modules (snac, whisper, sparktts, outetts) stay in sys.modules. When the next training run calls dataset.map(num_proc=N), forked child processes inherit this stale state and deadlock. """ import sys as _sys # Remove cloned audio repo paths from sys.path base_dir = os.path.dirname(os.path.abspath(__file__)) audio_paths = [ os.path.join(base_dir, "inference", "OuteTTS"), # DAC/OuteTTS ] # Spark-TTS path is relative to the downloaded repo if self._spark_tts_repo_dir: spark_code_dir = os.path.join( os.path.dirname(self._spark_tts_repo_dir), "Spark-TTS" ) audio_paths.append(spark_code_dir) removed_paths = [] for path in audio_paths: if path in _sys.path: _sys.path.remove(path) removed_paths.append(path) # Remove stale audio modules from sys.modules prefixes = ("snac", "whisper", "sparktts", "outetts") removed_modules = [key for key in _sys.modules if key.startswith(prefixes)] for key in removed_modules: del _sys.modules[key] if removed_paths or removed_modules: logger.info( f"Cleaned up audio artifacts: {len(removed_paths)} paths, " f"{len(removed_modules)} modules\n" ) def _resolve_audio_columns(self, dataset, custom_format_mapping: dict = None): """Resolve audio, text, and speaker columns from user mapping or hardcoded fallback. Returns: dict with keys: audio_col, text_col, speaker_col (speaker_col may be None) """ cols = dataset.column_names if custom_format_mapping: audio_col = None text_col = None speaker_col = None for col, role in custom_format_mapping.items(): if role == "audio": audio_col = col elif role == "text": text_col = col elif role == "speaker_id": speaker_col = col # Use mapping if both required columns exist in the dataset if audio_col and audio_col in cols and text_col and text_col in cols: return { "audio_col": audio_col, "text_col": text_col, "speaker_col": speaker_col, } # Hardcoded fallback (existing behavior) audio_col = next((c for c in cols if c.lower() in ("audio", "speech")), None) text_col = next( ( c for c in cols if c.lower() in ("text", "sentence", "transcript", "transcription") ), None, ) speaker_col = None if "source" in cols: speaker_col = "source" elif "speaker_id" in cols: speaker_col = "speaker_id" return { "audio_col": audio_col, "text_col": text_col, "speaker_col": speaker_col, } def load_model( self, model_name: str, max_seq_length: int = 2048, load_in_4bit: bool = True, hf_token: Optional[str] = None, is_dataset_image: bool = False, is_dataset_audio: bool = False, trust_remote_code: bool = False, full_finetuning: bool = False, ) -> bool: """Load model for training (supports both text and vision models)""" self.load_in_4bit = load_in_4bit # Store for training_meta.json self.trust_remote_code = ( trust_remote_code # For AutoProcessor etc. used during training ) try: if self.model is not None: del self.model if self.tokenizer is not None: del self.tokenizer if self.trainer is not None: del self.trainer logger.info("\nClearing GPU memory before training...") clear_gpu_cache() # Clean up sys.path and sys.modules from previous audio preprocessing # to prevent deadlocks when forking worker processes in dataset.map() self._cleanup_audio_artifacts() # Reload Unsloth-patched transformers modeling modules before clearing # the compiled cache. unsloth_compile_transformers() sets __UNSLOTH_PATCHED__ # on each modeling module and replaces methods with exec'd code. # clear_unsloth_compiled_cache() deletes the disk cache, but the flag # prevents re-compilation — leaving missing cache files. Reloading # restores original class definitions so Unsloth can re-compile cleanly. import sys as _sys import importlib for _key, _mod in list(_sys.modules.items()): if "transformers.models." in _key and ".modeling_" in _key: if hasattr(_mod, "__UNSLOTH_PATCHED__"): try: importlib.reload(_mod) except Exception: pass # Non-critical — Unsloth will handle stale modules # Remove stale compiled cache so the new model gets a fresh one from utils.cache_cleanup import clear_unsloth_compiled_cache clear_unsloth_compiled_cache() # Detect audio model type dynamically (config.json + tokenizer) self._audio_type = detect_audio_type(model_name, hf_token) # audio_vlm is detected as an audio_type now, handle it separately if self._audio_type == "audio_vlm": self.is_audio = False self.is_audio_vlm = ( is_dataset_audio # Only use audio VLM path if dataset has audio ) self._audio_type = None else: self.is_audio = self._audio_type is not None self.is_audio_vlm = False if not self.is_audio and not self.is_audio_vlm: self._cuda_audio_used = False # VLM: vision model with image dataset (mutually exclusive with audio paths) vision = is_vision_model(model_name) if not self.is_audio else False self.is_vlm = not self.is_audio_vlm and vision and is_dataset_image self.model_name = model_name self.max_seq_length = max_seq_length logger.info( f"Audio type: {self._audio_type}, is_audio: {self.is_audio}, is_audio_vlm: {self.is_audio_vlm}" ) logger.info( f"Dataset has images: {is_dataset_image}, audio: {is_dataset_audio}" ) logger.info(f"Using VLM path: {self.is_vlm}") # Reset training state for new run self._update_progress( is_training = True, is_completed = False, error = None, step = 0, loss = 0.0, epoch = 0, ) # Update UI immediately with loading message model_display = ( model_name.split("/")[-1] if "/" in model_name else model_name ) model_type_label = ( "audio" if self.is_audio else ("vision" if self.is_vlm else "text") ) self._update_progress( status_message = f"Loading {model_type_label} model... {model_display}" ) logger.info(f"\nLoading {model_type_label} model: {model_name}") # Set HF token if provided if hf_token: os.environ["HF_TOKEN"] = hf_token # Proactive gated-model check: verify access BEFORE from_pretrained. # Catches ALL gated/private models (text, vision, audio) globally. if "/" in model_name: # Only check HF repo IDs, not local paths try: from huggingface_hub import model_info as hf_model_info info = hf_model_info(model_name, token = hf_token or None) # model_info succeeds even for gated repos (metadata is public), # but info.gated tells us if files require acceptance/token. if info.gated and not hf_token: friendly = ( f"Access denied for '{model_name}'. This model is gated. " f"Please add a Hugging Face token with access and try again." ) logger.error( f"Model '{model_name}' is gated (gated={info.gated}) and no HF token provided" ) self._update_progress(error = friendly, is_training = False) return False except Exception as gate_err: from huggingface_hub.utils import ( GatedRepoError, RepositoryNotFoundError, ) if isinstance(gate_err, (GatedRepoError, RepositoryNotFoundError)): friendly = ( f"Access denied for '{model_name}'. This model is gated or private. " f"Please add a Hugging Face token with access and try again." ) logger.error(f"Gated model check failed: {gate_err}") self._update_progress(error = friendly, is_training = False) return False # Branch based on model type if self._audio_type == "csm": # CSM: FastModel + auto_model=CsmForConditionalGeneration + load_in_4bit=False from unsloth import FastModel from transformers import CsmForConditionalGeneration self.model, self.tokenizer = FastModel.from_pretrained( model_name = model_name, max_seq_length = max_seq_length, dtype = None, auto_model = CsmForConditionalGeneration, load_in_4bit = False, full_finetuning = full_finetuning, token = hf_token, trust_remote_code = trust_remote_code, ) logger.info("Loaded CSM audio model") elif self._audio_type == "whisper": # Whisper: FastModel + auto_model=WhisperForConditionalGeneration + load_in_4bit=False from unsloth import FastModel from transformers import WhisperForConditionalGeneration self.model, self.tokenizer = FastModel.from_pretrained( model_name = model_name, dtype = None, load_in_4bit = False, full_finetuning = full_finetuning, auto_model = WhisperForConditionalGeneration, whisper_language = "English", whisper_task = "transcribe", token = hf_token, trust_remote_code = trust_remote_code, ) # Configure generation settings (notebook lines 100-105) self.model.generation_config.language = "<|en|>" self.model.generation_config.task = "transcribe" self.model.config.suppress_tokens = [] self.model.generation_config.forced_decoder_ids = None logger.info("Loaded Whisper audio model (FastModel)") elif self._audio_type == "snac": # Orpheus: language model with audio codec tokens self.model, self.tokenizer = FastLanguageModel.from_pretrained( model_name = model_name, max_seq_length = max_seq_length, dtype = None, load_in_4bit = load_in_4bit, full_finetuning = full_finetuning, token = hf_token, trust_remote_code = trust_remote_code, ) logger.info( f"Loaded {self._audio_type} audio model (FastLanguageModel)" ) elif self._audio_type == "bicodec": # Spark-TTS: download full repo (contains sparktts package + BiCodec weights), # then load only the LLM subfolder with FastModel. # model_name may be: # "Spark-TTS-0.5B/LLM" (local-style, from YAML mapping) # "unsloth/Spark-TTS-0.5B" (HF repo ID) from unsloth import FastModel from huggingface_hub import snapshot_download if model_name.endswith("/LLM"): # "Spark-TTS-0.5B/LLM" → parent="Spark-TTS-0.5B" local_dir = model_name.rsplit("/", 1)[0] hf_repo = f"unsloth/{local_dir}" llm_path = model_name else: # "unsloth/Spark-TTS-0.5B" → local_dir="Spark-TTS-0.5B" hf_repo = model_name local_dir = model_name.split("/")[-1] llm_path = f"{local_dir}/LLM" repo_path = snapshot_download(hf_repo, local_dir = local_dir) self._spark_tts_repo_dir = os.path.abspath( repo_path ) # Absolute path for sys.path llm_path = os.path.join(self._spark_tts_repo_dir, "LLM") self.model, self.tokenizer = FastModel.from_pretrained( model_name = llm_path, max_seq_length = max_seq_length, dtype = torch.float32, # Spark-TTS requires float32 load_in_4bit = False, full_finetuning = full_finetuning, token = hf_token, trust_remote_code = trust_remote_code, ) logger.info("Loaded Spark-TTS (bicodec) model") elif self._audio_type == "dac": # OuteTTS: uses FastModel (not FastLanguageModel) with load_in_4bit=False from unsloth import FastModel self.model, self.tokenizer = FastModel.from_pretrained( model_name, max_seq_length = max_seq_length, load_in_4bit = False, full_finetuning = full_finetuning, token = hf_token, trust_remote_code = trust_remote_code, ) logger.info("Loaded OuteTTS (dac) model (FastModel)") elif self.is_audio_vlm: # Audio VLM: multimodal model trained on audio (e.g. Gemma 3N) # Uses FastModel (general loader) — returns (model, processor) from unsloth import FastModel self.model, self.tokenizer = FastModel.from_pretrained( model_name = model_name, max_seq_length = max_seq_length, dtype = None, load_in_4bit = load_in_4bit, full_finetuning = full_finetuning, token = hf_token, trust_remote_code = trust_remote_code, ) logger.info("Loaded audio VLM model (FastModel)") elif self.is_vlm: # Load vision model - returns (model, tokenizer) self.model, self.tokenizer = FastVisionModel.from_pretrained( model_name = model_name, max_seq_length = max_seq_length, dtype = None, # Auto-detect load_in_4bit = load_in_4bit, full_finetuning = full_finetuning, token = hf_token, trust_remote_code = trust_remote_code, ) logger.info("Loaded vision model") # Diagnostic: check if FastVisionModel returned a real Processor or a raw tokenizer from transformers import ProcessorMixin tok = self.tokenizer has_image_proc = isinstance(tok, ProcessorMixin) or hasattr( tok, "image_processor" ) logger.info( f"\n[VLM Diagnostic] FastVisionModel returned: {type(tok).__name__}" ) logger.info( f"[VLM Diagnostic] Is ProcessorMixin: {isinstance(tok, ProcessorMixin)}" ) logger.info( f"[VLM Diagnostic] Has image_processor: {hasattr(tok, 'image_processor')}" ) logger.info( f"[VLM Diagnostic] Usable as vision processor: {has_image_proc}\n" ) else: # Load text model - returns (model, tokenizer) self.model, self.tokenizer = FastLanguageModel.from_pretrained( model_name = model_name, max_seq_length = max_seq_length, dtype = None, # Auto-detect load_in_4bit = load_in_4bit, full_finetuning = full_finetuning, token = hf_token, trust_remote_code = trust_remote_code, ) logger.info("Loaded text model") if self.should_stop: return False if full_finetuning: # Enable training mode for full fine-tuning # This ensures all model parameters are trainable; otherwise, they might be frozen. self.model.for_training() self._update_progress(status_message = "Model loaded successfully") logger.info("Model loaded successfully") return True except OSError as e: if "could not get source code" in str(e) and not getattr( self, "_source_code_retried", False ): # Unsloth's patching can leave stale state that makes # inspect.getsource() fail when switching model families # (e.g. gemma3 → gemma3n). The load always succeeds on a # second attempt because the failed first call's partial # imports clean up the stale state as a side effect. self._source_code_retried = True logger.info(f"\n'could not get source code' — retrying once...\n") return self.load_model( model_name = model_name, max_seq_length = max_seq_length, load_in_4bit = load_in_4bit, hf_token = hf_token, is_dataset_image = is_dataset_image, is_dataset_audio = is_dataset_audio, trust_remote_code = trust_remote_code, full_finetuning = full_finetuning, ) error_msg = str(e) error_lower = error_msg.lower() if any( k in error_lower for k in ( "gated repo", "access to it at", "401", "403", "unauthorized", "forbidden", ) ): error_msg = ( f"Access denied for '{model_name}'. This model is gated or private. " f"Please add a Hugging Face token with access and try again." ) logger.error(f"Error loading model: {e}") self._update_progress(error = error_msg, is_training = False) return False except Exception as e: error_msg = str(e) # Catch gated/auth errors and surface a friendly message error_lower = error_msg.lower() if any( k in error_lower for k in ( "gated repo", "access to it at", "401", "403", "unauthorized", "forbidden", ) ): error_msg = ( f"Access denied for '{model_name}'. This model is gated or private. " f"Please add a Hugging Face token with access and try again." ) logger.error(f"Error loading model: {e}") self._update_progress(error = error_msg, is_training = False) return False finally: self._source_code_retried = False def prepare_model_for_training( self, use_lora: bool = True, # Vision-specific LoRA parameters (only used if is_vlm=True) finetune_vision_layers: bool = True, finetune_language_layers: bool = True, finetune_attention_modules: bool = True, finetune_mlp_modules: bool = True, # Standard LoRA parameters target_modules: list = None, lora_r: int = 16, lora_alpha: int = 16, lora_dropout: float = 0.0, use_gradient_checkpointing: str = "unsloth", use_rslora: bool = False, use_loftq: bool = False, ) -> bool: """ Prepare model for training (with optional LoRA). """ try: if self.model is None: raise ValueError("Model not loaded. Call load_model() first.") # Full finetuning mode - skip PEFT entirely if not use_lora: self._update_progress( status_message = "Full finetuning mode - no LoRA adapters" ) logger.info("Full finetuning mode - training all parameters\n") return True # LoRA/QLoRA mode - apply PEFT # "all-linear" is a PEFT keyword that targets every linear layer if isinstance(target_modules, list) and "all-linear" in target_modules: if len(target_modules) == 1: target_modules = "all-linear" else: target_modules = [m for m in target_modules if m != "all-linear"] elif target_modules is None or ( isinstance(target_modules, list) and len(target_modules) == 0 ): target_modules = [ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ] # Validate and normalize gradient_checkpointing # Must be one of: True, False, or "unsloth" if isinstance(use_gradient_checkpointing, str): use_gradient_checkpointing = use_gradient_checkpointing.strip().lower() if ( use_gradient_checkpointing == "" or use_gradient_checkpointing == "unsloth" ): use_gradient_checkpointing = "unsloth" elif use_gradient_checkpointing in ("true", "1", "yes"): use_gradient_checkpointing = True elif use_gradient_checkpointing in ("false", "0", "no"): use_gradient_checkpointing = False else: # Invalid value, default to "unsloth" logger.warning( f"Invalid gradient_checkpointing value: {use_gradient_checkpointing}, defaulting to 'unsloth'" ) use_gradient_checkpointing = "unsloth" elif use_gradient_checkpointing not in (True, False, "unsloth"): # Invalid type or value, default to "unsloth" logger.warning( f"Invalid gradient_checkpointing type/value: {use_gradient_checkpointing}, defaulting to 'unsloth'" ) use_gradient_checkpointing = "unsloth" # Verify model is loaded if self.model is None: error_msg = "Model is None - model was not loaded properly" logger.error(error_msg) self._update_progress(error = error_msg) return False # Check if model has the expected attributes if not hasattr(self.model, "config"): error_msg = "Model does not have config attribute - model may not be loaded correctly" logger.error(error_msg) self._update_progress(error = error_msg) return False logger.info( f"Configuring LoRA adapters (r={lora_r}, alpha={lora_alpha})...\n" ) logger.info( f"Gradient checkpointing: {use_gradient_checkpointing} (type: {type(use_gradient_checkpointing).__name__})\n" ) # Branch based on model type: audio, audio_vlm, vision, or text if self._audio_type in ("csm", "bicodec", "dac") or self.is_audio_vlm: # Models using FastModel.get_peft_model (codec audio + audio VLM) from unsloth import FastModel label = self._audio_type or "audio_vlm" logger.info(f"{label} LoRA configuration:") logger.info(f" - Target modules: {target_modules}") if self.is_audio_vlm: logger.info(f" - Finetune vision layers: {finetune_vision_layers}") logger.info( f" - Finetune language layers: {finetune_language_layers}" ) logger.info( f" - Finetune attention modules: {finetune_attention_modules}" ) logger.info(f" - Finetune MLP modules: {finetune_mlp_modules}") logger.info() peft_kwargs = dict( r = lora_r, target_modules = target_modules, lora_alpha = lora_alpha, lora_dropout = lora_dropout, bias = "none", use_gradient_checkpointing = use_gradient_checkpointing, random_state = 3407, use_rslora = use_rslora, loftq_config = {"loftq_bits": 4, "loftq_iter": 1} if use_loftq else None, ) # Audio VLM models support VLM-style layer selection if self.is_audio_vlm: peft_kwargs.update( finetune_vision_layers = finetune_vision_layers, finetune_language_layers = finetune_language_layers, finetune_attention_modules = finetune_attention_modules, finetune_mlp_modules = finetune_mlp_modules, ) self.model = FastModel.get_peft_model(self.model, **peft_kwargs) elif self._audio_type == "whisper": # Phase 2: Whisper uses FastModel.get_peft_model with task_type=None from unsloth import FastModel logger.info(f"Audio model (whisper) LoRA configuration:") logger.info(f" - Target modules: {target_modules}\n") self.model = FastModel.get_peft_model( self.model, r = lora_r, target_modules = target_modules, lora_alpha = lora_alpha, lora_dropout = lora_dropout, bias = "none", use_gradient_checkpointing = use_gradient_checkpointing, random_state = 3407, use_rslora = use_rslora, loftq_config = {"loftq_bits": 4, "loftq_iter": 1} if use_loftq else None, task_type = None, ) elif self._audio_type == "snac": # Orpheus uses FastLanguageModel.get_peft_model logger.info(f"Audio model ({self._audio_type}) LoRA configuration:") logger.info(f" - Target modules: {target_modules}\n") self.model = FastLanguageModel.get_peft_model( self.model, r = lora_r, target_modules = target_modules, lora_alpha = lora_alpha, lora_dropout = lora_dropout, bias = "none", use_gradient_checkpointing = use_gradient_checkpointing, random_state = 3407, use_rslora = use_rslora, loftq_config = {"loftq_bits": 4, "loftq_iter": 1} if use_loftq else None, ) elif self.is_vlm: # Vision model LoRA logger.info(f"Vision model LoRA configuration:") logger.info(f" - Finetune vision layers: {finetune_vision_layers}") logger.info(f" - Finetune language layers: {finetune_language_layers}") logger.info( f" - Finetune attention modules: {finetune_attention_modules}" ) logger.info(f" - Finetune MLP modules: {finetune_mlp_modules}\n") self.model = FastVisionModel.get_peft_model( self.model, finetune_vision_layers = finetune_vision_layers, finetune_language_layers = finetune_language_layers, finetune_attention_modules = finetune_attention_modules, finetune_mlp_modules = finetune_mlp_modules, r = lora_r, target_modules = target_modules, lora_alpha = lora_alpha, lora_dropout = lora_dropout, bias = "none", use_gradient_checkpointing = use_gradient_checkpointing, random_state = 3407, use_rslora = use_rslora, loftq_config = {"loftq_bits": 4, "loftq_iter": 1} if use_loftq else None, ) else: # Text model LoRA logger.info(f"Text model LoRA configuration:") logger.info(f" - Target modules: {target_modules}\n") self.model = FastLanguageModel.get_peft_model( self.model, r = lora_r, target_modules = target_modules, lora_alpha = lora_alpha, lora_dropout = lora_dropout, bias = "none", use_gradient_checkpointing = use_gradient_checkpointing, random_state = 3407, use_rslora = use_rslora, loftq_config = {"loftq_bits": 4, "loftq_iter": 1} if use_loftq else None, ) # Check if stopped during LoRA preparation if self.should_stop: logger.info("Stopped during LoRA configuration\n") return False self._update_progress(status_message = "LoRA adapters configured") logger.info("LoRA adapters configured successfully\n") return True except Exception as e: import traceback import sys error_details = ( f"{type(e).__name__}: {str(e)}" if str(e) else f"{type(e).__name__} (no message)" ) full_traceback = traceback.format_exc() logger.error(f"Error preparing model: {error_details}") logger.error(f"Full traceback:\n{full_traceback}") logger.info(f"\n[ERROR] Error preparing model: {error_details}") logger.info(f"[ERROR] Full traceback:\n{full_traceback}") self._update_progress(error = error_details) return False def _apply_csm_forward_fix(self): """Monkey-patch CsmForConditionalGeneration.forward to fix depth decoder kwargs. The original transformers forward passes raw **kwargs (num_items_in_batch, causal_mask, etc.) from the Trainer/PEFT through to the depth decoder, causing depth_decoder_loss=None and 'Tensor + NoneType' crash. We patch at both instance AND class level for maximum reliability, and strip non-TransformersKwargs params that Unsloth/PEFT inject. """ import types import torch import torch.nn as nn from transformers.models.csm.modeling_csm import ( CsmForConditionalGeneration, CsmOutputWithPast, ) base_csm = self.model.base_model.model # CsmForConditionalGeneration # Save original forward (the @can_return_tuple wrapped version) _original_forward = CsmForConditionalGeneration.forward # Keys that the depth decoder and its sub-layers actually understand _TRANSFORMERS_KWARGS = { "num_items_in_batch", "output_hidden_states", "output_attentions", "output_router_logits", "cu_seq_lens_q", "cu_seq_lens_k", "max_length_q", "max_length_k", } def _fixed_csm_forward( self, input_ids = None, input_values = None, attention_mask = None, input_values_cutoffs = None, position_ids = None, past_key_values = None, inputs_embeds = None, labels = None, use_cache = None, cache_position = None, logits_to_keep = 0, **kwargs, ): # Strip non-standard kwargs injected by Unsloth/PEFT (causal_mask, # num_logits_to_keep, task_ids, return_dict, etc.) output_attentions = kwargs.pop("output_attentions", None) output_hidden_states = kwargs.pop("output_hidden_states", None) kwargs.pop("return_dict", None) kwargs.pop("causal_mask", None) kwargs.pop("num_logits_to_keep", None) kwargs.pop("task_ids", None) # Only keep recognized TransformersKwargs clean_kwargs = { k: v for k, v in kwargs.items() if k in _TRANSFORMERS_KWARGS } if input_ids is not None and input_ids.ndim == 2: merged = self._merge_input_ids_with_input_values( input_ids, input_values, input_values_cutoffs, labels ) inputs_embeds = merged["inputs_embeds"] labels = merged["labels"] input_ids = None backbone_outputs = self.backbone_model( input_ids = input_ids, attention_mask = attention_mask, position_ids = position_ids, past_key_values = past_key_values, inputs_embeds = inputs_embeds, use_cache = use_cache, cache_position = cache_position, output_attentions = output_attentions, output_hidden_states = output_hidden_states, **clean_kwargs, ) backbone_hidden_states = backbone_outputs[0] slice_indices = ( slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep ) backbone_logits = self.lm_head(backbone_hidden_states[:, slice_indices, :]) loss = None backbone_loss = None depth_decoder_loss = None depth_decoder_outputs = None if labels is not None: backbone_labels = labels[:, :, 0] backbone_loss = self.loss_function( logits = backbone_logits, labels = backbone_labels, vocab_size = self.config.vocab_size, **clean_kwargs, ) train_mask = ~(labels[:, :, 1:] == -100).all(dim = -1) depth_decoder_input_ids = labels[train_mask][ ..., : self.config.num_codebooks - 1 ] depth_decoder_input_ids = nn.functional.pad( depth_decoder_input_ids, (1, 0), value = 0 ) train_idxs = train_mask.nonzero(as_tuple = True) backbone_last_hidden_states = backbone_hidden_states[ train_idxs[0], train_idxs[1] - 1, : ] depth_decoder_labels = labels[train_mask] # Build clean kwargs for depth decoder dd_kwargs = clean_kwargs.copy() # Scale num_items_in_batch for depth decoder (31 codebooks) if "num_items_in_batch" in dd_kwargs: dd_kwargs["num_items_in_batch"] = dd_kwargs[ "num_items_in_batch" ] * (self.config.num_codebooks - 1) depth_decoder_outputs = self.depth_decoder( input_ids = depth_decoder_input_ids, backbone_last_hidden_state = backbone_last_hidden_states, use_cache = False, return_dict = True, labels = depth_decoder_labels, output_attentions = output_attentions, output_hidden_states = output_hidden_states, **dd_kwargs, ) depth_decoder_loss = depth_decoder_outputs.loss if depth_decoder_loss is None: logger.warning( "CSM depth_decoder_loss is None! " f"labels shape={depth_decoder_labels.shape}, " f"train_mask sum={train_mask.sum().item()}" ) # Fallback: use only backbone loss instead of crashing loss = backbone_loss else: loss = backbone_loss + depth_decoder_loss return CsmOutputWithPast( loss = loss, backbone_loss = backbone_loss, depth_decoder_loss = depth_decoder_loss, logits = backbone_logits, past_key_values = backbone_outputs.past_key_values, hidden_states = backbone_outputs.hidden_states, attentions = backbone_outputs.attentions, depth_decoder_logits = ( depth_decoder_outputs.logits if depth_decoder_outputs else None ), depth_decoder_past_key_values = ( depth_decoder_outputs.past_key_values if depth_decoder_outputs else None ), depth_decoder_hidden_states = ( depth_decoder_outputs.hidden_states if depth_decoder_outputs else None ), depth_decoder_attentions = ( depth_decoder_outputs.attentions if depth_decoder_outputs else None ), ) # Patch at BOTH instance and class level for maximum reliability. # Instance-level: catches calls via BaseTuner.forward -> self.model.forward() base_csm.forward = types.MethodType(_fixed_csm_forward, base_csm) # Class-level: catches any path that resolves through the class dict CsmForConditionalGeneration.forward = _fixed_csm_forward logger.info("Applied CSM forward fix (class + instance level)\n") def _preprocess_csm_dataset(self, dataset, custom_format_mapping = None): """Preprocess dataset for CSM TTS training (exact notebook copy).""" from transformers import AutoProcessor from datasets import Audio import torch processor = AutoProcessor.from_pretrained( self.model_name, trust_remote_code = getattr(self, "trust_remote_code", False), ) # Strip pad_to_multiple_of from tokenizer init_kwargs — fine-tuned models # (e.g. keanteng/sesame-csm-elise) save it in tokenizer_config.json, and # _merge_kwargs leaks it into audio_kwargs where EncodecFeatureExtractor rejects it. processor.tokenizer.init_kwargs.pop("pad_to_multiple_of", None) # Resolve columns from user mapping or hardcoded fallback resolved = self._resolve_audio_columns(dataset, custom_format_mapping) audio_col = resolved["audio_col"] text_col = resolved["text_col"] speaker_key = resolved["speaker_col"] if audio_col is None: raise ValueError( f"No audio column found in dataset. Columns: {dataset.column_names}" ) if text_col is None: raise ValueError( f"No text column found in dataset. Columns: {dataset.column_names}" ) if speaker_key is None: logger.info( "No speaker found, adding default 'source' of 0 for all examples\n" ) dataset = dataset.add_column("source", ["0"] * len(dataset)) speaker_key = "source" logger.info( f"CSM preprocessing: audio_col='{audio_col}', text_col='{text_col}', speaker_key='{speaker_key}'\n" ) dataset = dataset.cast_column(audio_col, Audio(sampling_rate = 24000)) required_keys = [ "input_ids", "attention_mask", "labels", "input_values", "input_values_cutoffs", ] self._update_progress(status_message = "Preprocessing CSM dataset...") processed_examples = [] skipped = 0 for idx in range(len(dataset)): if self.should_stop: logger.info("Stopped during CSM preprocessing\n") break example = dataset[idx] try: conversation = [ { "role": str(example[speaker_key]), "content": [ {"type": "text", "text": example.get(text_col, "")}, {"type": "audio", "path": example[audio_col]["array"]}, ], } ] # NOTE: pad_to_multiple_of intentionally omitted from text_kwargs — # CsmProcessor._merge_kwargs leaks it to EncodecFeatureExtractor which rejects it. model_inputs = processor.apply_chat_template( conversation, tokenize = True, return_dict = True, output_labels = True, text_kwargs = { "padding": "max_length", "max_length": 256, "padding_side": "right", }, audio_kwargs = { "sampling_rate": 24_000, "max_length": 240001, "padding": "max_length", }, common_kwargs = {"return_tensors": "pt"}, ) out = {} for k in required_keys: if k not in model_inputs: raise KeyError(f"Missing required key '{k}' in model outputs") out[k] = model_inputs[k][0] if not all(isinstance(out[k], torch.Tensor) for k in out): skipped += 1 continue processed_examples.append(out) except Exception as e: logger.warning(f"Error processing CSM example {idx}: {e}") skipped += 1 continue if (idx + 1) % 100 == 0: self._update_progress( status_message = f"Preprocessing CSM... {idx + 1}/{len(dataset)}" ) if not processed_examples: raise ValueError( f"No valid examples after CSM preprocessing (skipped {skipped})" ) result_dataset = Dataset.from_list(processed_examples) logger.info( f"CSM preprocessing complete: {len(result_dataset)} examples " f"({skipped} skipped)\n" ) return result_dataset def _format_audio_vlm_dataset(self, dataset, custom_format_mapping = None): """Format dataset as audio chat messages for multimodal models (e.g. Gemma 3N). Expects columns: audio (Audio), text (str). Produces: messages column with system/user/assistant chat format. """ from datasets import Audio resolved = self._resolve_audio_columns(dataset, custom_format_mapping) audio_col = resolved["audio_col"] text_col = resolved["text_col"] if not audio_col or not text_col: raise ValueError( f"Audio VLM dataset needs 'audio' and 'text' columns, got: {dataset.column_names}" ) # Store resolved audio column name for the collator closure self._audio_vlm_audio_col = audio_col # Cast audio to 16kHz (standard for speech models) dataset = dataset.cast_column(audio_col, Audio(sampling_rate = 16000)) def format_messages(samples): formatted = {"messages": []} for idx in range(len(samples[audio_col])): audio = samples[audio_col][idx]["array"] label = str(samples[text_col][idx]) message = [ { "role": "system", "content": [ { "type": "text", "text": "You are an assistant that transcribes speech accurately.", } ], }, { "role": "user", "content": [ {"type": "audio", "audio": audio}, {"type": "text", "text": "Please transcribe this audio."}, ], }, {"role": "assistant", "content": [{"type": "text", "text": label}]}, ] formatted["messages"].append(message) return formatted self._update_progress(status_message = "Formatting audio VLM dataset...") dataset = dataset.map( format_messages, batched = True, batch_size = 4, num_proc = safe_num_proc(4) ) logger.info(f"Audio VLM dataset formatted: {len(dataset)} examples\n") return dataset def _preprocess_snac_dataset(self, dataset, custom_format_mapping = None): """Preprocess dataset for Orpheus TTS training with SNAC codec. Mirrors Orpheus_(3B)-TTS.ipynb: encode audio with SNAC (24kHz, 3 hierarchical layers), interleave 7 codes per frame, wrap with Orpheus special tokens, train on full sequence (no label masking). """ import torch import torchaudio.transforms as T SNAC_MODEL_NAME = "hubertsiuzdak/snac_24khz" SNAC_SAMPLE_RATE = 24000 device = "cuda" if torch.cuda.is_available() else "cpu" max_length = self.max_seq_length or 2048 tokenizer = self.tokenizer # Orpheus special token IDs (hardcoded in tokenizer vocabulary) START_OF_HUMAN = 128259 END_OF_HUMAN = 128260 START_OF_AI = 128261 END_OF_AI = 128262 START_OF_SPEECH = 128257 END_OF_SPEECH = 128258 END_OF_TEXT = 128009 AUDIO_OFFSET = 128266 resolved = self._resolve_audio_columns(dataset, custom_format_mapping) audio_col = resolved["audio_col"] text_col = resolved["text_col"] speaker_col = resolved["speaker_col"] has_source = speaker_col is not None if not audio_col or not text_col: raise ValueError( f"SNAC dataset needs 'audio' and 'text' columns, got: {dataset.column_names}" ) # Cast audio column so datasets 4.x AudioDecoder objects are decoded to dicts from datasets import Audio dataset = dataset.cast_column(audio_col, Audio(sampling_rate = SNAC_SAMPLE_RATE)) # Get dataset sample rate from first example (after cast, always SNAC_SAMPLE_RATE) first_audio = dataset[0][audio_col] ds_sample_rate = ( first_audio.get("sampling_rate", SNAC_SAMPLE_RATE) if isinstance(first_audio, dict) else SNAC_SAMPLE_RATE ) # Load SNAC codec model self._update_progress(status_message = "Loading SNAC codec model...") logger.info("Loading SNAC codec model...\n") from snac import SNAC snac_model = SNAC.from_pretrained(SNAC_MODEL_NAME) snac_model = snac_model.to(device).eval() # Resample transform (created once) resample_transform = ( T.Resample(orig_freq = ds_sample_rate, new_freq = SNAC_SAMPLE_RATE) if ds_sample_rate != SNAC_SAMPLE_RATE else None ) self._update_progress(status_message = "Encoding audio with SNAC...") logger.info( f"SNAC preprocessing: audio_col='{audio_col}', text_col='{text_col}', " f"has_source={has_source}, ds_sample_rate={ds_sample_rate}\n" ) processed_examples = [] skipped = 0 for idx in range(len(dataset)): if self.should_stop: logger.info("Stopped during SNAC preprocessing\n") break example = dataset[idx] try: text = example.get(text_col) if not text: skipped += 1 continue audio_data = example.get(audio_col) if audio_data is None or audio_data.get("array") is None: skipped += 1 continue # --- Encode audio with SNAC (notebook lines 122-142) --- waveform = ( torch.from_numpy(audio_data["array"]) .unsqueeze(0) .to(dtype = torch.float32) ) if resample_transform is not None: waveform = resample_transform(waveform) waveform = waveform.unsqueeze(0).to(device) with torch.inference_mode(): codes = snac_model.encode(waveform) # Interleave 7 codes per frame with layer offsets (notebook lines 134-142) all_codes = [] for i in range(codes[0].shape[1]): all_codes.append(codes[0][0][i].item() + AUDIO_OFFSET) all_codes.append(codes[1][0][2 * i].item() + AUDIO_OFFSET + 4096) all_codes.append( codes[2][0][4 * i].item() + AUDIO_OFFSET + (2 * 4096) ) all_codes.append( codes[2][0][(4 * i) + 1].item() + AUDIO_OFFSET + (3 * 4096) ) all_codes.append( codes[1][0][(2 * i) + 1].item() + AUDIO_OFFSET + (4 * 4096) ) all_codes.append( codes[2][0][(4 * i) + 2].item() + AUDIO_OFFSET + (5 * 4096) ) all_codes.append( codes[2][0][(4 * i) + 3].item() + AUDIO_OFFSET + (6 * 4096) ) if len(all_codes) == 0: skipped += 1 continue # Deduplicate consecutive frames with same first code (notebook lines 185-207) deduped = all_codes[:7] for i in range(7, len(all_codes), 7): if all_codes[i] != deduped[-7]: deduped.extend(all_codes[i : i + 7]) all_codes = deduped # --- Build text tokens (notebook lines 217-224) --- text_prompt = ( f"{example[speaker_col]}: {text}" if has_source and example.get(speaker_col) else text ) text_ids = tokenizer.encode(text_prompt, add_special_tokens = True) text_ids.append(END_OF_TEXT) # --- Build full input_ids (notebook lines 225-234) --- input_ids = ( [START_OF_HUMAN] + text_ids + [END_OF_HUMAN] + [START_OF_AI] + [START_OF_SPEECH] + all_codes + [END_OF_SPEECH] + [END_OF_AI] ) # Truncate to max_length input_ids = input_ids[:max_length] # Labels = input_ids (no masking — Orpheus trains on full sequence) labels = list(input_ids) attention_mask = [1] * len(input_ids) processed_examples.append( { "input_ids": input_ids, "labels": labels, "attention_mask": attention_mask, } ) except Exception as e: logger.warning(f"Error processing SNAC example {idx}: {e}") skipped += 1 continue # Progress update every 100 examples if (idx + 1) % 100 == 0: self._update_progress( status_message = f"Encoding audio... {idx + 1}/{len(dataset)}" ) # Free SNAC model from GPU logger.info("Freeing SNAC codec model from GPU...\n") snac_model.to("cpu") del snac_model import gc gc.collect() torch.cuda.empty_cache() self._cuda_audio_used = True if not processed_examples: raise ValueError( f"No valid examples after SNAC preprocessing (skipped {skipped})" ) result_dataset = Dataset.from_list(processed_examples) logger.info( f"SNAC preprocessing complete: {len(result_dataset)} examples " f"({skipped} skipped)\n" ) return result_dataset def _preprocess_bicodec_dataset(self, dataset, custom_format_mapping = None): """Preprocess dataset for Spark-TTS training with BiCodec tokenizer. Mirrors Spark_TTS_(0_5B).ipynb: encode audio with BiCodec (semantic + global tokens), format as special-token text strings for SFTTrainer with dataset_text_field="text". """ import sys import torch import numpy as np import torchaudio.transforms as T import subprocess device = "cuda" if torch.cuda.is_available() else "cpu" # The sparktts Python package lives in the SparkAudio/Spark-TTS GitHub repo, # NOT in the unsloth/Spark-TTS-0.5B HF model repo. Clone it if needed. spark_code_dir = os.path.join( os.path.dirname(self._spark_tts_repo_dir), "Spark-TTS" ) sparktts_pkg = os.path.join(spark_code_dir, "sparktts") if not os.path.isdir(sparktts_pkg): self._update_progress(status_message = "Cloning Spark-TTS code repo...") logger.info(f"Cloning SparkAudio/Spark-TTS to {spark_code_dir}...\n") subprocess.run( [ "git", "clone", "--depth", "1", "https://github.com/SparkAudio/Spark-TTS", spark_code_dir, ], check = True, ) if spark_code_dir not in sys.path: sys.path.insert(0, spark_code_dir) from sparktts.models.audio_tokenizer import BiCodecTokenizer from sparktts.utils.audio import audio_volume_normalize # Resolve audio and text columns resolved = self._resolve_audio_columns(dataset, custom_format_mapping) audio_col = resolved["audio_col"] text_col = resolved["text_col"] speaker_col = resolved["speaker_col"] has_source = speaker_col is not None if not audio_col or not text_col: raise ValueError( f"BiCodec dataset needs 'audio' and 'text' columns, got: {dataset.column_names}" ) # Cast audio column so datasets 4.x AudioDecoder objects are decoded to dicts. # Don't resample here — BiCodec's target_sr may differ; the loop handles resampling. from datasets import Audio dataset = dataset.cast_column(audio_col, Audio()) # Load BiCodec tokenizer self._update_progress(status_message = "Loading BiCodec tokenizer...") logger.info("Loading BiCodec tokenizer...\n") audio_tokenizer = BiCodecTokenizer(self._spark_tts_repo_dir, device) target_sr = audio_tokenizer.config["sample_rate"] self._update_progress(status_message = "Encoding audio with BiCodec...") logger.info( f"BiCodec preprocessing: audio_col='{audio_col}', text_col='{text_col}', " f"has_source={has_source}, target_sr={target_sr}\n" ) def extract_wav2vec2_features(wavs: torch.Tensor) -> torch.Tensor: """Extract wav2vec2 features (average of layers 11, 14, 16).""" if wavs.shape[0] != 1: raise ValueError(f"Expected batch size 1, but got shape {wavs.shape}") wav_np = wavs.squeeze(0).cpu().numpy() processed = audio_tokenizer.processor( wav_np, sampling_rate = 16000, return_tensors = "pt", padding = True, ) input_values = processed.input_values.to( audio_tokenizer.feature_extractor.device ) model_output = audio_tokenizer.feature_extractor(input_values) if model_output.hidden_states is None: raise ValueError("Wav2Vec2Model did not return hidden states.") feats_mix = ( model_output.hidden_states[11] + model_output.hidden_states[14] + model_output.hidden_states[16] ) / 3 return feats_mix processed_examples = [] skipped = 0 for idx in range(len(dataset)): if self.should_stop: logger.info("Stopped during BiCodec preprocessing\n") break example = dataset[idx] try: text = example.get(text_col) if not text: skipped += 1 continue audio_data = example.get(audio_col) if audio_data is None or audio_data.get("array") is None: skipped += 1 continue audio_array = audio_data["array"] sampling_rate = audio_data.get("sampling_rate", target_sr) # Resample if needed if sampling_rate != target_sr: resampler = T.Resample(orig_freq = sampling_rate, new_freq = target_sr) audio_tensor_temp = torch.from_numpy(audio_array).float() audio_array = resampler(audio_tensor_temp).numpy() # Volume normalize if configured if audio_tokenizer.config.get("volume_normalize", False): audio_array = audio_volume_normalize(audio_array) # Get reference clip ref_wav_np = audio_tokenizer.get_ref_clip(audio_array) # Prepare tensors audio_tensor = ( torch.from_numpy(audio_array).unsqueeze(0).float().to(device) ) ref_wav_tensor = ( torch.from_numpy(ref_wav_np).unsqueeze(0).float().to(device) ) # Extract wav2vec2 features feat = extract_wav2vec2_features(audio_tensor) batch = { "wav": audio_tensor, "ref_wav": ref_wav_tensor, "feat": feat.to(device), } # BiCodec tokenize semantic_token_ids, global_token_ids = audio_tokenizer.model.tokenize( batch ) global_tokens = "".join( [ f"<|bicodec_global_{i}|>" for i in global_token_ids.squeeze().cpu().numpy() ] ) semantic_tokens = "".join( [ f"<|bicodec_semantic_{i}|>" for i in semantic_token_ids.squeeze().cpu().numpy() ] ) # Format text with source prefix if available text_content = ( f"{example[speaker_col]}: {text}" if has_source and example.get(speaker_col) else text ) formatted = "".join( [ "<|task_tts|>", "<|start_content|>", text_content, "<|end_content|>", "<|start_global_token|>", global_tokens, "<|end_global_token|>", "<|start_semantic_token|>", semantic_tokens, "<|end_semantic_token|>", "<|im_end|>", ] ) processed_examples.append({"text": formatted}) except Exception as e: logger.warning(f"Error processing BiCodec example {idx}: {e}") skipped += 1 continue # Progress update every 100 examples if (idx + 1) % 100 == 0: self._update_progress( status_message = f"Encoding audio with BiCodec... {idx + 1}/{len(dataset)}" ) # Free BiCodec model from GPU logger.info("Freeing BiCodec tokenizer from GPU...\n") audio_tokenizer.model.cpu() audio_tokenizer.feature_extractor.cpu() del audio_tokenizer import gc gc.collect() torch.cuda.empty_cache() self._cuda_audio_used = True if not processed_examples: raise ValueError( f"No valid examples after BiCodec preprocessing (skipped {skipped})" ) result_dataset = Dataset.from_list(processed_examples) logger.info( f"BiCodec preprocessing complete: {len(result_dataset)} examples " f"({skipped} skipped)\n" ) # Debug: show first example text (truncated) sample = result_dataset[0]["text"] logger.info(f"Sample text (first 200 chars): {sample[:200]}...\n") logger.info(f"Sample text length: {len(sample)} chars\n") return result_dataset def _preprocess_dac_dataset(self, dataset, custom_format_mapping = None): """Preprocess dataset for OuteTTS training with DAC codec. Mirrors Oute_TTS_(1B).ipynb DataCreationV3: uses Whisper for word timings, OuteTTS AudioProcessor for speaker representations, PromptProcessor for training prompts. Outputs text strings for SFTTrainer with dataset_text_field="text". """ import sys import io import tempfile import torch import numpy as np import soundfile as sf from datasets import Dataset as HFDataset from utils.paths import ensure_dir, tmp_root device = "cuda" if torch.cuda.is_available() else "cpu" # Clone OuteTTS repo (same as audio_codecs._load_dac) import subprocess base_dir = os.path.dirname(os.path.abspath(__file__)) outetts_code_dir = os.path.join(base_dir, "inference", "OuteTTS") outetts_pkg = os.path.join(outetts_code_dir, "outetts") if not os.path.isdir(outetts_pkg): self._update_progress(status_message = "Cloning OuteTTS code repo...") logger.info(f"Cloning edwko/OuteTTS to {outetts_code_dir}...\n") subprocess.run( [ "git", "clone", "--depth", "1", "https://github.com/edwko/OuteTTS", outetts_code_dir, ], check = True, ) for fpath in [ os.path.join(outetts_pkg, "models", "gguf_model.py"), os.path.join(outetts_pkg, "interface.py"), os.path.join(outetts_pkg, "__init__.py"), ]: if os.path.exists(fpath): os.remove(fpath) logger.info(f"Removed {fpath}\n") if outetts_code_dir not in sys.path: sys.path.insert(0, outetts_code_dir) from outetts.version.v3.audio_processor import AudioProcessor from outetts.version.v3.prompt_processor import PromptProcessor from outetts.models.config import ModelConfig as OuteTTSModelConfig from outetts.utils.preprocessing import text_normalizations # Resolve audio and text columns resolved = self._resolve_audio_columns(dataset, custom_format_mapping) audio_col = resolved["audio_col"] text_col = resolved["text_col"] if not audio_col or not text_col: raise ValueError( f"DAC dataset needs 'audio' and 'text' columns, got: {dataset.column_names}" ) # Cast audio to 24kHz (notebook: dataset.cast_column("audio", Audio(sampling_rate=24000))) from datasets import Audio dataset = dataset.cast_column(audio_col, Audio(sampling_rate = 24000)) logger.info("Cast audio column to 24kHz\n") # Load Whisper for word timings self._update_progress( status_message = "Loading Whisper model for word timings..." ) logger.info("Loading Whisper model for word timings...\n") import whisper whisper_model = whisper.load_model("turbo", device = device) # Load OuteTTS AudioProcessor + PromptProcessor self._update_progress(status_message = "Loading OuteTTS AudioProcessor...") logger.info("Loading OuteTTS AudioProcessor...\n") model_tokenizer_path = "OuteAI/Llama-OuteTTS-1.0-1B" dummy_config = OuteTTSModelConfig( tokenizer_path = model_tokenizer_path, device = device, audio_codec_path = None, ) audio_processor = AudioProcessor(config = dummy_config) prompt_processor = PromptProcessor(model_tokenizer_path) self._update_progress(status_message = "Preprocessing audio with OuteTTS...") logger.info( f"DAC preprocessing: audio_col='{audio_col}', text_col='{text_col}'\n" ) processed_examples = [] skipped = 0 for idx in range(len(dataset)): if self.should_stop: logger.info("Stopped during DAC preprocessing\n") break example = dataset[idx] try: text = example.get(text_col) if not text or not isinstance(text, str): skipped += 1 continue audio_data = example.get(audio_col) if audio_data is None or audio_data.get("array") is None: skipped += 1 continue audio_array = np.array(audio_data["array"], dtype = np.float32) sampling_rate = audio_data.get("sampling_rate", 24000) # Convert to WAV bytes (Whisper needs a file path) buf = io.BytesIO() sf.write(buf, audio_array, sampling_rate, format = "WAV", subtype = "FLOAT") buf.seek(0) audio_bytes = buf.getvalue() # 1. Get word timings from Whisper with tempfile.NamedTemporaryFile( suffix = ".wav", delete = False, dir = str(ensure_dir(tmp_root())), ) as tmp: tmp.write(audio_bytes) tmp.flush() tmp_path = tmp.name try: whisper_result = whisper_model.transcribe( tmp_path, word_timestamps = True ) finally: Path(tmp_path).unlink(missing_ok = True) normalized_transcript = text_normalizations(text) words_with_timings = [] if whisper_result and "segments" in whisper_result: for segment in whisper_result["segments"]: for word_info in segment.get("words", []): cleaned = word_info["word"].strip() if cleaned: words_with_timings.append( { "word": cleaned, "start": float(word_info["start"]), "end": float(word_info["end"]), } ) if not words_with_timings: skipped += 1 continue # 2. Create speaker representation with AudioProcessor speaker_data_dict = { "audio": {"bytes": audio_bytes}, "text": normalized_transcript, "words": words_with_timings, } speaker = audio_processor.create_speaker_from_dict(speaker_data_dict) if speaker is None: skipped += 1 continue # 3. Get training prompt from PromptProcessor prompt = prompt_processor.get_training_prompt(speaker) if prompt: processed_examples.append({"text": prompt}) except Exception as e: logger.warning(f"Error processing DAC example {idx}: {e}") skipped += 1 continue if (idx + 1) % 100 == 0: self._update_progress( status_message = f"Preprocessing audio with OuteTTS... {idx + 1}/{len(dataset)}" ) # Free Whisper from GPU (notebook: data_processor.whisper_model.to('cpu')) logger.info("Moving Whisper model to CPU...\n") whisper_model.to("cpu") del whisper_model del audio_processor del prompt_processor import gc gc.collect() torch.cuda.empty_cache() self._cuda_audio_used = True if not processed_examples: raise ValueError( f"No valid examples after DAC preprocessing (skipped {skipped})" ) result_dataset = HFDataset.from_list(processed_examples) logger.info( f"DAC preprocessing complete: {len(result_dataset)} examples " f"({skipped} skipped)\n" ) sample = result_dataset[0]["text"] logger.info(f"Sample text (first 200 chars): {sample[:200]}...\n") return result_dataset def _preprocess_whisper_dataset( self, dataset, eval_split = None, custom_format_mapping = None ): """Preprocess dataset for Whisper speech-to-text training. Mirrors Whisper.ipynb: extract audio features with Whisper's feature extractor, tokenize text labels. Returns (train_data, eval_data) where each is a list of dicts with 'input_features' and 'labels'. """ from datasets import Audio WHISPER_SAMPLE_RATE = 16000 resolved = self._resolve_audio_columns(dataset, custom_format_mapping) audio_col = resolved["audio_col"] text_col = resolved["text_col"] if not audio_col or not text_col: raise ValueError( f"Whisper dataset needs 'audio' and 'text' columns, got: {dataset.column_names}" ) # Cast audio to 16kHz (Whisper's expected sample rate) dataset = dataset.cast_column( audio_col, Audio(sampling_rate = WHISPER_SAMPLE_RATE) ) # Train/eval split (notebook does dataset.train_test_split) eval_dataset_raw = None if eval_split: splits = dataset.train_test_split(test_size = 0.06, seed = 42) dataset = splits["train"] eval_dataset_raw = splits["test"] self._update_progress(status_message = "Processing audio for Whisper...") logger.info( f"Whisper preprocessing: audio_col='{audio_col}', text_col='{text_col}', " f"samples={len(dataset)}\n" ) def process_split(ds, split_name = "train"): processed = [] skipped = 0 for idx in range(len(ds)): if self.should_stop: logger.info(f"Stopped during Whisper {split_name} preprocessing\n") break example = ds[idx] try: audio_data = example.get(audio_col) text = example.get(text_col) if ( audio_data is None or audio_data.get("array") is None or not text ): skipped += 1 continue # Extract audio features (notebook line 112-115) features = self.tokenizer.feature_extractor( audio_data["array"], sampling_rate = audio_data["sampling_rate"] ) # Tokenize text (notebook line 116) tokenized_text = self.tokenizer.tokenizer(text) processed.append( { "input_features": features.input_features[0], "labels": tokenized_text.input_ids, } ) except Exception as e: logger.warning( f"Error processing Whisper {split_name} example {idx}: {e}" ) skipped += 1 continue if (idx + 1) % 100 == 0: self._update_progress( status_message = f"Processing {split_name} audio... {idx + 1}/{len(ds)}" ) logger.info( f"Whisper {split_name} preprocessing: {len(processed)} examples ({skipped} skipped)\n" ) return processed train_data = process_split(dataset, "train") eval_data = ( process_split(eval_dataset_raw, "eval") if eval_dataset_raw else None ) if not train_data: raise ValueError("No valid examples after Whisper preprocessing") return (train_data, eval_data) @staticmethod def _resolve_local_files(file_paths: list) -> list[str]: """Resolve a list of local dataset paths to concrete file paths.""" all_files: list[str] = [] for dataset_file in file_paths: if os.path.isabs(dataset_file): file_path = dataset_file else: file_path = str(resolve_dataset_path(dataset_file)) file_path_obj = Path(file_path) if file_path_obj.is_dir(): parquet_dir = ( file_path_obj / "parquet-files" if (file_path_obj / "parquet-files").exists() else file_path_obj ) parquet_files = sorted(parquet_dir.glob("*.parquet")) if parquet_files: all_files.extend(str(p) for p in parquet_files) continue candidates: list[Path] = [] for ext in (".json", ".jsonl", ".csv", ".parquet"): candidates.extend(sorted(file_path_obj.glob(f"*{ext}"))) if candidates: all_files.extend(str(c) for c in candidates) continue raise ValueError( f"No supported data files in directory: {file_path_obj}" ) else: all_files.append(str(file_path_obj)) return all_files @staticmethod def _loader_for_files(files: list[str]) -> str: """Determine the HF datasets loader type from file extensions.""" first_ext = Path(files[0]).suffix.lower() if first_ext in (".json", ".jsonl"): return "json" elif first_ext == ".csv": return "csv" elif first_ext == ".parquet": return "parquet" raise ValueError(f"Unsupported dataset format: {files[0]}") def load_and_format_dataset( self, dataset_source: str, format_type: str = "auto", local_datasets: list = None, local_eval_datasets: list = None, custom_format_mapping: dict = None, subset: str = None, train_split: str = "train", eval_split: str = None, eval_steps: float = 0.00, dataset_slice_start: int = None, dataset_slice_end: int = None, ) -> Optional[tuple]: """ Load and prepare dataset for training. Strategy: format first, then split — ensures both train and eval portions are properly formatted and templated. Returns: Tuple of (dataset_info, eval_dataset) or None on error. eval_dataset may be None if no eval split is available. """ try: dataset = None eval_dataset = None has_separate_eval_source = ( False # True if eval comes from a separate HF split ) eval_enabled = eval_steps is not None and eval_steps > 0 if local_datasets: # Load local datasets using load_dataset() so the result is # Arrow-backed (has cache files). Dataset.from_list() creates # an in-memory dataset with no cache, which forces num_proc=1 # during tokenization/map because sharding requires Arrow files. all_files = self._resolve_local_files(local_datasets) if all_files: loader = self._loader_for_files(all_files) dataset = load_dataset(loader, data_files = all_files, split = "train") # Check if stopped during dataset loading if self.should_stop: logger.info("Stopped during dataset loading\n") return None self._update_progress( status_message = f"Loaded {len(dataset)} samples from local files" ) logger.info(f"Loaded {len(dataset)} samples from local files\n") logger.info(f"[DEBUG] Dataset cache_files: {dataset.cache_files}\n") # Load local eval datasets if provided if local_eval_datasets and eval_enabled: eval_all_files = self._resolve_local_files(local_eval_datasets) if eval_all_files: eval_loader = self._loader_for_files(eval_all_files) eval_dataset = load_dataset( eval_loader, data_files = eval_all_files, split = "train" ) has_separate_eval_source = True logger.info( f"Loaded {len(eval_dataset)} eval samples from local eval files\n" ) elif dataset_source: # Load from Hugging Face split_name = train_split or "train" load_kwargs = {"path": dataset_source, "split": split_name} if subset: load_kwargs["name"] = subset _slice_start = dataset_slice_start or 0 if ( dataset_slice_end is not None and dataset_slice_end >= 0 and dataset_slice_end >= _slice_start ): # Manual slice — stream only the rows we need instead of # downloading the entire dataset. rows_to_stream = dataset_slice_end + 1 logger.info( f"[dataset-slice] Manual slice specified " f"(start={dataset_slice_start}, end={dataset_slice_end}), " f"streaming {rows_to_stream} rows\n" ) stream = load_dataset(**load_kwargs, streaming = True) dataset = Dataset.from_list(list(stream.take(rows_to_stream))) logger.info( f"[dataset-slice] Downloaded {len(dataset)} rows " f"(requested {rows_to_stream})\n" ) self._update_progress( status_message = f"Streamed {len(dataset)} rows from HuggingFace" ) else: self._update_progress( status_message = f"Downloading dataset: {dataset_source}..." ) dataset = load_dataset(**load_kwargs) # Check if stopped during dataset loading if self.should_stop: logger.info("Stopped during dataset loading\n") return None n_rows = len(dataset) if hasattr(dataset, "__len__") else 0 self._update_progress( status_message = f"Downloaded {dataset_source} ({n_rows:,} rows)" ) logger.info( f"Loaded dataset from Hugging Face: {dataset_source} ({n_rows:,} rows)\n" ) # Resolve eval split from a separate HF split (explicit or auto-detected) if eval_enabled: effective_train = train_split or "train" if eval_split and eval_split != effective_train: # Explicit eval split provided - load it directly logger.info(f"Loading explicit eval split: '{eval_split}'\n") eval_load_kwargs = {"path": dataset_source, "split": eval_split} if subset: eval_load_kwargs["name"] = subset eval_dataset = load_dataset(**eval_load_kwargs) has_separate_eval_source = True logger.info( f"Loaded eval split '{eval_split}' with {len(eval_dataset)} rows\n" ) elif eval_split and eval_split == effective_train: # Same split as training — will do 80/20 split after formatting logger.info( f"Eval split '{eval_split}' is the same as train split — will split 80/20\n" ) else: # Auto-detect eval split from HF (returns a separate dataset, or None) eval_dataset = self._auto_detect_eval_split_from_hf( dataset_source = dataset_source, subset = subset, ) if eval_dataset is not None: has_separate_eval_source = True else: logger.info( "Eval disabled (eval_steps <= 0), skipping eval split detection\n" ) if dataset is None: raise ValueError("No dataset provided") # Apply index range slicing if requested (inclusive on both ends) if dataset_slice_start is not None or dataset_slice_end is not None: total_rows = len(dataset) start = dataset_slice_start if dataset_slice_start is not None else 0 end = ( dataset_slice_end if dataset_slice_end is not None else total_rows - 1 ) # Clamp to valid range start = max(0, min(start, total_rows - 1)) end = max(start, min(end, total_rows - 1)) dataset = dataset.select(range(start, end + 1)) logger.info( f"Sliced dataset to rows [{start}, {end}]: {len(dataset)} of {total_rows} rows\n" ) self._update_progress( status_message = f"Sliced dataset to {len(dataset)} rows (indices {start}-{end})" ) # Check if stopped before applying template if self.should_stop: logger.info("Stopped before applying chat template\n") return None # ========== AUDIO MODELS: custom preprocessing ========== if self._audio_type == "csm": processed = self._preprocess_csm_dataset(dataset, custom_format_mapping) return (processed, None) elif self._audio_type == "whisper": train_data, eval_data = self._preprocess_whisper_dataset( dataset, eval_split = eval_split, custom_format_mapping = custom_format_mapping, ) return (train_data, eval_data) elif self._audio_type == "snac": processed = self._preprocess_snac_dataset( dataset, custom_format_mapping ) return (processed, None) elif self._audio_type == "bicodec": processed = self._preprocess_bicodec_dataset( dataset, custom_format_mapping ) return ({"dataset": processed, "final_format": "audio_bicodec"}, None) elif self._audio_type == "dac": processed = self._preprocess_dac_dataset(dataset, custom_format_mapping) return ({"dataset": processed, "final_format": "audio_dac"}, None) elif self.is_audio_vlm: formatted = self._format_audio_vlm_dataset( dataset, custom_format_mapping ) return (formatted, None) # ========== FORMAT FIRST ========== logger.info(f"Formatting dataset with format_type='{format_type}'...\n") dataset_info = format_and_template_dataset( dataset, model_name = self.model_name, tokenizer = self.tokenizer, is_vlm = self.is_vlm, format_type = format_type, dataset_name = dataset_source, custom_format_mapping = custom_format_mapping, progress_callback = self._update_progress, ) # Check if stopped during formatting if self.should_stop: logger.info("Stopped during dataset formatting\n") return None # Abort if dataset formatting/conversion failed if not dataset_info.get("success", True): errors = dataset_info.get("errors", []) error_msg = "; ".join(errors) if errors else "Dataset formatting failed" logger.error(f"Dataset conversion failed: {error_msg}") self._update_progress(error = error_msg) return None detected = dataset_info.get("detected_format", "unknown") final_ds = dataset_info.get("dataset") final_n = len(final_ds) if hasattr(final_ds, "__len__") else "?" self._update_progress( status_message = f"Dataset ready ({final_n:,} samples, {detected} format)" ) logger.info( f"Dataset formatted successfully ({final_n} samples, {detected})\n" ) # ========== THEN SPLIT ========== if has_separate_eval_source and eval_dataset is not None: # Eval came from a separate HF split — format it too logger.info(f"Formatting eval dataset ({len(eval_dataset)} rows)...\n") eval_info = format_and_template_dataset( eval_dataset, model_name = self.model_name, tokenizer = self.tokenizer, is_vlm = self.is_vlm, format_type = format_type, dataset_name = dataset_source, custom_format_mapping = custom_format_mapping, ) eval_dataset = eval_info["dataset"] logger.info(f"Eval dataset formatted successfully\n") elif eval_enabled and not has_separate_eval_source: # No separate eval source — split the already-formatted dataset formatted_dataset = dataset_info["dataset"] split_result = self._resolve_eval_split_from_dataset(formatted_dataset) if split_result is not None: train_portion, eval_dataset = split_result dataset_info["dataset"] = train_portion return (dataset_info, eval_dataset) except Exception as e: logger.error(f"Error loading dataset: {e}") self._update_progress(error = str(e)) return None def _auto_detect_eval_split_from_hf( self, dataset_source: str, subset: str ) -> Optional[Dataset]: """Auto-detect an eval split from HF dataset (separate named split only).""" try: from datasets import get_dataset_split_names load_kwargs = {"path": dataset_source} if subset: load_kwargs["config_name"] = subset available_splits = get_dataset_split_names(**load_kwargs) logger.info(f"Available splits: {available_splits}\n") # Check for common eval split names for candidate in ["eval", "validation", "valid", "val", "test"]: if candidate in available_splits: eval_load_kwargs = {"path": dataset_source, "split": candidate} if subset: eval_load_kwargs["name"] = subset candidate_ds = load_dataset(**eval_load_kwargs) if len(candidate_ds) >= 16: logger.info( f"Auto-detected eval split '{candidate}' with {len(candidate_ds)} rows\n" ) return candidate_ds else: logger.info( f"Found eval split '{candidate}' but only {len(candidate_ds)} rows (< 16), skipping\n" ) except Exception as e: logger.warning(f"Could not check dataset splits: {e}") # No separate HF eval split found — caller will handle programmatic splitting return None def _resolve_eval_split_from_dataset(self, dataset) -> Optional[tuple]: """Split a dataset into train and eval portions. Returns: Tuple of (train_dataset, eval_dataset), or None if dataset too small. """ MIN_EVAL_ROWS = 16 MIN_TOTAL_ROWS = 32 # Need at least 16 train + 16 eval n = len(dataset) if n < MIN_TOTAL_ROWS: logger.info(f"Dataset too small ({n} rows) for eval split, skipping eval\n") return None eval_size = max(MIN_EVAL_ROWS, min(128, int(0.05 * n))) # Ensure we don't take more than half the dataset eval_size = min(eval_size, n // 2) logger.info(f"Auto-splitting: {eval_size} rows for eval from {n} total\n") split_result = dataset.train_test_split(test_size = eval_size, seed = 3407) logger.info( f"Split complete: {len(split_result['train'])} train, {len(split_result['test'])} eval\n" ) return (split_result["train"], split_result["test"]) def start_training( self, dataset: Dataset, eval_dataset: Dataset = None, eval_steps: float = 0.00, output_dir: str | None = None, num_epochs: int = 3, learning_rate: float = 5e-5, batch_size: int = 2, gradient_accumulation_steps: int = 4, warmup_steps: int = None, warmup_ratio: float = None, max_steps: int = 0, save_steps: int = 0, weight_decay: float = 0.01, random_seed: int = 3407, packing: bool = False, train_on_completions: bool = False, enable_wandb: bool = False, wandb_project: str = "unsloth-training", wandb_token: str = None, enable_tensorboard: bool = False, tensorboard_dir: str | None = None, **kwargs, ) -> bool: """Start training in a separate thread""" if self.is_training: logger.warning("Training already in progress") return False if self.model is None or self.tokenizer is None: self._update_progress(error = "Model not loaded") return False # Pre-import heavy transformers modules on the main thread. # Unsloth's patched_import hook (deepseek_v3_moe.py) is not thread-safe # with Python's importlib cache, causing KeyError: 'size' if these are # first imported inside the worker thread. import transformers # noqa: F401 – ensures submodules are cached from transformers import ( # noqa: F401 Trainer as _HFTrainer, TrainingArguments as _TrainingArguments, TrainerCallback as _TrainerCallback, ) if self._audio_type == "whisper": from transformers import ( # noqa: F401 Seq2SeqTrainer as _Seq2SeqTrainer, Seq2SeqTrainingArguments as _Seq2SeqTrainingArguments, ) # Start training in separate thread self.training_thread = threading.Thread( target = self._train_worker, args = (dataset,), kwargs = { "output_dir": output_dir, "num_epochs": num_epochs, "learning_rate": learning_rate, "batch_size": batch_size, "gradient_accumulation_steps": gradient_accumulation_steps, "warmup_steps": warmup_steps, "warmup_ratio": warmup_ratio, "max_steps": max_steps, "save_steps": save_steps, "weight_decay": weight_decay, "random_seed": random_seed, "packing": packing, "train_on_completions": train_on_completions, "enable_wandb": enable_wandb, "wandb_project": wandb_project, "wandb_token": wandb_token, "enable_tensorboard": enable_tensorboard, "tensorboard_dir": tensorboard_dir, "eval_dataset": eval_dataset, "eval_steps": eval_steps, **kwargs, }, ) self.should_stop = False self.is_training = True try: self.training_thread.start() return True except Exception as e: self.is_training = False logger.error(f"Failed to start training thread: {e}") return False def _train_worker(self, dataset: Dataset, **training_args): """Worker function for training (runs in separate thread)""" try: # Store training parameters for metrics calculation self.batch_size = training_args.get("batch_size", 2) self.max_seq_length = training_args.get("max_seq_length", 2048) self.gradient_accumulation_steps = training_args.get( "gradient_accumulation_steps", 4 ) # Set training start time self.training_start_time = time.time() self._update_progress(is_training = True, error = None) # Setup logging if training_args.get("enable_wandb", False) and training_args.get( "wandb_token" ): os.environ["WANDB_API_KEY"] = training_args["wandb_token"] import wandb wandb.init( project = training_args.get("wandb_project", "unsloth-training") ) # Create output directory output_dir = str(resolve_output_dir(training_args.get("output_dir"))) ensure_dir(Path(output_dir)) # ========== AUDIO TRAINER BRANCH ========== if self._audio_type == "csm": # CSM uses plain HF Trainer (NOT SFTTrainer) # Needs remove_unused_columns=False for depth decoder (input_values + cutoffs) from transformers import Trainer as HFTrainer, TrainingArguments self._apply_csm_forward_fix() config = self._build_audio_training_args( training_args, output_dir, extra_args = { "remove_unused_columns": False, }, ) self.trainer = HFTrainer( model = self.model, train_dataset = dataset, args = TrainingArguments(**config), ) self.trainer.add_callback(self._create_progress_callback()) batch_size = training_args.get("batch_size", 2) total = self._calculate_total_steps( len(dataset), batch_size, training_args.get("gradient_accumulation_steps", 4), training_args.get("num_epochs", 3), training_args.get("max_steps", 0), ) self._update_progress( total_steps = total, status_message = "Starting CSM training..." ) logger.info(f"CSM training config: {config}\n") self.trainer.train() self._finalize_training(output_dir, "CSM") return elif self._audio_type == "snac": # Orpheus: language model with SNAC codec tokens — plain HF Trainer # DataCollatorForSeq2Seq dynamically pads variable-length sequences per batch # (text + audio codes vary in length) and pads labels with -100. from transformers import ( Trainer as HFTrainer, TrainingArguments, DataCollatorForSeq2Seq, ) config = self._build_audio_training_args(training_args, output_dir) self.trainer = HFTrainer( model = self.model, train_dataset = dataset, args = TrainingArguments(**config), data_collator = DataCollatorForSeq2Seq( tokenizer = self.tokenizer, padding = True, pad_to_multiple_of = 8, ), ) self.trainer.add_callback(self._create_progress_callback()) batch_size = training_args.get("batch_size", 2) total = self._calculate_total_steps( len(dataset), batch_size, training_args.get("gradient_accumulation_steps", 4), training_args.get("num_epochs", 3), training_args.get("max_steps", 0), ) self._update_progress( total_steps = total, status_message = "Starting SNAC training..." ) logger.info(f"SNAC training config: {config}\n") self.trainer.train() self._finalize_training(output_dir, "SNAC") return elif self._audio_type == "whisper": # Whisper: Seq2SeqTrainer with custom speech collator from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments from utils.datasets import DataCollatorSpeechSeq2SeqWithPadding eval_dataset = training_args.get("eval_dataset", None) extra = {"remove_unused_columns": False, "label_names": ["labels"]} if eval_dataset: extra["eval_strategy"] = "steps" extra["eval_steps"] = training_args.get("eval_steps", 5) config = self._build_audio_training_args( training_args, output_dir, extra_args = extra ) trainer_kwargs = { "model": self.model, "train_dataset": dataset, "data_collator": DataCollatorSpeechSeq2SeqWithPadding( processor = self.tokenizer ), "processing_class": self.tokenizer.feature_extractor, "args": Seq2SeqTrainingArguments(**config), } if eval_dataset: trainer_kwargs["eval_dataset"] = eval_dataset self.trainer = Seq2SeqTrainer(**trainer_kwargs) self.trainer.add_callback(self._create_progress_callback()) batch_size = training_args.get("batch_size", 2) total = self._calculate_total_steps( len(dataset), batch_size, training_args.get("gradient_accumulation_steps", 4), training_args.get("num_epochs", 3), training_args.get("max_steps", 0), ) self._update_progress( total_steps = total, status_message = "Starting Whisper training..." ) logger.info(f"Whisper training config: {config}\n") self.trainer.train() self._finalize_training(output_dir, "Whisper") return elif self._audio_type is not None and self._audio_type not in ( "bicodec", "dac", ): # bicodec/dac use the standard SFTTrainer text path below raise NotImplementedError( f"Audio training for '{self._audio_type}' not yet implemented" ) # ========== DATA COLLATOR SELECTION ========== # Detect special model types model_name_lower = self.model_name.lower() is_deepseek_ocr = ( "deepseek" in model_name_lower and "ocr" in model_name_lower ) logger.info("Configuring data collator...\n") data_collator = None # Default to built-in data collator if is_deepseek_ocr: # Special DeepSeek OCR collator - auto-install if needed logger.info("Detected DeepSeek OCR model\n") # Ensure DeepSeek OCR module is installed if not _ensure_deepseek_ocr_installed(): error_msg = ( "Failed to install DeepSeek OCR module. " "Please install manually: " "from huggingface_hub import snapshot_download; " "snapshot_download('unsloth/DeepSeek-OCR', local_dir='deepseek_ocr')" ) logger.error(error_msg) self._update_progress(error = error_msg, is_training = False) return try: from backend.data_utils import DeepSeekOCRDataCollator logger.info("Configuring DeepSeek OCR data collator...\n") FastVisionModel.for_training(self.model) data_collator = DeepSeekOCRDataCollator( tokenizer = self.tokenizer, model = self.model, image_size = 640, base_size = 1024, crop_mode = True, train_on_responses_only = training_args.get( "train_on_completions", False ), ) logger.info("DeepSeek OCR data collator configured successfully\n") except Exception as e: logger.error(f"Failed to configure DeepSeek OCR collator: {e}") error_msg = f"Error configuring DeepSeek OCR: {str(e)}" self._update_progress(error = error_msg, is_training = False) return elif self.is_audio_vlm: # Audio VLM collator (e.g. Gemma 3N with audio data) # Mirrors the collate_fn from Gemma3N_(4B)-Audio notebook logger.info("Configuring audio VLM data collator...\n") processor = self.tokenizer # FastModel returns processor as tokenizer audio_col_name = getattr(self, "_audio_vlm_audio_col", "audio") def audio_vlm_collate_fn(examples): texts = [] audios = [] for example in examples: text = processor.apply_chat_template( example["messages"], tokenize = False, add_generation_prompt = False, ).strip() texts.append(text) audios.append(example[audio_col_name]["array"]) batch = processor( text = texts, audio = audios, return_tensors = "pt", padding = True ) # Labels = input_ids with special tokens masked labels = batch["input_ids"].clone() labels[labels == processor.tokenizer.pad_token_id] = -100 for attr in ( "audio_token_id", "image_token_id", "boi_token_id", "eoi_token_id", ): token_id = getattr(processor.tokenizer, attr, None) if token_id is not None: labels[labels == token_id] = -100 batch["labels"] = labels return batch data_collator = audio_vlm_collate_fn logger.info("Audio VLM data collator configured\n") elif self.is_vlm: # Standard VLM collator (images) logger.info("Using UnslothVisionDataCollator for vision model\n") from unsloth.trainer import UnslothVisionDataCollator FastVisionModel.for_training(self.model) data_collator = UnslothVisionDataCollator(self.model, self.tokenizer) logger.info("Vision data collator configured\n") # ========== TRAINING CONFIGURATION ========== # Handle warmup_steps vs warmup_ratio warmup_steps_val = training_args.get("warmup_steps", None) warmup_ratio_val = training_args.get("warmup_ratio", None) lr_value = training_args.get("learning_rate", 2e-4) logger.info( f"[DEBUG] learning_rate from training_args: {lr_value} (type: {type(lr_value).__name__})\n" ) config_args = { "per_device_train_batch_size": training_args.get("batch_size", 2), "gradient_accumulation_steps": training_args.get( "gradient_accumulation_steps", 4 ), "num_train_epochs": training_args.get( "num_epochs", 3 ), # Default to epochs "learning_rate": lr_value, "fp16": not is_bfloat16_supported(), "bf16": is_bfloat16_supported(), "logging_steps": 1, "weight_decay": training_args.get("weight_decay", 0.01), "seed": training_args.get("random_seed", 3407), "output_dir": output_dir, "report_to": _build_report_targets(training_args), "include_num_input_tokens_seen": True, # Enable token counting "dataset_num_proc": 1 if (self.is_audio or self.is_audio_vlm or self._cuda_audio_used) else safe_num_proc(max(1, os.cpu_count() // 4)), "max_seq_length": training_args.get("max_seq_length", 2048), } if training_args.get("enable_tensorboard", False): config_args["logging_dir"] = str( resolve_tensorboard_dir(training_args.get("tensorboard_dir")) ) logger.info( f"[DEBUG] dataset_num_proc={config_args['dataset_num_proc']} (is_audio={self.is_audio}, is_audio_vlm={self.is_audio_vlm}, _cuda_audio_used={self._cuda_audio_used})" ) # On Windows with transformers 5.x, disable DataLoader multiprocessing # to avoid issues with modified sys.path (.venv_t5) in spawned workers. if sys.platform == "win32": import transformers as _tf if _tf.__version__.startswith("5."): config_args["dataloader_num_workers"] = 0 # Add warmup parameter - use warmup_ratio if provided, otherwise warmup_steps if warmup_ratio_val is not None: config_args["warmup_ratio"] = warmup_ratio_val logger.info(f"Using warmup_ratio: {warmup_ratio_val}\n") elif warmup_steps_val is not None: config_args["warmup_steps"] = warmup_steps_val logger.info(f"Using warmup_steps: {warmup_steps_val}\n") else: # Default to warmup_steps if neither provided config_args["warmup_steps"] = 5 logger.info(f"Using default warmup_steps: 5\n") # Add save_steps if specified save_steps_val = training_args.get("save_steps", 0) if save_steps_val and save_steps_val > 0: config_args["save_steps"] = save_steps_val config_args["save_strategy"] = "steps" # If max_steps is specified, use it instead of epochs max_steps_val = training_args.get("max_steps", 0) if max_steps_val and max_steps_val > 0: del config_args["num_train_epochs"] # Remove epochs config_args["max_steps"] = max_steps_val # Use steps instead logger.info(f"Training for {max_steps_val} steps\n") else: logger.info(f"Training for {config_args['num_train_epochs']} epochs\n") # ========== EVAL CONFIGURATION ========== eval_dataset = training_args.get("eval_dataset", None) eval_steps_val = training_args.get("eval_steps", 0.00) if eval_dataset is not None: if eval_steps_val > 0: config_args["eval_strategy"] = "steps" config_args["eval_steps"] = eval_steps_val logger.info( f"✅ Evaluation enabled: eval_steps={eval_steps_val} (fraction of total steps)\n" ) logger.info(f"Eval dataset: {len(eval_dataset)} rows\n") else: logger.info( f"⚠️ Eval dataset provided but eval_steps={eval_steps_val} (disabled)\n" ) logger.info("To enable evaluation, set eval_steps > 0.0\n") else: logger.info("No eval dataset — evaluation disabled\n") # Add model-specific parameters # Use optim and lr_scheduler_type from training_args if provided, otherwise use defaults optim_value = training_args.get("optim", "adamw_8bit") lr_scheduler_type_value = training_args.get("lr_scheduler_type", "linear") if self.is_vlm or self.is_audio_vlm: # Vision / audio VLM config (both need skip_prepare_dataset + remove_unused_columns) label = "audio VLM" if self.is_audio_vlm else "vision" logger.info(f"Configuring {label} model training parameters\n") # Use provided values or defaults for vision models optim_value = training_args.get("optim", "adamw_torch_fused") lr_scheduler_type_value = training_args.get( "lr_scheduler_type", "cosine" ) config_args.update( { "optim": optim_value, "lr_scheduler_type": lr_scheduler_type_value, "gradient_checkpointing": True, "gradient_checkpointing_kwargs": {"use_reentrant": False}, "max_grad_norm": 0.3, "remove_unused_columns": False, "dataset_text_field": "", "dataset_kwargs": {"skip_prepare_dataset": True}, "max_length": training_args.get("max_seq_length", 2048), } ) else: logger.info("Configuring text model training parameters\n") config_args.update( { "optim": optim_value, "lr_scheduler_type": lr_scheduler_type_value, "dataset_text_field": "text", } ) # Only add packing for text models (not DeepSeek OCR which is VLM) if not is_deepseek_ocr: packing_enabled = training_args.get("packing", False) config_args["packing"] = packing_enabled logger.info( f"Sequence packing: {'enabled' if packing_enabled else 'disabled'}\n" ) # Audio codec overrides — BiCodec/DAC use the text SFTTrainer path if self._audio_type == "bicodec": config_args["packing"] = False logger.info("Applied BiCodec overrides: packing=False\n") elif self._audio_type == "dac": config_args["packing"] = False logger.info("Applied DAC overrides: packing=False\n") logger.info(f"The configuration is: {config_args}") logger.info("Training configuration prepared\n") # ========== TRAINER INITIALIZATION ========== if self.is_audio_vlm: # Audio VLM (e.g. Gemma 3N + audio): raw Dataset from _format_audio_vlm_dataset # Notebook uses processing_class=processor.tokenizer (text tokenizer only) train_dataset = ( dataset if isinstance(dataset, Dataset) else dataset["dataset"] ) processing_class = ( self.tokenizer.tokenizer if hasattr(self.tokenizer, "tokenizer") else self.tokenizer ) trainer_kwargs = { "model": self.model, "train_dataset": train_dataset, "processing_class": processing_class, "data_collator": data_collator, "args": SFTConfig(**config_args), } if eval_dataset is not None: trainer_kwargs["eval_dataset"] = eval_dataset self.trainer = SFTTrainer(**trainer_kwargs) elif self.is_vlm: # Image VLM: dataset is dict wrapper from format_and_template_dataset train_dataset = ( dataset["dataset"] if isinstance(dataset, dict) else dataset ) trainer_kwargs = { "model": self.model, "train_dataset": train_dataset, "processing_class": self.tokenizer, "data_collator": data_collator, "args": SFTConfig(**config_args), } if eval_dataset is not None: trainer_kwargs["eval_dataset"] = eval_dataset self.trainer = SFTTrainer(**trainer_kwargs) else: # For text-only training, if the tokenizer is actually a Processor # (e.g., Gemma-3 returns a ProcessorMixin even for text), we must # unwrap to the raw tokenizer. Otherwise Unsloth's SFTTrainer detects # ProcessorMixin → sets _is_vlm=True → skips _prepare_dataset entirely, # and the 'text' column never gets tokenized to 'input_ids'. from transformers import ProcessorMixin sft_tokenizer = self.tokenizer if isinstance(self.tokenizer, ProcessorMixin) and hasattr( self.tokenizer, "tokenizer" ): logger.info( f" ⚠️ Unwrapping Processor → raw tokenizer for text-only SFTTrainer" ) sft_tokenizer = self.tokenizer.tokenizer trainer_kwargs = { "model": self.model, "tokenizer": sft_tokenizer, "train_dataset": dataset["dataset"], "data_collator": data_collator, "args": SFTConfig(**config_args), } if eval_dataset is not None: trainer_kwargs["eval_dataset"] = eval_dataset self.trainer = SFTTrainer(**trainer_kwargs) # Restore the full processor as processing_class so checkpoint # saves include preprocessor_config.json (needed for GGUF export). if sft_tokenizer is not self.tokenizer: self.trainer.processing_class = self.tokenizer logger.info("Trainer initialized\n") # ========== TRAIN ON RESPONSES ONLY ========== # Determine if we should train on responses only instruction_part = None response_part = None train_on_responses_enabled = training_args.get( "train_on_completions", False ) # DeepSeek OCR handles this internally in its collator, so skip # Audio VLM handles label masking in its collator, so skip if ( train_on_responses_enabled and not self.is_audio_vlm and not self.is_audio and not (is_deepseek_ocr or dataset["final_format"].lower() == "alpaca") ): try: logger.info("Configuring train on responses only...\n") # Get the template mapping for this model model_name_lower = self.model_name.lower() if model_name_lower in MODEL_TO_TEMPLATE_MAPPER: template_name = MODEL_TO_TEMPLATE_MAPPER[model_name_lower] logger.info(f"Detected template: {template_name}\n") if template_name in TEMPLATE_TO_RESPONSES_MAPPER: instruction_part = TEMPLATE_TO_RESPONSES_MAPPER[ template_name ]["instruction"] response_part = TEMPLATE_TO_RESPONSES_MAPPER[template_name][ "response" ] logger.info( f"Instruction marker: {instruction_part[:50]}...\n" ) logger.info(f"Response marker: {response_part[:50]}...\n") else: logger.info( f"No response mapping found for template: {template_name}\n" ) train_on_responses_enabled = False else: logger.info( f"No template mapping found for model: {self.model_name}\n" ) train_on_responses_enabled = False except Exception as e: logger.warning(f"Could not configure train on responses: {e}") train_on_responses_enabled = False # Apply train on responses only if we have valid parts if ( train_on_responses_enabled and instruction_part and response_part and not self.is_audio_vlm and not self.is_audio and not (is_deepseek_ocr or dataset["final_format"].lower() == "alpaca") ): try: from unsloth.chat_templates import train_on_responses_only self.trainer = train_on_responses_only( self.trainer, instruction_part = instruction_part, response_part = response_part, num_proc = config_args["dataset_num_proc"], ) logger.info("Train on responses only configured successfully\n") # ── Safety net: check if all samples were filtered out ── # Unsloth's train_on_responses_only masks non-response # tokens with -100. If max_seq_length is too short and the # response portion gets truncated away, EVERY sample ends # up with all labels == -100 and Unsloth removes them, # leaving 0 usable training samples. filtered_len = len(self.trainer.train_dataset) original_len = len(dataset["dataset"]) dropped = original_len - filtered_len drop_pct = ( round(100 * dropped / original_len, 1) if original_len > 0 else 0 ) if filtered_len == 0 or drop_pct > 30: max_seq = training_args.get("max_seq_length", 2048) error_msg = ( f"{dropped}/{original_len} samples ({drop_pct}%) " f"were dropped after applying 'train on responses " f"only' — only {filtered_len} remain. This usually " f"means max_seq_length ({max_seq}) is too short " f"and the response portion is being truncated " f"away. Try increasing max_seq_length (e.g. 8192) " f"or disabling 'Train on completions'." ) logger.error(error_msg) self._update_progress(error = error_msg, is_training = False) return if dropped > 0: logger.info( f"⚠️ {dropped}/{original_len} samples " f"({drop_pct}%) were dropped (all labels " f"masked). {filtered_len} samples remain.\n" ) logger.info(f"Post-filter dataset size: {filtered_len} samples\n") # [DEBUG] Decode first sample AFTER train_on_completions applied # try: # _row = self.trainer.train_dataset[0] # _space = self.tokenizer( # " ", add_special_tokens = False # ).input_ids[0] # print("[DEBUG] === After train_on_completions ===", flush = True) # print( # f"[DEBUG] input_ids decoded:\n{self.tokenizer.decode(_row['input_ids'])}\n", # flush = True, # ) # print( # f"[DEBUG] labels decoded (-100 → space):\n{self.tokenizer.decode([_space if x == -100 else x for x in _row['labels']])}\n", # flush = True, # ) # except Exception as _dbg_e: # print( # f"[DEBUG] Could not decode post-completions sample: {_dbg_e}", # flush = True, # ) except Exception as e: logger.warning(f"Failed to apply train on responses only: {e}") train_on_responses_enabled = False else: if train_on_responses_enabled and is_deepseek_ocr: logger.info("Train on responses handled by DeepSeek OCR collator\n") else: logger.info("Training on full sequences (including prompts)\n") # ========== PROGRESS TRACKING ========== self.trainer.add_callback(self._create_progress_callback()) num_samples = len( dataset["dataset"] if isinstance(dataset, dict) else dataset ) batch_size = training_args.get("batch_size", 2) total_steps = self._calculate_total_steps( num_samples, batch_size, training_args.get("gradient_accumulation_steps", 4), training_args.get("num_epochs", 3), training_args.get("max_steps", 0), ) self._update_progress(total_steps = total_steps) # ========== START TRAINING ========== self._update_progress(status_message = "Starting training...") logger.info("Starting training...\n") self.trainer.train() # ========== SAVE MODEL ========== self._finalize_training(output_dir) except Exception as e: import traceback logger.error(f"Training error: {e}") logger.error(f"Full traceback:\n{traceback.format_exc()}") self._update_progress(is_training = False, error = str(e)) finally: self.is_training = False def _patch_adapter_config(self, output_dir: str) -> None: """Patch adapter_config.json with unsloth_training_method. Values: 'qlora', 'lora', 'FT', 'CPT', 'DPO', 'GRPO', etc. For LoRA/QLoRA, the distinction comes from load_in_4bit. """ config_path = os.path.join(output_dir, "adapter_config.json") if not os.path.exists(config_path): logger.info("No adapter_config.json found — skipping training method patch") return try: with open(config_path, "r") as f: config = json.load(f) # Determine the training method if self.load_in_4bit: method = "qlora" else: method = "lora" config["unsloth_training_method"] = method logger.info( f"Patching adapter_config.json with unsloth_training_method='{method}'" ) with open(config_path, "w") as f: json.dump(config, f, indent = 2) except Exception as e: logger.warning(f"Failed to patch adapter_config.json: {e}") def stop_training(self, save: bool = True): """Stop ongoing training""" logger.info(f"\nStopping training (save={save})...") self.should_stop = True self.save_on_stop = save stop_msg = ( "Stopping training and saving checkpoint..." if save else "Cancelling training..." ) self._update_progress(status_message = stop_msg) # If trainer exists, try to stop it gracefully if self.trainer: try: # The callback will catch should_stop flag and stop the training loop logger.info("Training will stop at next step...\n") except Exception as e: logger.error(f"Error stopping trainer: {e}") def get_training_progress(self) -> TrainingProgress: """Get current training progress""" with self._lock: return self.training_progress def cleanup(self): """Cleanup resources""" if self.trainer: self.trainer = None if self.model: self.model = None if self.tokenizer: self.tokenizer = None # Clear GPU memory clear_gpu_cache() def _ensure_deepseek_ocr_installed(): """ Auto-install DeepSeek OCR module if not available. Downloads from HuggingFace hub as a local module. Returns: bool: True if available (either already installed or just installed) """ try: # Try importing to see if already available from deepseek_ocr.modeling_deepseekocr import format_messages logger.info("DeepSeek OCR module already available") return True except ImportError: pass try: logger.info( "DeepSeek OCR module not found. Auto-installing from HuggingFace..." ) logger.info("\n Downloading DeepSeek OCR module from HuggingFace...\n") from huggingface_hub import snapshot_download import sys import os # Get the script directory to install locally script_dir = os.path.dirname(os.path.abspath(__file__)) parent_dir = os.path.dirname(script_dir) # Go up to project root # Download to project root as 'deepseek_ocr' folder local_dir = os.path.join(parent_dir, "deepseek_ocr") snapshot_download( "unsloth/DeepSeek-OCR", local_dir = local_dir, local_dir_use_symlinks = False ) # Add to sys.path if not already there if parent_dir not in sys.path: sys.path.insert(0, parent_dir) # Try importing again from deepseek_ocr.modeling_deepseekocr import format_messages logger.info("DeepSeek OCR module installed successfully") logger.info("DeepSeek OCR module installed successfully!\n") return True except Exception as e: logger.error(f"Failed to install DeepSeek OCR module: {e}") logger.info(f"\n❌ Failed to install DeepSeek OCR module: {e}\n") return False # Global trainer instance _trainer_instance = None def get_trainer() -> UnslothTrainer: """Get global trainer instance""" global _trainer_instance if _trainer_instance is None: _trainer_instance = UnslothTrainer() return _trainer_instance ================================================ FILE: studio/backend/core/training/training.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Training backend — subprocess orchestrator. Each training job runs in a fresh subprocess (mp.get_context("spawn")), solving the transformers version-switching problem. The old in-process UnslothTrainer singleton is only used inside the subprocess (worker.py). This file orchestrates the subprocess lifecycle, pumps events from the worker's mp.Queue, and exposes the same API surface to routes/training.py. Pattern follows core/data_recipe/jobs/manager.py. """ import math import multiprocessing as mp import queue import threading import time import structlog from loggers import get_logger from dataclasses import dataclass, field from pathlib import Path from typing import Optional, Tuple, Any import matplotlib.pyplot as plt logger = get_logger(__name__) _CTX = mp.get_context("spawn") # Plot styling constants PLOT_WIDTH = 8 PLOT_HEIGHT = 3.5 @dataclass class TrainingProgress: """Mirror of trainer.TrainingProgress — kept here so the parent process never needs to import the heavy ML modules.""" epoch: float = 0 step: int = 0 total_steps: int = 0 loss: float = 0.0 learning_rate: float = 0.0 is_training: bool = False is_completed: bool = False error: Optional[str] = None status_message: str = "Ready to train" elapsed_seconds: Optional[float] = None eta_seconds: Optional[float] = None grad_norm: Optional[float] = None num_tokens: Optional[int] = None eval_loss: Optional[float] = None class TrainingBackend: """ Training orchestration backend — subprocess-based. Launches a fresh subprocess per training job, communicates via mp.Queue. """ def __init__(self): # Subprocess state self._proc: Optional[mp.Process] = None self._event_queue: Any = None self._stop_queue: Any = None self._pump_thread: Optional[threading.Thread] = None self._lock = threading.Lock() # Progress state (updated by pump thread from subprocess events) self._progress = TrainingProgress() self._should_stop = False self._cancel_requested = False # True only for stop(save=False) # Training Metrics (consumed by routes for SSE and /metrics) self.loss_history: list = [] self.lr_history: list = [] self.step_history: list = [] self.grad_norm_history: list = [] self.grad_norm_step_history: list = [] self.eval_loss_history: list = [] self.eval_step_history: list = [] self.eval_enabled: bool = False self.current_theme: str = "light" # Job metadata self.current_job_id: Optional[str] = None self._output_dir: Optional[str] = None logger.info("TrainingBackend initialized (subprocess mode)") # ------------------------------------------------------------------ # Public API (called by routes/training.py) # ------------------------------------------------------------------ def start_training(self, **kwargs) -> bool: """Spawn a subprocess to run the full training pipeline. All kwargs are serialized into a config dict and sent to the worker. Returns True if the subprocess was started successfully. """ with self._lock: if self._proc is not None and self._proc.is_alive(): logger.warning("Training subprocess already running") return False # Join prior pump thread to prevent it from consuming events # from the new job's queue (it reads self._event_queue dynamically). if self._pump_thread is not None and self._pump_thread.is_alive(): self._pump_thread.join(timeout = 5.0) if self._pump_thread.is_alive(): logger.warning("Previous pump thread did not exit within 5s") self._pump_thread = None # Reset state self._should_stop = False self._cancel_requested = False self._progress = TrainingProgress( is_training = True, status_message = "Initializing training..." ) self.loss_history.clear() self.lr_history.clear() self.step_history.clear() self.grad_norm_history.clear() self.grad_norm_step_history.clear() self.eval_loss_history.clear() self.eval_step_history.clear() self.eval_enabled = False self._output_dir = None # Build config dict for the subprocess config = { "model_name": kwargs["model_name"], "training_type": kwargs.get("training_type", "LoRA/QLoRA"), "hf_token": kwargs.get("hf_token", ""), "load_in_4bit": kwargs.get("load_in_4bit", True), "max_seq_length": kwargs.get("max_seq_length", 2048), "hf_dataset": kwargs.get("hf_dataset", ""), "local_datasets": kwargs.get("local_datasets"), "local_eval_datasets": kwargs.get("local_eval_datasets"), "format_type": kwargs.get("format_type", ""), "subset": kwargs.get("subset"), "train_split": kwargs.get("train_split", "train"), "eval_split": kwargs.get("eval_split"), "eval_steps": kwargs.get("eval_steps", 0.00), "dataset_slice_start": kwargs.get("dataset_slice_start"), "dataset_slice_end": kwargs.get("dataset_slice_end"), "custom_format_mapping": kwargs.get("custom_format_mapping"), "is_dataset_image": kwargs.get("is_dataset_image", False), "is_dataset_audio": kwargs.get("is_dataset_audio", False), "is_embedding": kwargs.get("is_embedding", False), "num_epochs": kwargs.get("num_epochs", 3), "learning_rate": kwargs.get("learning_rate", "2e-4"), "batch_size": kwargs.get("batch_size", 2), "gradient_accumulation_steps": kwargs.get("gradient_accumulation_steps", 4), "warmup_steps": kwargs.get("warmup_steps"), "warmup_ratio": kwargs.get("warmup_ratio"), "max_steps": kwargs.get("max_steps", 0), "save_steps": kwargs.get("save_steps", 0), "weight_decay": kwargs.get("weight_decay", 0.01), "random_seed": kwargs.get("random_seed", 3407), "packing": kwargs.get("packing", False), "optim": kwargs.get("optim", "adamw_8bit"), "lr_scheduler_type": kwargs.get("lr_scheduler_type", "linear"), "use_lora": kwargs.get("use_lora", True), "lora_r": kwargs.get("lora_r", 16), "lora_alpha": kwargs.get("lora_alpha", 16), "lora_dropout": kwargs.get("lora_dropout", 0.0), "target_modules": kwargs.get("target_modules"), "gradient_checkpointing": kwargs.get("gradient_checkpointing", "unsloth"), "use_rslora": kwargs.get("use_rslora", False), "use_loftq": kwargs.get("use_loftq", False), "train_on_completions": kwargs.get("train_on_completions", False), "finetune_vision_layers": kwargs.get("finetune_vision_layers", True), "finetune_language_layers": kwargs.get("finetune_language_layers", True), "finetune_attention_modules": kwargs.get( "finetune_attention_modules", True ), "finetune_mlp_modules": kwargs.get("finetune_mlp_modules", True), "enable_wandb": kwargs.get("enable_wandb", False), "wandb_token": kwargs.get("wandb_token"), "wandb_project": kwargs.get("wandb_project", "unsloth-training"), "enable_tensorboard": kwargs.get("enable_tensorboard", False), "tensorboard_dir": kwargs.get("tensorboard_dir", "runs"), "trust_remote_code": kwargs.get("trust_remote_code", False), } # Derive load_in_4bit from training_type if config["training_type"] != "LoRA/QLoRA": config["load_in_4bit"] = False # Spawn subprocess from .worker import run_training_process self._event_queue = _CTX.Queue() self._stop_queue = _CTX.Queue() self._proc = _CTX.Process( target = run_training_process, kwargs = { "event_queue": self._event_queue, "stop_queue": self._stop_queue, "config": config, }, daemon = True, ) self._proc.start() logger.info("Training subprocess started (pid=%s)", self._proc.pid) # Start event pump thread self._pump_thread = threading.Thread(target = self._pump_loop, daemon = True) self._pump_thread.start() return True def stop_training(self, save: bool = True) -> bool: """Send stop signal to the training subprocess.""" self._should_stop = True if not save: self._cancel_requested = True with self._lock: if self._stop_queue is not None: try: self._stop_queue.put({"type": "stop", "save": save}) except (OSError, ValueError): pass # Update progress immediately for responsive UI self._progress.status_message = ( "Stopping training and saving checkpoint..." if save else "Cancelling training..." ) return True def force_terminate(self) -> None: """Force-kill the training subprocess so state can be reset immediately.""" with self._lock: if self._proc is not None and self._proc.is_alive(): logger.info( "Force-terminating training subprocess (pid=%s)", self._proc.pid ) self._proc.terminate() proc = self._proc if proc is not None: proc.join(timeout = 5.0) if proc.is_alive(): proc.kill() proc.join(timeout = 2.0) def is_training_active(self) -> bool: """Check if training is currently active.""" with self._lock: # Subprocess alive = active if self._proc is not None and self._proc.is_alive(): return True # Stop was requested and process exited → inactive if self._should_stop: return False # Check progress state p = self._progress if p.is_training: return True if p.is_completed or p.error: return False # Check status message for activity indicators status_lower = (p.status_message or "").lower() if any( k in status_lower for k in [ "cancelled", "canceled", "stopped", "completed", "ready to train", ] ): return False if any( k in status_lower for k in [ "loading", "preparing", "training", "configuring", "tokenizing", "starting", "importing", ] ): return True return False def get_training_status(self, theme: str = "light") -> Tuple: """Get current training status and loss plot.""" with self._lock: progress = self._progress if not (progress.is_training or progress.is_completed or progress.error): return (None, progress) plot = self._create_loss_plot(progress, theme) return (plot, progress) def refresh_plot_for_theme(self, theme: str) -> Optional[plt.Figure]: """Refresh plot with new theme.""" if theme and isinstance(theme, str) and theme in ["light", "dark"]: self.current_theme = theme if self.loss_history: with self._lock: progress = self._progress return self._create_loss_plot(progress, self.current_theme) return None # ------------------------------------------------------------------ # Compatibility shims — routes/training.py accesses these # ------------------------------------------------------------------ class _TrainerShim: """Minimal shim so routes that access backend.trainer.* still work.""" def __init__(self, backend: "TrainingBackend"): self._backend = backend self.should_stop = False @property def training_progress(self): return self._backend._progress @training_progress.setter def training_progress(self, value): self._backend._progress = value def get_training_progress(self): return self._backend._progress def _update_progress(self, **kwargs): with self._backend._lock: for key, value in kwargs.items(): if hasattr(self._backend._progress, key): setattr(self._backend._progress, key, value) @property def trainer(self): """Compatibility shim for routes that access backend.trainer.*""" return self._TrainerShim(self) # ------------------------------------------------------------------ # Event pump (background thread) # ------------------------------------------------------------------ def _pump_loop(self) -> None: """Background thread: consume events from subprocess → update state.""" while True: if self._proc is None or self._event_queue is None: return # Try to read an event event = self._read_queue(self._event_queue, timeout_sec = 0.25) if event is not None: self._handle_event(event) continue # No event — check if process is still alive if self._proc.is_alive(): continue # Process exited — drain remaining events for e in self._drain_queue(self._event_queue): self._handle_event(e) # Mark as done if no explicit complete/error was received with self._lock: if self._progress.is_training: if self._should_stop: self._progress.is_training = False self._progress.status_message = "Training stopped." else: self._progress.is_training = False self._progress.error = ( self._progress.error or "Training process exited unexpectedly" ) return def _handle_event(self, event: dict) -> None: """Apply a subprocess event to local state.""" etype = event.get("type") with self._lock: if etype == "progress": self._progress.step = event.get("step", self._progress.step) self._progress.epoch = event.get("epoch", self._progress.epoch) self._progress.loss = event.get("loss", self._progress.loss) self._progress.learning_rate = event.get( "learning_rate", self._progress.learning_rate ) self._progress.total_steps = event.get( "total_steps", self._progress.total_steps ) self._progress.elapsed_seconds = event.get("elapsed_seconds") self._progress.eta_seconds = event.get("eta_seconds") self._progress.grad_norm = event.get("grad_norm") self._progress.num_tokens = event.get("num_tokens") self._progress.eval_loss = event.get("eval_loss") self._progress.is_training = True status = event.get("status_message", "") if status: self._progress.status_message = status # Update metric histories step = event.get("step", 0) loss = event.get("loss", 0.0) lr = event.get("learning_rate", 0.0) if step >= 0 and loss > 0: self.loss_history.append(loss) self.lr_history.append(lr) self.step_history.append(step) grad_norm = event.get("grad_norm") if grad_norm is not None: try: gn = float(grad_norm) except (TypeError, ValueError): gn = None if gn is not None and math.isfinite(gn): self.grad_norm_history.append(gn) self.grad_norm_step_history.append(step) eval_loss = event.get("eval_loss") if eval_loss is not None: self.eval_loss_history.append(eval_loss) self.eval_step_history.append(step) self.eval_enabled = True elif etype == "eval_configured": self.eval_enabled = True elif etype == "status": self._progress.status_message = event.get("message", "") self._progress.is_training = True elif etype == "complete": self._progress.is_training = False self._progress.is_completed = True self._output_dir = event.get("output_dir") msg = event.get("status_message", "Training completed") self._progress.status_message = msg elif etype == "error": self._progress.is_training = False self._progress.error = event.get("error", "Unknown error") logger.error("Training error: %s", event.get("error")) stack = event.get("stack", "") if stack: logger.error("Stack trace:\n%s", stack) @staticmethod def _read_queue(q: Any, timeout_sec: float) -> Optional[dict]: try: return q.get(timeout = timeout_sec) except queue.Empty: return None except (EOFError, OSError, ValueError): return None @staticmethod def _drain_queue(q: Any) -> list: events = [] while True: try: events.append(q.get_nowait()) except queue.Empty: return events except (EOFError, OSError, ValueError): return events # ------------------------------------------------------------------ # Plot generation (unchanged from original) # ------------------------------------------------------------------ def _create_loss_plot( self, progress: TrainingProgress, theme: str = "light" ) -> plt.Figure: """Create training loss plot with theme-aware styling.""" plt.close("all") LIGHT_STYLE = { "facecolor": "#ffffff", "grid_color": "#d1d5db", "line": "#16b88a", "text": "#1f2937", "empty_text": "#6b7280", } DARK_STYLE = { "facecolor": "#292929", "grid_color": "#404040", "line": "#4ade80", "text": "#e5e7eb", "empty_text": "#9ca3af", } style = LIGHT_STYLE if theme == "light" else DARK_STYLE fig, ax = plt.subplots(figsize = (PLOT_WIDTH, PLOT_HEIGHT)) fig.patch.set_facecolor(style["facecolor"]) ax.set_facecolor(style["facecolor"]) if self.loss_history: steps = self.step_history losses = self.loss_history scatter_color = "#60a5fa" ax.scatter( steps, losses, s = 16, alpha = 0.6, color = scatter_color, linewidths = 0, label = "Training Loss (raw)", ) MA_WINDOW = 20 window = min(MA_WINDOW, len(losses)) if window >= 2: cumsum = [0.0] for v in losses: cumsum.append(cumsum[-1] + float(v)) ma = [] for i in range(len(losses)): start = max(0, i - window + 1) denom = i - start + 1 ma.append((cumsum[i + 1] - cumsum[start]) / denom) ax.plot( steps, ma, color = style["line"], linewidth = 2.5, alpha = 0.95, label = f"Moving Avg ({ma[-1]:.4f})", ) leg = ax.legend(frameon = False, fontsize = 9) for t in leg.get_texts(): t.set_color(style["text"]) ax.set_xlabel("Steps", fontsize = 10, color = style["text"]) ax.set_ylabel("Loss", fontsize = 10, color = style["text"]) if progress.error: title = f"Error: {progress.error}" elif progress.is_completed: title = f"Training completed! Final loss: {progress.loss:.4f}" elif progress.status_message: title = progress.status_message elif progress.step > 0: title = f"Epoch: {progress.epoch} | Step: {progress.step}/{progress.total_steps} | Loss: {progress.loss:.4f}" else: title = "Training Loss" ax.set_title( title, fontsize = 11, fontweight = "bold", pad = 10, color = style["text"] ) ax.grid(True, alpha = 0.4, linestyle = "--", color = style["grid_color"]) ax.tick_params(colors = style["text"], which = "both") ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) ax.spines["bottom"].set_color(style["text"]) ax.spines["left"].set_color(style["text"]) else: display_msg = ( progress.status_message if progress.status_message else "Waiting for training data..." ) ax.text( 0.5, 0.5, display_msg, ha = "center", va = "center", fontsize = 16, color = style["empty_text"], transform = ax.transAxes, ) ax.set_xticks([]) ax.set_yticks([]) for spine in ax.spines.values(): spine.set_visible(False) fig.tight_layout() return fig def _transfer_to_inference_backend(self) -> bool: """Transfer model to inference backend. With subprocess-based training, the model lives in the subprocess and is freed when it exits. Inference must load from the saved checkpoint on disk. This is a no-op placeholder. """ logger.info( "_transfer_to_inference_backend: subprocess training — " "model must be loaded from disk (output_dir=%s)", self._output_dir, ) return False # ========== GLOBAL INSTANCE ========== _training_backend = None def get_training_backend() -> TrainingBackend: """Get global training backend instance""" global _training_backend if _training_backend is None: _training_backend = TrainingBackend() return _training_backend ================================================ FILE: studio/backend/core/training/worker.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Training subprocess entry point. Each training job runs in a fresh subprocess (mp.get_context("spawn")). This gives us a clean Python interpreter with no stale module state — solving the transformers version-switching problem completely. Pattern follows core/data_recipe/jobs/worker.py. """ from __future__ import annotations import structlog from loggers import get_logger import os import sys import time import traceback from pathlib import Path from typing import Any logger = get_logger(__name__) def _activate_transformers_version(model_name: str) -> None: """Activate the correct transformers version BEFORE any ML imports. If the model needs transformers 5.x, prepend the pre-installed .venv_t5/ directory to sys.path. Otherwise do nothing (default 4.57.x in .venv/). """ # Ensure backend is on path for utils imports backend_path = str(Path(__file__).resolve().parent.parent.parent) if backend_path not in sys.path: sys.path.insert(0, backend_path) from utils.transformers_version import ( needs_transformers_5, _resolve_base_model, _ensure_venv_t5_exists, _VENV_T5_DIR, ) resolved = _resolve_base_model(model_name) if needs_transformers_5(resolved): if not _ensure_venv_t5_exists(): raise RuntimeError( f"Cannot activate transformers 5.x: .venv_t5 missing at {_VENV_T5_DIR}" ) if _VENV_T5_DIR not in sys.path: sys.path.insert(0, _VENV_T5_DIR) logger.info("Activated transformers 5.x from %s", _VENV_T5_DIR) # Propagate to child subprocesses (e.g. GGUF converter) _pp = os.environ.get("PYTHONPATH", "") os.environ["PYTHONPATH"] = _VENV_T5_DIR + (os.pathsep + _pp if _pp else "") else: logger.info("Using default transformers (4.57.x) for %s", model_name) def run_training_process( *, event_queue: Any, stop_queue: Any, config: dict, ) -> None: """Subprocess entrypoint. Fresh Python — no stale module state. Args: event_queue: mp.Queue for sending progress/status/error events to parent. stop_queue: mp.Queue for receiving stop commands from parent. config: Training configuration dict with all parameters. """ os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["PYTHONWARNINGS"] = ( "ignore" # Suppress warnings at C-level before imports ) import warnings from loggers.config import LogConfig if os.getenv("ENVIRONMENT_TYPE", "production") == "production": warnings.filterwarnings("ignore") LogConfig.setup_logging( service_name = "unsloth-studio-training-worker", env = os.getenv("ENVIRONMENT_TYPE", "production"), ) model_name = config["model_name"] # ── 1. Activate correct transformers version BEFORE any ML imports ── try: _activate_transformers_version(model_name) except Exception as exc: event_queue.put( { "type": "error", "error": f"Failed to activate transformers version: {exc}", "stack": traceback.format_exc(limit = 20), "ts": time.time(), } ) return # ── 1a. Auto-enable trust_remote_code for unsloth/* transformers 5.x models ── # Some newer architectures (e.g. NemotronH) have config parsing bugs in # transformers that require trust_remote_code=True as a workaround. # Only auto-enable for unsloth/* prefixed models (trusted source). from utils.transformers_version import needs_transformers_5 if ( needs_transformers_5(model_name) and model_name.lower().startswith("unsloth/") and not config.get("trust_remote_code", False) ): config["trust_remote_code"] = True logger.info( "Auto-enabled trust_remote_code for unsloth/* transformers 5.x model: %s", model_name, ) # ── 1b. Auto-install mamba-ssm for SSM/hybrid models (NemotronH, Falcon-H1) ── _SSM_MODEL_SUBSTRINGS = ("nemotron_h", "nemotron-3-nano", "falcon_h1", "falcon-h1") if any(sub in model_name.lower() for sub in _SSM_MODEL_SUBSTRINGS): try: import mamba_ssm # noqa: F401 logger.info("mamba-ssm already installed") except ImportError: logger.info( "SSM model detected — installing mamba-ssm and causal-conv1d (this may take several minutes)..." ) _send_status( event_queue, "Installing mamba-ssm (first time only, ~7 min)..." ) import subprocess as _sp # --no-build-isolation: compile against current torch (no version conflicts) # --no-deps: don't pull in torch/transformers/triton (already installed) for _pkg in ["causal_conv1d", "mamba_ssm"]: _r = _sp.run( [ sys.executable, "-m", "pip", "install", "--no-build-isolation", "--no-deps", "--no-cache-dir", _pkg, ], stdout = _sp.PIPE, stderr = _sp.STDOUT, text = True, ) if _r.returncode != 0: logger.error("Failed to install %s:\n%s", _pkg, _r.stdout) else: logger.info("Installed %s successfully", _pkg) logger.info("mamba-ssm installation complete") # ── 1c. Set fork start method so dataset.map() can multiprocess ── # The parent launched us via spawn (clean process), but the compiled # SFTTrainer checks get_start_method() and disables num_proc if not "fork". # Linux only: fork is the default start method and is safe here (no CUDA # context exists yet). macOS defaults to spawn since Python 3.8 because # fork is unsafe with macOS frameworks (Metal/MPS, CoreFoundation) -- # do NOT override on macOS. Windows has no fork at all. if sys.platform == "linux": import multiprocessing as _mp try: _mp.set_start_method("fork", force = True) except RuntimeError: pass # Already set # ── 1c. On Windows, check Triton availability (must be before import torch) ── if sys.platform == "win32": try: import triton # noqa: F401 logger.info("Triton available — torch.compile enabled") except ImportError: os.environ["TORCHDYNAMO_DISABLE"] = "1" logger.warning( "Triton not found on Windows — torch.compile disabled. " 'Install for better performance: pip install "triton-windows<3.7"' ) # ── 2. Now import ML libraries (fresh in this clean process) ── try: _send_status(event_queue, "Importing Unsloth...") backend_path = str(Path(__file__).resolve().parent.parent.parent) if backend_path not in sys.path: sys.path.insert(0, backend_path) from core.training.trainer import UnslothTrainer, TrainingProgress from utils.paths import ( ensure_dir, resolve_output_dir, resolve_tensorboard_dir, datasets_root, ) import transformers logger.info("Subprocess loaded transformers %s", transformers.__version__) except Exception as exc: event_queue.put( { "type": "error", "error": f"Failed to import ML libraries: {exc}", "stack": traceback.format_exc(limit = 20), "ts": time.time(), } ) return # ── 2b. EMBEDDING MODEL FAST-PATH ── # Embedding models use a completely different pipeline (FastSentenceTransformer # + SentenceTransformerTrainer + MultipleNegativesRankingLoss) so we branch # early and handle the entire flow in a self-contained function. if config.get("is_embedding", False): try: _run_embedding_training(event_queue, stop_queue, config) except Exception as exc: event_queue.put( { "type": "error", "error": str(exc), "stack": traceback.format_exc(limit = 20), "ts": time.time(), } ) return # ── 3. Create a fresh trainer instance ── trainer = UnslothTrainer() # Wire up progress callback → event_queue def _on_progress(progress: TrainingProgress): has_train_loss = progress.step >= 0 and progress.loss > 0 has_eval_loss = progress.eval_loss is not None if has_train_loss or has_eval_loss: event_queue.put( { "type": "progress", "step": progress.step, "epoch": progress.epoch, "loss": progress.loss, "learning_rate": progress.learning_rate, "total_steps": progress.total_steps, "elapsed_seconds": progress.elapsed_seconds, "eta_seconds": progress.eta_seconds, "grad_norm": progress.grad_norm, "num_tokens": progress.num_tokens, "eval_loss": progress.eval_loss, "status_message": progress.status_message, "ts": time.time(), } ) if progress.status_message: _send_status(event_queue, progress.status_message) trainer.add_progress_callback(_on_progress) # Wire up stop_queue polling to trainer.should_stop import threading import queue as _queue def _poll_stop(): while True: try: msg = stop_queue.get(timeout = 1.0) if msg and msg.get("type") == "stop": save = msg.get("save", True) trainer.should_stop = True trainer.save_on_stop = save logger.info("Stop signal received (save=%s)", save) return except _queue.Empty: continue except (EOFError, OSError): return stop_thread = threading.Thread(target = _poll_stop, daemon = True) stop_thread.start() # ── 4. Execute the training pipeline ── # Order: detect → dataset → model → prepare → train # Dataset processing (including LLM-assisted detection) runs BEFORE model # loading so both never occupy VRAM at the same time. try: hf_token = config.get("hf_token", "") hf_token = hf_token if hf_token and hf_token.strip() else None # ── 4a. Lightweight detection + tokenizer (no VRAM) ── _send_status(event_queue, "Detecting model type...") trainer.pre_detect_and_load_tokenizer( model_name = model_name, max_seq_length = config["max_seq_length"], hf_token = hf_token, is_dataset_image = config.get("is_dataset_image", False), is_dataset_audio = config.get("is_dataset_audio", False), trust_remote_code = config.get("trust_remote_code", False), ) if trainer.should_stop: event_queue.put({"type": "complete", "output_dir": None, "ts": time.time()}) return # ── 4b. Load and format dataset (LLM helper may use VRAM briefly) ── _send_status(event_queue, "Loading and formatting dataset...") hf_dataset = config.get("hf_dataset", "") dataset_result = trainer.load_and_format_dataset( dataset_source = hf_dataset if hf_dataset and hf_dataset.strip() else None, format_type = config.get("format_type", ""), local_datasets = config.get("local_datasets") or None, local_eval_datasets = config.get("local_eval_datasets") or None, custom_format_mapping = config.get("custom_format_mapping"), subset = config.get("subset"), train_split = config.get("train_split", "train"), eval_split = config.get("eval_split"), eval_steps = config.get("eval_steps", 0.00), dataset_slice_start = config.get("dataset_slice_start"), dataset_slice_end = config.get("dataset_slice_end"), ) if isinstance(dataset_result, tuple): dataset, eval_dataset = dataset_result else: dataset = dataset_result eval_dataset = None # [DEBUG] Print first sample before model is loaded # dataset is a dict {"dataset": , "detected_format": ..., ...} # or a raw Dataset for audio paths # try: # ds = dataset["dataset"] if isinstance(dataset, dict) else dataset # print( # f"\n[DEBUG] Dataset loaded BEFORE model. type={type(ds).__name__}, len={len(ds)}", # flush = True, # ) # print(f"[DEBUG] Columns: {ds.column_names}", flush = True) # sample = ds[0] # preview = {k: str(v)[:300] for k, v in sample.items()} # print(f"[DEBUG] First sample: {preview}\n", flush = True) # except Exception as e: # print( # f"[DEBUG] Could not preview first sample: {type(e).__name__}: {e}", # flush = True, # ) # Disable eval if eval_steps <= 0 eval_steps = config.get("eval_steps", 0.00) if eval_steps is not None and float(eval_steps) <= 0: eval_dataset = None # Tell the parent process that eval is configured so the frontend # shows "Waiting for first evaluation step..." instead of "not configured" if eval_dataset is not None: event_queue.put( { "type": "eval_configured", "ts": time.time(), } ) if dataset is None or trainer.should_stop: if trainer.should_stop: event_queue.put( {"type": "complete", "output_dir": None, "ts": time.time()} ) else: event_queue.put( { "type": "error", "error": trainer.training_progress.error or "Failed to load dataset", "stack": "", "ts": time.time(), } ) return # ── Start tqdm monitor early so it captures download + tokenization bars ── import threading as _th _tqdm_stop = _th.Event() def _monitor_tqdm(): from tqdm.auto import tqdm as _tqdm_cls while not _tqdm_stop.is_set(): for bar in list(getattr(_tqdm_cls, "_instances", set())): try: n, total = bar.n or 0, bar.total or 0 desc = getattr(bar, "desc", "") or "" if total > 0 and n > 0 and desc: pct = min(int(n * 100 / total), 100) _send_status( event_queue, f"{desc.strip()} {pct}% ({n:,}/{total:,})" ) except (AttributeError, ReferenceError): pass _tqdm_stop.wait(3) _tqdm_thread = _th.Thread(target = _monitor_tqdm, daemon = True) _tqdm_thread.start() training_type = config.get("training_type", "LoRA/QLoRA") use_lora = training_type == "LoRA/QLoRA" # ── 4c. Load training model (uses VRAM — dataset already formatted) ── _send_status(event_queue, "Loading model...") success = trainer.load_model( model_name = model_name, max_seq_length = config["max_seq_length"], load_in_4bit = config["load_in_4bit"], full_finetuning = not use_lora, hf_token = hf_token, is_dataset_image = config.get("is_dataset_image", False), is_dataset_audio = config.get("is_dataset_audio", False), trust_remote_code = config.get("trust_remote_code", False), ) if not success or trainer.should_stop: if trainer.should_stop: event_queue.put( {"type": "complete", "output_dir": None, "ts": time.time()} ) else: error_msg = trainer.training_progress.error or "Failed to load model" event_queue.put( { "type": "error", "error": error_msg, "stack": "", "ts": time.time(), } ) return # ── 4d. Prepare model (LoRA or full finetuning) ── if use_lora: _send_status(event_queue, "Configuring LoRA adapters...") success = trainer.prepare_model_for_training( use_lora = True, finetune_vision_layers = config.get("finetune_vision_layers", True), finetune_language_layers = config.get("finetune_language_layers", True), finetune_attention_modules = config.get( "finetune_attention_modules", True ), finetune_mlp_modules = config.get("finetune_mlp_modules", True), target_modules = config.get("target_modules"), lora_r = config.get("lora_r", 16), lora_alpha = config.get("lora_alpha", 16), lora_dropout = config.get("lora_dropout", 0.0), use_gradient_checkpointing = config.get( "gradient_checkpointing", "unsloth" ), use_rslora = config.get("use_rslora", False), use_loftq = config.get("use_loftq", False), ) else: _send_status(event_queue, "Preparing model for full finetuning...") success = trainer.prepare_model_for_training(use_lora = False) if not success or trainer.should_stop: if trainer.should_stop: event_queue.put( {"type": "complete", "output_dir": None, "ts": time.time()} ) else: event_queue.put( { "type": "error", "error": trainer.training_progress.error or "Failed to prepare model", "stack": "", "ts": time.time(), } ) return # Convert learning rate try: lr_value = float(config.get("learning_rate", "2e-4")) except ValueError: event_queue.put( { "type": "error", "error": f"Invalid learning rate: {config.get('learning_rate')}", "stack": "", "ts": time.time(), } ) return # Generate output dir output_dir = config.get("output_dir") if not output_dir: output_dir = f"{model_name.replace('/', '_')}_{int(time.time())}" output_dir = str(resolve_output_dir(output_dir)) ensure_dir(Path(output_dir)) tensorboard_dir = config.get("tensorboard_dir") if config.get("enable_tensorboard", False): tensorboard_dir = str(resolve_tensorboard_dir(tensorboard_dir)) ensure_dir(Path(tensorboard_dir)) # Start training (directly — no inner thread, we ARE the subprocess) dataset_display = ( config.get("hf_dataset", "") or config.get("uploaded_file", "") or "" ) _send_status( event_queue, f'Training "{model_name}"' + (f"\nDataset = {dataset_display}" if dataset_display else ""), ) max_steps = config.get("max_steps", 0) save_steps = config.get("save_steps", 0) trainer._train_worker( dataset, output_dir = output_dir, num_epochs = config.get("num_epochs", 3), learning_rate = lr_value, batch_size = config.get("batch_size", 2), gradient_accumulation_steps = config.get("gradient_accumulation_steps", 4), warmup_steps = config.get("warmup_steps"), warmup_ratio = config.get("warmup_ratio"), max_steps = max_steps if max_steps and max_steps > 0 else 0, save_steps = save_steps if save_steps and save_steps > 0 else 0, weight_decay = config.get("weight_decay", 0.01), random_seed = config.get("random_seed", 3407), packing = config.get("packing", False), train_on_completions = config.get("train_on_completions", False), enable_wandb = config.get("enable_wandb", False), wandb_project = config.get("wandb_project", "unsloth-training"), wandb_token = config.get("wandb_token"), enable_tensorboard = config.get("enable_tensorboard", False), tensorboard_dir = tensorboard_dir, eval_dataset = eval_dataset, eval_steps = eval_steps, max_seq_length = config.get("max_seq_length", 2048), optim = config.get("optim", "adamw_8bit"), lr_scheduler_type = config.get("lr_scheduler_type", "linear"), ) _tqdm_stop.set() # Check final state progress = trainer.get_training_progress() if progress.error: event_queue.put( { "type": "error", "error": progress.error, "stack": "", "ts": time.time(), } ) else: event_queue.put( { "type": "complete", "output_dir": output_dir, "status_message": progress.status_message or "Training completed", "ts": time.time(), } ) except Exception as exc: event_queue.put( { "type": "error", "error": str(exc), "stack": traceback.format_exc(limit = 20), "ts": time.time(), } ) def _send_status(event_queue: Any, message: str) -> None: """Send a status update to the parent process.""" event_queue.put( { "type": "status", "message": message, "ts": time.time(), } ) def _run_embedding_training(event_queue: Any, stop_queue: Any, config: dict) -> None: """Self-contained embedding model training pipeline. Uses FastSentenceTransformer + SentenceTransformerTrainer + MultipleNegativesRankingLoss — completely separate from the LLM/VLM/audio paths in UnslothTrainer. Mirrors the pattern from the reference embedding notebooks: All_MiniLM_L6_v2.py, BGE_M3.py, EmbeddingGemma_300M.py, ModernBert.py, Qwen3_Embedding_0_6B.py """ import math import queue as _queue import threading model_name = config["model_name"] training_start_time = time.time() # ── 1. Import embedding-specific libraries ── _send_status(event_queue, "Importing embedding libraries...") try: from unsloth import FastSentenceTransformer, is_bfloat16_supported from sentence_transformers import ( SentenceTransformerTrainer, SentenceTransformerTrainingArguments, ) from sentence_transformers.losses import MultipleNegativesRankingLoss from sentence_transformers.training_args import BatchSamplers from datasets import load_dataset, Dataset from transformers import TrainerCallback from utils.paths import datasets_root, resolve_output_dir except ImportError as e: event_queue.put( { "type": "error", "error": f"Failed to import embedding libraries: {e}. " "Ensure 'sentence_transformers' and 'unsloth' are installed.", "stack": traceback.format_exc(limit = 20), "ts": time.time(), } ) return # ── Stop signal handling ── _should_stop = False _save_on_stop = True def _poll_stop(): nonlocal _should_stop, _save_on_stop while True: try: msg = stop_queue.get(timeout = 1.0) if msg and msg.get("type") == "stop": _save_on_stop = msg.get("save", True) _should_stop = True logger.info( "Embedding training: stop signal received (save=%s)", _save_on_stop, ) return except _queue.Empty: continue except (EOFError, OSError): return stop_thread = threading.Thread(target = _poll_stop, daemon = True) stop_thread.start() # ── 2. Load model ── _send_status(event_queue, "Loading embedding model...") try: hf_token = config.get("hf_token", "") hf_token = hf_token if hf_token and hf_token.strip() else None max_seq_length = config.get("max_seq_length", 512) training_type = config.get("training_type", "LoRA/QLoRA") use_lora = training_type == "LoRA/QLoRA" model = FastSentenceTransformer.from_pretrained( model_name = model_name, max_seq_length = max_seq_length, full_finetuning = not use_lora, token = hf_token, ) except Exception as e: event_queue.put( { "type": "error", "error": f"Failed to load embedding model '{model_name}': {e}", "stack": traceback.format_exc(limit = 20), "ts": time.time(), } ) return if _should_stop: event_queue.put({"type": "complete", "output_dir": None, "ts": time.time()}) return # ── 3. Apply LoRA ── if use_lora: _send_status(event_queue, "Configuring LoRA adapters (FEATURE_EXTRACTION)...") try: gradient_checkpointing = config.get("gradient_checkpointing", False) # Normalize: "none" or empty → False if gradient_checkpointing in ("none", "", None): gradient_checkpointing = False model = FastSentenceTransformer.get_peft_model( model, r = config.get("lora_r", 32), target_modules = config.get("target_modules") or ["q_proj", "k_proj", "v_proj", "o_proj"], lora_alpha = config.get("lora_alpha", 64), lora_dropout = config.get("lora_dropout", 0.0), bias = "none", use_gradient_checkpointing = gradient_checkpointing, random_state = config.get("random_seed", 3407), use_rslora = config.get("use_rslora", False), loftq_config = {"loftq_bits": 4, "loftq_iter": 1} if config.get("use_loftq") else None, task_type = "FEATURE_EXTRACTION", ) except Exception as e: event_queue.put( { "type": "error", "error": f"Failed to configure LoRA for embedding model: {e}", "stack": traceback.format_exc(limit = 20), "ts": time.time(), } ) return if _should_stop: event_queue.put({"type": "complete", "output_dir": None, "ts": time.time()}) return # ── 4. Load dataset ── _send_status(event_queue, "Loading dataset...") try: hf_dataset = config.get("hf_dataset", "") local_datasets = config.get("local_datasets") or [] subset = config.get("subset") or None train_split = config.get("train_split", "train") or "train" if hf_dataset and hf_dataset.strip(): hf_token = config.get("hf_token", "") hf_token = hf_token if hf_token and hf_token.strip() else None dataset = load_dataset( hf_dataset.strip(), subset, split = train_split, token = hf_token, ) elif local_datasets: # Load from local file(s) — mirrors the non-embedding pipeline's # directory handling so recipe outputs (parquet-files/) work. all_files: list[str] = [] for dataset_file in local_datasets: file_path = ( dataset_file if os.path.isabs(dataset_file) else os.path.join( str(datasets_root()), dataset_file, ) ) if os.path.isdir(file_path): file_path_obj = Path(file_path) parquet_dir = ( file_path_obj / "parquet-files" if (file_path_obj / "parquet-files").exists() else file_path_obj ) parquet_files = sorted(parquet_dir.glob("*.parquet")) if parquet_files: all_files.extend(str(p) for p in parquet_files) continue candidates: list[Path] = [] for ext in (".json", ".jsonl", ".csv", ".parquet"): candidates.extend(sorted(file_path_obj.glob(f"*{ext}"))) if candidates: all_files.extend(str(c) for c in candidates) continue raise ValueError( f"No supported data files in directory: {file_path_obj}" ) else: all_files.append(file_path) if all_files: first_ext = Path(all_files[0]).suffix.lower() if first_ext in (".json", ".jsonl"): loader = "json" elif first_ext == ".csv": loader = "csv" elif first_ext == ".parquet": loader = "parquet" else: raise ValueError( f"Unsupported local dataset format: {all_files[0]}" ) dataset = load_dataset(loader, data_files = all_files, split = "train") else: event_queue.put( { "type": "error", "error": "No dataset specified for embedding training.", "stack": "", "ts": time.time(), } ) return # Apply dataset slicing if specified slice_start = config.get("dataset_slice_start") slice_end = config.get("dataset_slice_end") if slice_start is not None or slice_end is not None: start = slice_start if slice_start is not None else 0 end = slice_end if slice_end is not None else len(dataset) dataset = dataset.select(range(start, min(end + 1, len(dataset)))) logger.info(f"Embedding dataset loaded: {len(dataset)} samples") except Exception as e: event_queue.put( { "type": "error", "error": f"Failed to load dataset: {e}", "stack": traceback.format_exc(limit = 20), "ts": time.time(), } ) return if _should_stop: event_queue.put({"type": "complete", "output_dir": None, "ts": time.time()}) return # ── 5. Create loss function ── loss = MultipleNegativesRankingLoss(model) # ── 6. Build training arguments ── _send_status(event_queue, "Configuring training...") try: lr_value = float(config.get("learning_rate", "2e-4")) except ValueError: event_queue.put( { "type": "error", "error": f"Invalid learning rate: {config.get('learning_rate')}", "stack": "", "ts": time.time(), } ) return output_dir = config.get("output_dir") if not output_dir: output_dir = str( resolve_output_dir(f"{model_name.replace('/', '_')}_{int(time.time())}") ) num_epochs = config.get("num_epochs", 2) batch_size = config.get("batch_size", 256) gradient_accumulation_steps = config.get("gradient_accumulation_steps", 1) max_steps_val = config.get("max_steps", 0) save_steps_val = config.get("save_steps", 0) warmup_ratio = config.get("warmup_ratio", 0.03) warmup_steps_val = config.get("warmup_steps") log_frequency = config.get("log_frequency", 50) # Build args dict training_args_kwargs = { "output_dir": output_dir, "per_device_train_batch_size": batch_size, "gradient_accumulation_steps": gradient_accumulation_steps, "learning_rate": lr_value, "fp16": not is_bfloat16_supported(), "bf16": is_bfloat16_supported(), "logging_steps": 1, "report_to": ["wandb"] if config.get("enable_wandb") else "none", "lr_scheduler_type": config.get("lr_scheduler_type", "linear"), "batch_sampler": BatchSamplers.NO_DUPLICATES, "optim": config.get("optim", "adamw_8bit"), "weight_decay": config.get("weight_decay", 0.01), "seed": config.get("random_seed", 3407), } # max_steps vs epochs if max_steps_val and max_steps_val > 0: training_args_kwargs["max_steps"] = max_steps_val else: training_args_kwargs["num_train_epochs"] = num_epochs if num_epochs > 0 else 2 # warmup: prefer warmup_ratio (standard for embedding scripts), fallback to steps if warmup_ratio is not None and warmup_ratio > 0: training_args_kwargs["warmup_ratio"] = warmup_ratio elif warmup_steps_val is not None and warmup_steps_val > 0: training_args_kwargs["warmup_steps"] = warmup_steps_val # save_steps if save_steps_val and save_steps_val > 0: training_args_kwargs["save_steps"] = save_steps_val training_args_kwargs["save_strategy"] = "steps" args = SentenceTransformerTrainingArguments(**training_args_kwargs) # ── 7. Calculate total steps for progress tracking ── if max_steps_val and max_steps_val > 0: total_steps = max_steps_val else: effective_epochs = num_epochs if num_epochs > 0 else 2 len_dataloader = math.ceil(len(dataset) / batch_size) steps_per_epoch = max(len_dataloader // gradient_accumulation_steps, 1) total_steps = steps_per_epoch * effective_epochs # ── 8. Create progress callback ── class _EmbeddingProgressCallback(TrainerCallback): """Sends training progress events to the parent process via event_queue.""" def on_log(self, args, state, control, logs = None, **kwargs): if not logs: return loss_value = logs.get("loss", logs.get("train_loss", 0.0)) current_step = state.global_step elapsed = time.time() - training_start_time eta = None if current_step > 0 and total_steps > 0: remaining = total_steps - current_step if remaining > 0: eta = (elapsed / current_step) * remaining event_queue.put( { "type": "progress", "step": current_step, "epoch": round(state.epoch, 2) if state.epoch else 0, "loss": loss_value, "learning_rate": logs.get("learning_rate", 0.0), "total_steps": total_steps, "elapsed_seconds": elapsed, "eta_seconds": eta, "grad_norm": logs.get("grad_norm"), "num_tokens": getattr(state, "num_input_tokens_seen", None), "eval_loss": logs.get("eval_loss"), "status_message": "", "ts": time.time(), } ) def on_step_end(self, args, state, control, **kwargs): if _should_stop: logger.info("Embedding training: stop at step %d", state.global_step) control.should_training_stop = True return control # ── 9. Create trainer and train ── _send_status(event_queue, "Starting embedding training...") try: trainer = SentenceTransformerTrainer( model = model, train_dataset = dataset, loss = loss, args = args, callbacks = [_EmbeddingProgressCallback()], ) trainer.train() except Exception as e: event_queue.put( { "type": "error", "error": f"Embedding training failed: {e}", "stack": traceback.format_exc(limit = 20), "ts": time.time(), } ) return # ── 10. Save model ── if _should_stop and not _save_on_stop: event_queue.put( { "type": "complete", "output_dir": None, "status_message": "Training cancelled", "ts": time.time(), } ) return _send_status(event_queue, "Saving model...") try: model.save_pretrained(output_dir) model.tokenizer.save_pretrained(output_dir) logger.info("Embedding model saved to %s", output_dir) except Exception as e: logger.error("Failed to save embedding model: %s", e) event_queue.put( { "type": "error", "error": f"Training completed but failed to save: {e}", "stack": traceback.format_exc(limit = 20), "ts": time.time(), } ) return # ── 11. Done ── event_queue.put( { "type": "complete", "output_dir": output_dir, "status_message": "Embedding training completed", "ts": time.time(), } ) ================================================ FILE: studio/backend/loggers/.gitkeep ================================================ ================================================ FILE: studio/backend/loggers/__init__.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 from .handlers import get_logger __all__ = ["get_logger"] ================================================ FILE: studio/backend/loggers/config.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """Logging configuration for structured logging with structlog. This module provides centralized logging configuration with environment-specific formats and processors. Supports both development and production environments with consistent structured logging. Key Features: - Environment-specific formatting (JSON for production, console for development) - Timestamp standardization (ISO format) - Context variable integration - Log level filtering - Logger caching for performance """ import logging import os import sys from typing import Optional import structlog class LogConfig: """Structured logging configuration for the application. Provides static method to configure structlog with environment-specific formatting and processors for consistent structured logging. """ @staticmethod def setup_logging( service_name: str = "unsloth-studio-backend", env: Optional[str] = None ) -> structlog.BoundLogger: """Configure structured logging for the application. Args: service_name: Name of the service for logging identification env: Environment (development/production), affects logging format """ # Determine log level from environment log_level_name = os.getenv("LOG_LEVEL", "INFO").upper() # Fallback to INFO if an invalid level is provided log_level = getattr(logging, log_level_name, logging.INFO) structlog.configure( processors = [ # Reorder processors to control field order structlog.processors.TimeStamper(fmt = "iso"), # timestamp first structlog.processors.add_log_level, # level second structlog.contextvars.merge_contextvars, # Custom processor to flatten the extra field lambda logger, method_name, event_dict: { "timestamp": event_dict.get("timestamp"), "level": event_dict.get("level"), "event": event_dict.get("event"), **(event_dict.get("extra", {})), # Flatten extra into main dict **{ k: v for k, v in event_dict.items() if k not in ["timestamp", "level", "event", "extra"] }, }, ( structlog.processors.JSONRenderer(sort_keys = False) # Preserve order if env == "production" else structlog.dev.ConsoleRenderer() ), ], wrapper_class = structlog.make_filtering_bound_logger(log_level), logger_factory = structlog.PrintLoggerFactory(file = sys.stdout), cache_logger_on_first_use = True, ) return structlog.get_logger(service_name) ================================================ FILE: studio/backend/loggers/handlers.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """Logging handlers and middleware for structured logging. This module provides FastAPI middleware and structlog processors for: - Request/response logging with timing - Sensitive data filtering in logs - Structured logging configuration - Error handling with detailed context Key Components: - LoggingMiddleware: FastAPI middleware for request/response logging - filter_sensitive_data: Structlog processor for data sanitization - get_logger: Factory function for structured loggers """ import time from typing import Callable import structlog from fastapi import Request, Response from starlette.middleware.base import BaseHTTPMiddleware logger = structlog.get_logger(__name__) class LoggingMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next: Callable) -> Response: start_time = time.time() try: response = await call_next(request) # Log response process_time = (time.time() - start_time) * 1000 EXCLUDED_PATHS = { "/api/train/status", "/api/train/metrics", "/api/train/hardware", "/api/system", } is_excluded = ( request.url.path in EXCLUDED_PATHS or request.url.path.startswith("/assets/") or request.url.path.endswith( (".png", ".jpg", ".jpeg", ".ico", ".woff", ".woff2", ".ttf") ) ) if not is_excluded: logger.info( "request_completed", method = request.method, path = request.url.path, status_code = response.status_code, process_time_ms = round(process_time, 2), ) return response except Exception as e: logger.error( "request_failed", path = request.url.path, method = request.method, error = str(e), exc_info = True, ) raise def filter_sensitive_data(logger, method_name, event_dict): """Structlog processor to filter out base64 data from logs.""" def filter_value(value): if ( isinstance(value, str) and len(value) > 100 and ("," in value or "/" in value) ): # Likely base64 data, truncate it return value[:20] + "..." elif isinstance(value, dict): return {k: filter_value(v) for k, v in value.items()} elif isinstance(value, list): return [filter_value(item) for item in value] return value return {k: filter_value(v) for k, v in event_dict.items()} def get_logger(name: str) -> structlog.BoundLogger: """Get a logger instance for a specific module. Args: name: Usually __name__ of the module Returns: A bound structured logger """ return structlog.get_logger(name) ================================================ FILE: studio/backend/main.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Main FastAPI application for Unsloth UI Backend """ import os # Suppress annoying C-level dependency warnings globally os.environ["PYTHONWARNINGS"] = "ignore" import shutil import sys import warnings from contextlib import asynccontextmanager # Suppress annoying dependency warnings in production if os.getenv("ENVIRONMENT_TYPE", "production") == "production": warnings.filterwarnings("ignore") # Alternatively, you can be more specific: # warnings.filterwarnings("ignore", category=DeprecationWarning) # warnings.filterwarnings("ignore", module="triton.*") from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse, HTMLResponse, Response from pathlib import Path from datetime import datetime # Import routers from routes import ( auth_router, data_recipe_router, datasets_router, export_router, inference_router, models_router, training_router, ) from auth import storage from utils.hardware import detect_hardware, get_device, DeviceType import utils.hardware.hardware as _hw_module from utils.cache_cleanup import clear_unsloth_compiled_cache @asynccontextmanager async def lifespan(app: FastAPI): """Startup: detect hardware, seed default admin if needed. Shutdown: clean up compiled cache.""" # Clean up any stale compiled cache from previous runs clear_unsloth_compiled_cache() # Remove stale .venv_overlay from previous versions — no longer used. # Version switching now uses .venv_t5/ (pre-installed by setup.sh). overlay_dir = Path(__file__).resolve().parent.parent.parent / ".venv_overlay" if overlay_dir.is_dir(): shutil.rmtree(overlay_dir, ignore_errors = True) # Detect hardware first — sets DEVICE global used everywhere detect_hardware() # Pre-cache the helper GGUF model for LLM-assisted dataset detection. # Runs in a background thread so it doesn't block server startup. import threading def _precache(): try: from utils.datasets.llm_assist import precache_helper_gguf precache_helper_gguf() except Exception: pass # non-critical threading.Thread(target = _precache, daemon = True).start() if storage.ensure_default_admin(): bootstrap_pw = storage.get_bootstrap_password() app.state.bootstrap_password = bootstrap_pw print("\n" + "=" * 60) print("DEFAULT ADMIN ACCOUNT CREATED") print( "Sign in with the seeded credentials and change the password immediately:\n" ) print(f" username: {storage.DEFAULT_ADMIN_USERNAME}") print(f" password: {bootstrap_pw}\n") print("=" * 60 + "\n") else: app.state.bootstrap_password = storage.get_bootstrap_password() yield # Cleanup _hw_module.DEVICE = None clear_unsloth_compiled_cache() # Create FastAPI app app = FastAPI( title = "Unsloth UI Backend", version = "1.0.0", description = "Backend API for Unsloth UI - Training and Model Management", lifespan = lifespan, ) # Initialize structured logging from loggers.config import LogConfig from loggers.handlers import LoggingMiddleware logger = LogConfig.setup_logging( service_name = "unsloth-studio-backend", env = os.getenv("ENVIRONMENT_TYPE", "production"), ) app.add_middleware(LoggingMiddleware) # CORS middleware app.add_middleware( CORSMiddleware, allow_origins = ["*"], # In production, specify allowed origins allow_credentials = True, allow_methods = ["*"], allow_headers = ["*"], ) # ============ Register API Routes ============ # Register routers app.include_router(auth_router, prefix = "/api/auth", tags = ["auth"]) app.include_router(training_router, prefix = "/api/train", tags = ["training"]) app.include_router(models_router, prefix = "/api/models", tags = ["models"]) app.include_router(inference_router, prefix = "/api/inference", tags = ["inference"]) # OpenAI-compatible endpoints: mount the same inference router at /v1 # so external tools (Open WebUI, SillyTavern, etc.) can use the # standard /v1/chat/completions path. app.include_router(inference_router, prefix = "/v1", tags = ["openai-compat"]) app.include_router(datasets_router, prefix = "/api/datasets", tags = ["datasets"]) app.include_router(data_recipe_router, prefix = "/api/data-recipe", tags = ["data-recipe"]) app.include_router(export_router, prefix = "/api/export", tags = ["export"]) # ============ Health and System Endpoints ============ @app.get("/api/health") async def health_check(): """Health check endpoint""" platform_map = {"darwin": "mac", "win32": "windows", "linux": "linux"} device_type = platform_map.get(sys.platform, sys.platform) return { "status": "healthy", "timestamp": datetime.now().isoformat(), "service": "Unsloth UI Backend", "device_type": device_type, "chat_only": _hw_module.CHAT_ONLY, } @app.get("/api/system") async def get_system_info(): """Get system information""" import platform import subprocess import psutil from utils.hardware import get_device, get_gpu_memory_info, DeviceType # GPU Info — query nvidia-smi for physical GPUs, filtered by # CUDA_VISIBLE_DEVICES when set (the frontend uses this for GGUF # fit estimation and llama-server respects CVD too). import os gpu_info: dict = {"available": False, "devices": []} device = get_device() if device == DeviceType.CUDA: # Parse CUDA_VISIBLE_DEVICES allowlist allowed_indices = None cvd = os.environ.get("CUDA_VISIBLE_DEVICES") if cvd is not None and cvd.strip(): try: allowed_indices = set(int(x.strip()) for x in cvd.split(",")) except ValueError: pass # Non-numeric (e.g. GPU-uuid), show all try: result = subprocess.run( [ "nvidia-smi", "--query-gpu=index,name,memory.total", "--format=csv,noheader,nounits", ], capture_output = True, text = True, timeout = 10, ) if result.returncode == 0: for line in result.stdout.strip().splitlines(): parts = [p.strip() for p in line.split(",")] if len(parts) == 3: idx = int(parts[0]) if allowed_indices is not None and idx not in allowed_indices: continue gpu_info["devices"].append( { "index": idx, "name": parts[1], "memory_total_gb": round(int(parts[2]) / 1024, 2), } ) gpu_info["available"] = len(gpu_info["devices"]) > 0 except Exception: pass # Fallback to torch-based single-GPU detection if not gpu_info["available"]: mem_info = get_gpu_memory_info() if mem_info.get("available"): gpu_info["available"] = True gpu_info["devices"].append( { "index": mem_info.get("device", 0), "name": mem_info.get("device_name", "Unknown"), "memory_total_gb": round(mem_info.get("total_gb", 0), 2), } ) # CPU & Memory memory = psutil.virtual_memory() return { "platform": platform.platform(), "python_version": platform.python_version(), "device_backend": get_device().value, "cpu_count": psutil.cpu_count(), "memory": { "total_gb": round(memory.total / 1e9, 2), "available_gb": round(memory.available / 1e9, 2), "percent_used": memory.percent, }, "gpu": gpu_info, } @app.get("/api/system/hardware") async def get_hardware_info(): """Return GPU name, total VRAM, and key ML package versions.""" from utils.hardware import get_gpu_summary, get_package_versions return { "gpu": get_gpu_summary(), "versions": get_package_versions(), } # ============ Serve Frontend (Optional) ============ def _strip_crossorigin(html_bytes: bytes) -> bytes: """Remove ``crossorigin`` attributes from script/link tags. Vite adds ``crossorigin`` by default which forces CORS mode on font subresource loads. When Studio is served over plain HTTP, Firefox HTTPS-Only Mode does not exempt CORS font requests -- causing all @font-face downloads to fail silently. Stripping the attribute makes them regular same-origin fetches that work on any protocol. """ import re as _re html = html_bytes.decode("utf-8") html = _re.sub(r'\s+crossorigin(?:="[^"]*")?', "", html) return html.encode("utf-8") def _inject_bootstrap(html_bytes: bytes, app: FastAPI) -> bytes: """Inject bootstrap credentials into HTML when password change is required. The script tag is only injected while the default admin account still has ``must_change_password=True``. Once the user changes the password the HTML is served clean — no credentials leak. """ import json as _json if not storage.requires_password_change(storage.DEFAULT_ADMIN_USERNAME): return html_bytes bootstrap_pw = getattr(app.state, "bootstrap_password", None) if not bootstrap_pw: return html_bytes payload = _json.dumps( { "username": storage.DEFAULT_ADMIN_USERNAME, "password": bootstrap_pw, } ) tag = f"" html = html_bytes.decode("utf-8") html = html.replace("", f"{tag}", 1) return html.encode("utf-8") def setup_frontend(app: FastAPI, build_path: Path): """Mount frontend static files (optional)""" if not build_path.exists(): return False # Mount assets assets_dir = build_path / "assets" if assets_dir.exists(): app.mount("/assets", StaticFiles(directory = assets_dir), name = "assets") @app.get("/") async def serve_root(): content = (build_path / "index.html").read_bytes() content = _strip_crossorigin(content) content = _inject_bootstrap(content, app) return Response( content = content, media_type = "text/html", headers = {"Cache-Control": "no-cache, no-store, must-revalidate"}, ) @app.get("/{full_path:path}") async def serve_frontend(full_path: str): if full_path.startswith("api"): return {"error": "API endpoint not found"} file_path = (build_path / full_path).resolve() # Block path traversal — ensure resolved path stays inside build_path if not file_path.is_relative_to(build_path.resolve()): return Response(status_code = 403) if file_path.is_file(): return FileResponse(file_path) # Serve index.html as bytes — avoids Content-Length mismatch content = (build_path / "index.html").read_bytes() content = _strip_crossorigin(content) content = _inject_bootstrap(content, app) return Response( content = content, media_type = "text/html", headers = {"Cache-Control": "no-cache, no-store, must-revalidate"}, ) return True ================================================ FILE: studio/backend/models/.gitkeep ================================================ ================================================ FILE: studio/backend/models/__init__.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Pydantic models for API request/response schemas """ from .training import ( TrainingStartRequest, TrainingJobResponse, TrainingStatus, TrainingProgress, ) from .models import ( CheckpointInfo, ModelCheckpoints, CheckpointListResponse, ModelDetails, LocalModelInfo, LocalModelListResponse, LoRAInfo, LoRAScanResponse, ModelListResponse, ) from .auth import ( AuthLoginRequest, RefreshTokenRequest, AuthStatusResponse, ChangePasswordRequest, ) from .export import ( LoadCheckpointRequest, ExportStatusResponse, ExportOperationResponse, ExportMergedModelRequest, ExportBaseModelRequest, ExportGGUFRequest, ExportLoRAAdapterRequest, ) from .users import Token from .datasets import ( CheckFormatRequest, CheckFormatResponse, ) from .inference import ( LoadRequest, UnloadRequest, GenerateRequest, LoadResponse, UnloadResponse, InferenceStatusResponse, ) from .responses import ( TrainingStopResponse, TrainingMetricsResponse, LoRABaseModelResponse, VisionCheckResponse, EmbeddingCheckResponse, ) from .data_recipe import ( RecipePayload, PreviewResponse, ValidateError, ValidateResponse, JobCreateResponse, ) __all__ = [ # Training schemas "TrainingStartRequest", "TrainingJobResponse", "TrainingStatus", "TrainingProgress", # Model management schemas "ModelDetails", "LocalModelInfo", "LocalModelListResponse", "LoRAInfo", "LoRAScanResponse", "ModelListResponse", # Auth schemas "AuthLoginRequest", "RefreshTokenRequest", "AuthStatusResponse", "ChangePasswordRequest", # Export schemas "CheckpointInfo", "ModelCheckpoints", "CheckpointListResponse", "LoadCheckpointRequest", "ExportStatusResponse", "ExportOperationResponse", "ExportMergedModelRequest", "ExportBaseModelRequest", "ExportGGUFRequest", "ExportLoRAAdapterRequest", "Token", # Dataset schemas "CheckFormatRequest", "CheckFormatResponse", # Inference schemas "LoadRequest", "UnloadRequest", "GenerateRequest", "LoadResponse", "UnloadResponse", "InferenceStatusResponse", # Response schemas "TrainingStopResponse", "TrainingMetricsResponse", "LoRABaseModelResponse", "VisionCheckResponse", "EmbeddingCheckResponse", # Data recipe "RecipePayload", "PreviewResponse", "ValidateError", "ValidateResponse", "JobCreateResponse", ] ================================================ FILE: studio/backend/models/auth.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Pydantic schemas for Authentication API """ from pydantic import BaseModel, Field class AuthLoginRequest(BaseModel): """Login payload: username/password to obtain a JWT.""" username: str = Field(..., description = "Username") password: str = Field(..., description = "Password") class RefreshTokenRequest(BaseModel): """Refresh token payload to obtain new access + refresh tokens.""" refresh_token: str = Field( ..., description = "Refresh token from a previous login or refresh" ) class AuthStatusResponse(BaseModel): """Indicate whether the seeded admin auth flow is ready.""" initialized: bool = Field( ..., description = "True if the auth database contains a login user" ) default_username: str = Field(..., description = "Default seeded admin username") requires_password_change: bool = Field( ..., description = "True if the seeded admin must still change the default password", ) class ChangePasswordRequest(BaseModel): """Change the current user's password, typically on first login.""" current_password: str = Field( ..., min_length = 8, description = "Existing password for the authenticated user" ) new_password: str = Field( ..., min_length = 8, description = "Replacement password (minimum 8 characters)" ) ================================================ FILE: studio/backend/models/data_recipe.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Pydantic schemas for Data Recipe (DataDesigner) API. """ from __future__ import annotations from typing import Any from pydantic import BaseModel, Field class RecipePayload(BaseModel): recipe: dict[str, Any] = Field(default_factory = dict) run: dict[str, Any] | None = None ui: dict[str, Any] | None = None class PreviewResponse(BaseModel): dataset: list[dict[str, Any]] = Field(default_factory = list) processor_artifacts: dict[str, Any] | None = None analysis: dict[str, Any] | None = None class ValidateError(BaseModel): message: str path: str | None = None code: str | None = None class ValidateResponse(BaseModel): valid: bool errors: list[ValidateError] = Field(default_factory = list) raw_detail: str | None = None class JobCreateResponse(BaseModel): job_id: str class PublishDatasetRequest(BaseModel): repo_id: str = Field(min_length = 3, description = "Hugging Face dataset repo ID") description: str = Field( min_length = 1, max_length = 4000, description = "Short dataset description for the dataset card", ) hf_token: str | None = Field( default = None, description = "Optional Hugging Face token for private or write-protected repos", ) private: bool = Field( default = False, description = "Create or update the dataset repo as private", ) artifact_path: str | None = Field( default = None, description = "Execution artifact path captured by the UI for completed runs", ) class PublishDatasetResponse(BaseModel): success: bool = True url: str message: str class SeedInspectRequest(BaseModel): dataset_name: str = Field(min_length = 1) hf_token: str | None = None subset: str | None = None split: str | None = "train" preview_size: int = Field(default = 10, ge = 1, le = 50) class SeedInspectUploadRequest(BaseModel): filename: str = Field(min_length = 1) content_base64: str = Field(min_length = 1) preview_size: int = Field(default = 10, ge = 1, le = 50) seed_source_type: str | None = None unstructured_chunk_size: int | None = Field(default = None, ge = 1, le = 20000) unstructured_chunk_overlap: int | None = Field(default = None, ge = 0, le = 20000) class SeedInspectResponse(BaseModel): dataset_name: str resolved_path: str columns: list[str] = Field(default_factory = list) preview_rows: list[dict[str, Any]] = Field(default_factory = list) split: str | None = None subset: str | None = None class McpToolsListRequest(BaseModel): mcp_providers: list[dict[str, Any]] = Field(default_factory = list) timeout_sec: float | None = Field(default = None, gt = 0) class McpToolsProviderResult(BaseModel): name: str tools: list[str] = Field(default_factory = list) error: str | None = None class McpToolsListResponse(BaseModel): providers: list[McpToolsProviderResult] = Field(default_factory = list) duplicate_tools: dict[str, list[str]] = Field(default_factory = dict) ================================================ FILE: studio/backend/models/datasets.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Dataset-related Pydantic models for API requests and responses. """ from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field, model_validator class CheckFormatRequest(BaseModel): """Request for dataset format check""" dataset_name: str # HuggingFace dataset name or local path is_vlm: bool = False hf_token: Optional[str] = None subset: Optional[str] = None train_split: Optional[str] = "train" @model_validator(mode = "before") @classmethod def _compat_split(cls, values: Any) -> Any: """Accept legacy 'split' field as alias for 'train_split'.""" if isinstance(values, dict) and "split" in values: values.setdefault("train_split", values.pop("split")) return values class CheckFormatResponse(BaseModel): """Response for dataset format check""" requires_manual_mapping: bool detected_format: str columns: List[str] is_image: bool = False is_audio: bool = False multimodal_columns: Optional[List[str]] = None suggested_mapping: Optional[Dict[str, str]] = None detected_image_column: Optional[str] = None detected_audio_column: Optional[str] = None detected_text_column: Optional[str] = None detected_speaker_column: Optional[str] = None preview_samples: Optional[List[Dict]] = None total_rows: Optional[int] = None warning: Optional[str] = None class AiAssistMappingRequest(BaseModel): """Request for LLM-assisted column classification (user-triggered).""" columns: List[str] samples: List[Dict[str, Any]] # Preview rows already loaded in the dialog dataset_name: Optional[str] = None # For LLM context hf_token: Optional[str] = None # For fetching dataset card model_name: Optional[str] = None model_type: Optional[str] = None class AiAssistMappingResponse(BaseModel): """Response from LLM-assisted column classification and conversion advice.""" success: bool suggested_mapping: Optional[Dict[str, str]] = None warning: Optional[str] = None # Conversion advisor fields system_prompt: Optional[str] = None label_mapping: Optional[Dict[str, Dict[str, str]]] = None dataset_type: Optional[str] = None is_conversational: Optional[bool] = None user_notification: Optional[str] = None class UploadDatasetResponse(BaseModel): """Response with stored dataset path for training.""" filename: str = Field(..., description = "Original filename") stored_path: str = Field(..., description = "Absolute path stored on backend") class LocalDatasetItem(BaseModel): class Metadata(BaseModel): actual_num_records: Optional[int] = None target_num_records: Optional[int] = None total_num_batches: Optional[int] = None num_completed_batches: Optional[int] = None columns: Optional[List[str]] = None id: str label: str path: str rows: Optional[int] = None updated_at: Optional[float] = None metadata: Optional[Metadata] = None class LocalDatasetsResponse(BaseModel): datasets: List[LocalDatasetItem] = Field(default_factory = list) ================================================ FILE: studio/backend/models/export.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Pydantic schemas for Export API. """ from pydantic import BaseModel, Field from typing import List, Optional, Literal, Dict, Any class LoadCheckpointRequest(BaseModel): """Request for loading a checkpoint into the export backend.""" checkpoint_path: str = Field(..., description = "Path to the checkpoint directory") max_seq_length: int = Field( 2048, ge = 128, le = 32768, description = "Maximum sequence length for loading the model", ) load_in_4bit: bool = Field( True, description = "Whether to load the model in 4-bit quantization", ) trust_remote_code: bool = Field( False, description = "Allow loading models with custom code. Only enable for checkpoints/base models you trust.", ) class ExportStatusResponse(BaseModel): """Current export backend status.""" current_checkpoint: Optional[str] = Field( None, description = "Path to the currently loaded checkpoint, if any", ) is_vision: bool = Field( False, description = "True if the loaded checkpoint is a vision model", ) is_peft: bool = Field( False, description = "True if the loaded checkpoint is a PEFT (LoRA) model", ) class ExportOperationResponse(BaseModel): """Generic response for export operations.""" success: bool = Field(..., description = "True if the operation succeeded") message: str = Field(..., description = "Human-readable status or error message") details: Optional[Dict[str, Any]] = Field( default = None, description = "Optional extra details about the operation", ) class ExportCommonOptions(BaseModel): """Common options for export operations that save locally and/or push to Hub.""" save_directory: str = Field( ..., description = "Local directory where the exported artifacts will be written", ) push_to_hub: bool = Field( False, description = "If True, also push the exported model to the Hugging Face Hub", ) repo_id: Optional[str] = Field( None, description = "Hugging Face Hub repository ID (username/model-name)", ) hf_token: Optional[str] = Field( None, description = "Hugging Face access token used for Hub operations", ) private: bool = Field( False, description = "If True, create a private repository on the Hub (where applicable)", ) base_model_id: Optional[str] = Field( None, description = "HuggingFace model ID of the base model (for model card metadata)", ) class ExportMergedModelRequest(ExportCommonOptions): """Request for exporting a merged PEFT model.""" format_type: Literal["16-bit (FP16)", "4-bit (FP4)"] = Field( "16-bit (FP16)", description = "Export precision / format for the merged model", ) class ExportBaseModelRequest(ExportCommonOptions): """Request for exporting a non-PEFT (base) model.""" # Uses fields from ExportCommonOptions only class ExportGGUFRequest(BaseModel): """Request for exporting the current model to GGUF format.""" save_directory: str = Field( ..., description = "Directory where GGUF files will be saved", ) quantization_method: str = Field( "Q4_K_M", description = 'GGUF quantization method (e.g. "Q4_K_M")', ) push_to_hub: bool = Field( False, description = "If True, also push GGUF artifacts to the Hugging Face Hub", ) repo_id: Optional[str] = Field( None, description = "Hugging Face Hub repository ID for GGUF upload", ) hf_token: Optional[str] = Field( None, description = "Hugging Face token for GGUF upload", ) class ExportLoRAAdapterRequest(ExportCommonOptions): """Request for exporting only the LoRA adapter (not merged).""" # Uses fields from ExportCommonOptions only ================================================ FILE: studio/backend/models/inference.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Pydantic schemas for Inference API """ from __future__ import annotations import time import uuid from typing import Annotated, Any, Dict, Literal, Optional, List, Union from pydantic import BaseModel, Discriminator, Field, Tag class LoadRequest(BaseModel): """Request to load a model for inference""" model_path: str = Field(..., description = "Model identifier or local path") hf_token: Optional[str] = Field( None, description = "HuggingFace token for gated models" ) max_seq_length: int = Field( 4096, ge = 128, le = 32768, description = "Maximum sequence length" ) load_in_4bit: bool = Field(True, description = "Load model in 4-bit quantization") is_lora: bool = Field(False, description = "Whether this is a LoRA adapter") gguf_variant: Optional[str] = Field( None, description = "GGUF quantization variant (e.g. 'Q4_K_M')" ) trust_remote_code: bool = Field( False, description = "Allow loading models with custom code (e.g. NVIDIA Nemotron). Only enable for repos you trust.", ) chat_template_override: Optional[str] = Field( None, description = "Custom Jinja2 chat template to use instead of the model's default", ) cache_type_kv: Optional[str] = Field( None, description = "KV cache data type for both K and V (e.g. 'f16', 'bf16', 'q8_0', 'q4_1', 'q5_1')", ) class UnloadRequest(BaseModel): """Request to unload a model""" model_path: str = Field(..., description = "Model identifier to unload") class ValidateModelRequest(BaseModel): """ Lightweight validation request to check whether a model identifier *can be resolved* into a ModelConfig. This does NOT actually load weights into GPU memory. """ model_path: str = Field(..., description = "Model identifier or local path") hf_token: Optional[str] = Field( None, description = "HuggingFace token for gated models" ) gguf_variant: Optional[str] = Field( None, description = "GGUF quantization variant (e.g. 'Q4_K_M')" ) class ValidateModelResponse(BaseModel): """ Result of model validation. valid == True means ModelConfig.from_identifier() succeeded and basic introspection (GGUF / LoRA / vision flags) is available. """ valid: bool = Field(..., description = "Whether the model identifier looks valid") message: str = Field(..., description = "Human-readable validation message") identifier: Optional[str] = Field(None, description = "Resolved model identifier") display_name: Optional[str] = Field( None, description = "Display name derived from identifier" ) is_gguf: bool = Field(False, description = "Whether this is a GGUF model (llama.cpp)") is_lora: bool = Field(False, description = "Whether this is a LoRA adapter") is_vision: bool = Field(False, description = "Whether this is a vision-capable model") class GenerateRequest(BaseModel): """Request for text generation (legacy /generate/stream endpoint)""" messages: List[dict] = Field(..., description = "Chat messages in OpenAI format") system_prompt: str = Field("", description = "System prompt") temperature: float = Field(0.6, ge = 0.0, le = 2.0, description = "Sampling temperature") top_p: float = Field(0.95, ge = 0.0, le = 1.0, description = "Top-p sampling") top_k: int = Field(20, ge = -1, le = 100, description = "Top-k sampling") max_new_tokens: int = Field( 2048, ge = 1, le = 4096, description = "Maximum tokens to generate" ) repetition_penalty: float = Field( 1.0, ge = 1.0, le = 2.0, description = "Repetition penalty" ) presence_penalty: float = Field(0.0, ge = 0.0, le = 2.0, description = "Presence penalty") image_base64: Optional[str] = Field( None, description = "Base64 encoded image for vision models" ) class LoadResponse(BaseModel): """Response after loading a model""" status: str = Field(..., description = "Load status") model: str = Field(..., description = "Model identifier") display_name: str = Field(..., description = "Display name of the model") is_vision: bool = Field(False, description = "Whether model is a vision model") is_lora: bool = Field(False, description = "Whether model is a LoRA adapter") is_gguf: bool = Field( False, description = "Whether model is a GGUF model (llama.cpp)" ) is_audio: bool = Field(False, description = "Whether model is a TTS audio model") audio_type: Optional[str] = Field( None, description = "Audio codec type: snac, csm, bicodec, dac" ) has_audio_input: bool = Field( False, description = "Whether model accepts audio input (ASR)" ) inference: dict = Field( ..., description = "Inference parameters (temperature, top_p, top_k, min_p)" ) context_length: Optional[int] = Field( None, description = "Model's native context length (from GGUF metadata)" ) supports_reasoning: bool = Field( False, description = "Whether model supports thinking/reasoning mode (enable_thinking)", ) supports_tools: bool = Field( False, description = "Whether model supports tool calling (web search, etc.)", ) cache_type_kv: Optional[str] = Field( None, description = "KV cache data type for K and V (e.g. 'f16', 'bf16', 'q8_0')", ) chat_template: Optional[str] = Field( None, description = "Jinja2 chat template string (from GGUF metadata or tokenizer)", ) class UnloadResponse(BaseModel): """Response after unloading a model""" status: str = Field(..., description = "Unload status") model: str = Field(..., description = "Model identifier that was unloaded") class InferenceStatusResponse(BaseModel): """Current inference backend status""" active_model: Optional[str] = Field( None, description = "Currently active model identifier" ) is_vision: bool = Field( False, description = "Whether the active model is a vision model" ) is_gguf: bool = Field( False, description = "Whether the active model is a GGUF model (llama.cpp)" ) gguf_variant: Optional[str] = Field( None, description = "GGUF quantization variant (e.g. Q4_K_M)" ) is_audio: bool = Field( False, description = "Whether the active model is a TTS audio model" ) audio_type: Optional[str] = Field( None, description = "Audio codec type: snac, csm, bicodec, dac" ) has_audio_input: bool = Field( False, description = "Whether model accepts audio input (ASR)" ) loading: List[str] = Field( default_factory = list, description = "Models currently being loaded" ) loaded: List[str] = Field( default_factory = list, description = "Models currently loaded" ) inference: Optional[Dict[str, Any]] = Field( None, description = "Recommended inference parameters for the active model" ) supports_reasoning: bool = Field( False, description = "Whether the active model supports reasoning/thinking mode" ) supports_tools: bool = Field( False, description = "Whether the active model supports tool calling" ) context_length: Optional[int] = Field( None, description = "Context length of the active model" ) # ===================================================================== # OpenAI-Compatible Chat Completions Models # ===================================================================== # ── Multimodal content parts (OpenAI vision format) ────────────── class TextContentPart(BaseModel): """Text content part in a multimodal message.""" type: Literal["text"] text: str class ImageUrl(BaseModel): """Image URL object — supports data URIs and remote URLs.""" url: str = Field(..., description = "data:image/png;base64,... or https://...") detail: Optional[Literal["auto", "low", "high"]] = "auto" class ImageContentPart(BaseModel): """Image content part in a multimodal message.""" type: Literal["image_url"] image_url: ImageUrl def _content_part_discriminator(v): if isinstance(v, dict): return v.get("type") return getattr(v, "type", None) ContentPart = Annotated[ Union[ Annotated[TextContentPart, Tag("text")], Annotated[ImageContentPart, Tag("image_url")], ], Discriminator(_content_part_discriminator), ] """Union type for multimodal content parts, discriminated by the 'type' field.""" # ── Messages ───────────────────────────────────────────────────── class ChatMessage(BaseModel): """ A single message in the conversation. ``content`` may be a plain string (text-only) or a list of content parts for multimodal messages (OpenAI vision format). """ role: Literal["system", "user", "assistant"] = Field( ..., description = "Message role" ) content: Union[str, list[ContentPart]] = Field( ..., description = "Message content (string or multimodal parts)" ) class ChatCompletionRequest(BaseModel): """ OpenAI-compatible chat completion request. Extensions (non-OpenAI fields) are marked with 'x-unsloth'. """ model: str = Field( "default", description = "Model identifier (informational; the active model is used)", ) messages: list[ChatMessage] = Field(..., description = "Conversation messages") stream: bool = Field(True, description = "Whether to stream the response via SSE") temperature: float = Field(0.6, ge = 0.0, le = 2.0) top_p: float = Field(0.95, ge = 0.0, le = 1.0) max_tokens: Optional[int] = Field( None, ge = 1, description = "Maximum tokens to generate (None = until EOS)" ) presence_penalty: float = Field(0.0, ge = 0.0, le = 2.0, description = "Presence penalty") # ── Unsloth extensions (ignored by standard OpenAI clients) ── top_k: int = Field(20, ge = -1, le = 100, description = "[x-unsloth] Top-k sampling") min_p: float = Field( 0.01, ge = 0.0, le = 1.0, description = "[x-unsloth] Min-p sampling threshold" ) repetition_penalty: float = Field( 1.1, ge = 1.0, le = 2.0, description = "[x-unsloth] Repetition penalty" ) image_base64: Optional[str] = Field( None, description = "[x-unsloth] Base64-encoded image for vision models" ) audio_base64: Optional[str] = Field( None, description = "[x-unsloth] Base64-encoded WAV for audio-input models (ASR)" ) use_adapter: Optional[Union[bool, str]] = Field( None, description = ( "[x-unsloth] Adapter control for compare mode. " "null = no change (default), " "false = disable adapters (base model), " "true = enable the current adapter, " "string = enable a specific adapter by name." ), ) enable_thinking: Optional[bool] = Field( None, description = "[x-unsloth] Enable/disable thinking/reasoning mode for supported models", ) enable_tools: Optional[bool] = Field( None, description = "[x-unsloth] Enable tool calling for supported models", ) enabled_tools: Optional[list[str]] = Field( None, description = "[x-unsloth] List of enabled tool names (e.g. ['web_search', 'python', 'terminal']). If None, all tools are enabled.", ) auto_heal_tool_calls: Optional[bool] = Field( True, description = "[x-unsloth] Auto-detect and fix malformed tool calls from model output.", ) max_tool_calls_per_message: Optional[int] = Field( 10, ge = 0, description = "[x-unsloth] Maximum number of tool call iterations per message (0 = disabled, 9999 = unlimited).", ) tool_call_timeout: Optional[int] = Field( 300, ge = 1, description = "[x-unsloth] Timeout in seconds for each tool call execution (9999 = no limit).", ) session_id: Optional[str] = Field( None, description = "[x-unsloth] Session/thread ID for scoping tool execution sandbox.", ) # ── Streaming response chunks ──────────────────────────────────── class ChoiceDelta(BaseModel): """Delta content for a streaming chunk.""" role: Optional[str] = None content: Optional[str] = None class ChunkChoice(BaseModel): """A single choice in a streaming chunk.""" index: int = 0 delta: ChoiceDelta finish_reason: Optional[Literal["stop", "length"]] = None class ChatCompletionChunk(BaseModel): """A single SSE chunk in OpenAI streaming format.""" id: str = Field(default_factory = lambda: f"chatcmpl-{uuid.uuid4().hex[:12]}") object: Literal["chat.completion.chunk"] = "chat.completion.chunk" created: int = Field(default_factory = lambda: int(time.time())) model: str = "default" choices: list[ChunkChoice] # ── Non-streaming response ─────────────────────────────────────── class CompletionMessage(BaseModel): """The assistant's complete response message.""" role: Literal["assistant"] = "assistant" content: str class CompletionChoice(BaseModel): """A single choice in a non-streaming response.""" index: int = 0 message: CompletionMessage finish_reason: Literal["stop", "length"] = "stop" class CompletionUsage(BaseModel): """Token usage statistics (approximate).""" prompt_tokens: int = 0 completion_tokens: int = 0 total_tokens: int = 0 class ChatCompletion(BaseModel): """Non-streaming chat completion response.""" id: str = Field(default_factory = lambda: f"chatcmpl-{uuid.uuid4().hex[:12]}") object: Literal["chat.completion"] = "chat.completion" created: int = Field(default_factory = lambda: int(time.time())) model: str = "default" choices: list[CompletionChoice] usage: CompletionUsage = Field(default_factory = CompletionUsage) ================================================ FILE: studio/backend/models/models.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Pydantic schemas for Model Management API """ from pydantic import BaseModel, Field from typing import Optional, List, Dict, Any, Literal ModelType = Literal["text", "vision", "audio", "embeddings"] class CheckpointInfo(BaseModel): """Information about a discovered checkpoint directory.""" display_name: str = Field( ..., description = "User-friendly checkpoint name (folder name)" ) path: str = Field(..., description = "Full path to the checkpoint directory") loss: Optional[float] = Field(None, description = "Training loss at this checkpoint") class ModelCheckpoints(BaseModel): """A training run and its associated checkpoints.""" name: str = Field(..., description = "Training run folder name") checkpoints: List[CheckpointInfo] = Field( default_factory = list, description = "List of checkpoints for this training run (final + intermediate)", ) base_model: Optional[str] = Field( None, description = "Base model name from adapter_config.json or config.json", ) peft_type: Optional[str] = Field( None, description = "PEFT type (e.g. LORA) if adapter training, None for full fine-tune", ) lora_rank: Optional[int] = Field( None, description = "LoRA rank (r) if applicable", ) class CheckpointListResponse(BaseModel): """Response for listing available checkpoints in an outputs directory.""" outputs_dir: str = Field(..., description = "Directory that was scanned") models: List[ModelCheckpoints] = Field( default_factory = list, description = "List of training runs with their checkpoints", ) class ModelDetails(BaseModel): """Detailed model configuration and metadata - can be used for both list and detail views""" id: str = Field(..., description = "Model identifier") model_name: Optional[str] = Field( None, description = "Model identifier (alias for id, for backward compatibility)" ) name: Optional[str] = Field(None, description = "Display name for the model") config: Optional[Dict[str, Any]] = Field( None, description = "Model configuration dictionary" ) is_vision: bool = Field(False, description = "Whether model is a vision model") is_embedding: bool = Field( False, description = "Whether model is an embedding/sentence-transformer model" ) is_lora: bool = Field(False, description = "Whether model is a LoRA adapter") is_gguf: bool = Field( False, description = "Whether model is a GGUF model (llama.cpp format)" ) is_audio: bool = Field(False, description = "Whether model is a TTS audio model") audio_type: Optional[str] = Field( None, description = "Audio codec type: snac, csm, bicodec, dac" ) has_audio_input: bool = Field( False, description = "Whether model accepts audio input (ASR)" ) model_type: Optional[ModelType] = Field( None, description = "Collapsed model modality: text, vision, audio, or embeddings" ) base_model: Optional[str] = Field( None, description = "Base model if this is a LoRA adapter" ) max_position_embeddings: Optional[int] = Field( None, description = "Maximum context length supported by the model" ) model_size_bytes: Optional[int] = Field( None, description = "Total size of model weight files in bytes" ) class LoRAInfo(BaseModel): """LoRA adapter or exported model information""" display_name: str = Field(..., description = "Display name for the LoRA") adapter_path: str = Field( ..., description = "Path to the LoRA adapter or exported model" ) base_model: Optional[str] = Field(None, description = "Base model identifier") source: Optional[str] = Field(None, description = "'training' or 'exported'") export_type: Optional[str] = Field( None, description = "'lora', 'merged', or 'gguf' (for exports)" ) class LoRAScanResponse(BaseModel): """Response schema for scanning trained LoRA adapters""" loras: List[LoRAInfo] = Field( default_factory = list, description = "List of found LoRA adapters" ) outputs_dir: str = Field(..., description = "Directory that was scanned") class ModelListResponse(BaseModel): """Response schema for listing models""" models: List[ModelDetails] = Field( default_factory = list, description = "List of models" ) default_models: List[str] = Field( default_factory = list, description = "List of default model IDs" ) class GgufVariantDetail(BaseModel): """A single GGUF quantization variant in a HuggingFace repo.""" filename: str = Field( ..., description = "GGUF filename (e.g., 'gemma-3-4b-it-Q4_K_M.gguf')" ) quant: str = Field(..., description = "Quantization label (e.g., 'Q4_K_M')") size_bytes: int = Field(0, description = "File size in bytes") downloaded: bool = Field( False, description = "Whether this variant is already in the local HF cache" ) class GgufVariantsResponse(BaseModel): """Response for listing GGUF quantization variants in a HuggingFace repo.""" repo_id: str = Field(..., description = "HuggingFace repo ID") variants: List[GgufVariantDetail] = Field( default_factory = list, description = "Available GGUF variants" ) has_vision: bool = Field( False, description = "Whether the model has vision support (mmproj files)" ) default_variant: Optional[str] = Field( None, description = "Recommended default quantization variant" ) class LocalModelInfo(BaseModel): """Discovered local model candidate.""" id: str = Field(..., description = "Identifier to use for loading/training") display_name: str = Field(..., description = "Display label") path: str = Field(..., description = "Local path where model data was discovered") source: Literal["models_dir", "hf_cache"] = Field( ..., description = "Discovery source", ) model_id: Optional[str] = Field( None, description = "HF repo id for cached models, e.g. org/model", ) updated_at: Optional[float] = Field( None, description = "Unix timestamp of latest observed update", ) class LocalModelListResponse(BaseModel): """Response schema for listing local/cached models.""" models_dir: str = Field( ..., description = "Directory scanned for custom local models" ) hf_cache_dir: Optional[str] = Field( None, description = "HF cache root that was scanned", ) models: List[LocalModelInfo] = Field( default_factory = list, description = "Discovered local/cached models", ) ================================================ FILE: studio/backend/models/responses.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Pydantic response schemas for endpoints that previously returned raw dicts. These are small response models for training and model management routes. """ from pydantic import BaseModel, Field from typing import Optional, List # --- Training route response models --- class TrainingStopResponse(BaseModel): """Response for stopping a training job""" status: str = Field(..., description = "Current status: 'stopped' or 'idle'") message: str = Field(..., description = "Human-readable status message") class TrainingMetricsResponse(BaseModel): """Response for training metrics history""" loss_history: List[float] = Field( default_factory = list, description = "Loss values per step" ) lr_history: List[float] = Field( default_factory = list, description = "Learning rate per step" ) step_history: List[int] = Field(default_factory = list, description = "Step numbers") grad_norm_history: List[float] = Field( default_factory = list, description = "Gradient norm values" ) grad_norm_step_history: List[int] = Field( default_factory = list, description = "Step numbers for gradient norm values" ) current_loss: Optional[float] = Field(None, description = "Most recent loss value") current_lr: Optional[float] = Field(None, description = "Most recent learning rate") current_step: Optional[int] = Field(None, description = "Most recent step number") # --- Model management route response models --- class LoRABaseModelResponse(BaseModel): """Response for getting a LoRA's base model""" lora_path: str = Field(..., description = "Path to the LoRA adapter") base_model: str = Field(..., description = "Base model identifier") class VisionCheckResponse(BaseModel): """Response for checking if a model is a vision model""" model_name: str = Field(..., description = "Model identifier") is_vision: bool = Field(..., description = "Whether the model is a vision model") class EmbeddingCheckResponse(BaseModel): """Response for checking if a model is an embedding model""" model_name: str = Field(..., description = "Model identifier") is_embedding: bool = Field( ..., description = "Whether the model is an embedding/sentence-transformer model" ) ================================================ FILE: studio/backend/models/training.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Pydantic schemas for Training API """ from pydantic import BaseModel, Field, model_validator from typing import Any, Optional, List, Dict, Literal class TrainingStartRequest(BaseModel): """Request schema for starting training""" # Model parameters model_name: str = Field( ..., description = "Model identifier (e.g., 'unsloth/llama-3-8b-bnb-4bit')" ) training_type: str = Field( ..., description = "Training type: 'LoRA/QLoRA' or 'Full Finetuning'" ) hf_token: Optional[str] = Field(None, description = "HuggingFace token") load_in_4bit: bool = Field(True, description = "Load model in 4-bit quantization") max_seq_length: int = Field(2048, description = "Maximum sequence length") trust_remote_code: bool = Field( False, description = "Allow loading models with custom code (e.g. NVIDIA Nemotron). Only enable for repos you trust.", ) # Dataset parameters hf_dataset: Optional[str] = Field( None, description = "HuggingFace dataset identifier" ) local_datasets: List[str] = Field( default_factory = list, description = "List of local dataset paths" ) local_eval_datasets: List[str] = Field( default_factory = list, description = "List of local eval dataset paths" ) format_type: str = Field(..., description = "Dataset format type") subset: Optional[str] = None train_split: Optional[str] = Field("train", description = "Training split name") eval_split: Optional[str] = Field( None, description = "Eval split name. None = auto-detect" ) eval_steps: float = Field( 0.00, description = "Fraction of total steps between evals (0-1)" ) dataset_slice_start: Optional[int] = Field( None, description = "Inclusive start row index for dataset slicing" ) dataset_slice_end: Optional[int] = Field( None, description = "Inclusive end row index for dataset slicing" ) @model_validator(mode = "before") @classmethod def _compat_split(cls, values: Any) -> Any: """Accept legacy 'split' field as alias for 'train_split'.""" if isinstance(values, dict) and "split" in values: values.setdefault("train_split", values.pop("split")) return values custom_format_mapping: Optional[Dict[str, Any]] = Field( None, description = ( "User-provided column-to-role mapping, e.g. {'image': 'image', 'caption': 'text'} " "for VLM or {'instruction': 'user', 'output': 'assistant'} for LLM. " "Enhanced format includes __system_prompt, __user_template, " "__assistant_template, __label_mapping metadata keys." ), ) # Training parameters num_epochs: int = Field(1, description = "Number of training epochs") learning_rate: str = Field("2e-4", description = "Learning rate") batch_size: int = Field(1, description = "Batch size") gradient_accumulation_steps: int = Field( 1, description = "Gradient accumulation steps" ) warmup_steps: Optional[int] = Field(None, description = "Warmup steps") warmup_ratio: Optional[float] = Field(None, description = "Warmup ratio") max_steps: Optional[int] = Field(None, description = "Maximum training steps") save_steps: int = Field(100, description = "Steps between checkpoints") weight_decay: float = Field(0.01, description = "Weight decay") random_seed: int = Field(42, description = "Random seed") packing: bool = Field(False, description = "Enable sequence packing") optim: str = Field("adamw_8bit", description = "Optimizer") lr_scheduler_type: str = Field("linear", description = "Learning rate scheduler type") # LoRA parameters use_lora: bool = Field(True, description = "Use LoRA (derived from training_type)") lora_r: int = Field(16, description = "LoRA rank") lora_alpha: int = Field(16, description = "LoRA alpha") lora_dropout: float = Field(0.0, description = "LoRA dropout") target_modules: List[str] = Field( default_factory = list, description = "Target modules for LoRA" ) gradient_checkpointing: str = Field( "", description = "Gradient checkpointing setting" ) use_rslora: bool = Field(False, description = "Use RSLoRA") use_loftq: bool = Field(False, description = "Use LoftQ") train_on_completions: bool = Field(False, description = "Train on completions only") # Vision-specific LoRA parameters finetune_vision_layers: bool = Field(False, description = "Finetune vision layers") finetune_language_layers: bool = Field( False, description = "Finetune language layers" ) finetune_attention_modules: bool = Field( False, description = "Finetune attention modules" ) finetune_mlp_modules: bool = Field(False, description = "Finetune MLP modules") is_dataset_image: bool = Field( False, description = "Whether the dataset contains image data" ) is_dataset_audio: bool = Field( False, description = "Whether the dataset contains audio data" ) is_embedding: bool = Field( False, description = "Whether model is an embedding/sentence-transformer model" ) # Logging parameters enable_wandb: bool = Field(False, description = "Enable Weights & Biases logging") wandb_token: Optional[str] = Field(None, description = "W&B token") wandb_project: Optional[str] = Field(None, description = "W&B project name") enable_tensorboard: bool = Field(False, description = "Enable TensorBoard logging") tensorboard_dir: Optional[str] = Field(None, description = "TensorBoard directory") class TrainingJobResponse(BaseModel): """Immediate response when training is initiated""" job_id: str = Field(..., description = "Unique training job identifier") status: Literal["queued", "error"] = Field(..., description = "Initial job status") message: str = Field(..., description = "Human-readable status message") error: Optional[str] = Field(None, description = "Error details if status is 'error'") class TrainingStatus(BaseModel): """Current training job status - works for streaming or polling""" job_id: str = Field(..., description = "Training job identifier") phase: Literal[ "idle", "loading_model", "loading_dataset", "configuring", "training", "completed", "error", "stopped", ] = Field(..., description = "Current phase of training pipeline") is_training_running: bool = Field( ..., description = "True if training loop is actively running" ) eval_enabled: bool = Field( False, description = "True if evaluation dataset is configured for this training run", ) message: str = Field(..., description = "Human-readable status message") error: Optional[str] = Field(None, description = "Error details if phase is 'error'") details: Optional[dict] = Field( None, description = "Phase-specific info, e.g. {'model_size': '8B'}" ) metric_history: Optional[dict] = Field( None, description = "Full metric history arrays for chart recovery after SSE reconnection. " "Keys: 'steps', 'loss', 'lr', 'grad_norm', 'grad_norm_steps' — each a list of numeric values.", ) class TrainingProgress(BaseModel): """Training progress metrics - for streaming or polling""" job_id: str = Field(..., description = "Training job identifier") step: int = Field(..., description = "Current training step") total_steps: int = Field(..., description = "Total training steps") loss: float = Field(..., description = "Current loss value") learning_rate: float = Field(..., description = "Current learning rate") progress_percent: float = Field( ..., description = "Progress percentage (0.0 to 100.0)" ) epoch: Optional[float] = Field(None, description = "Current epoch") elapsed_seconds: Optional[float] = Field( None, description = "Time elapsed since training started" ) eta_seconds: Optional[float] = Field(None, description = "Estimated time remaining") grad_norm: Optional[float] = Field( None, description = "L2 norm of gradients, computed before gradient clipping" ) num_tokens: Optional[int] = Field( None, description = "Total number of tokens processed so far" ) eval_loss: Optional[float] = Field( None, description = "Eval loss from the most recent evaluation step" ) ================================================ FILE: studio/backend/models/users.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """Pydantic models for authentication tokens. This module defines the Token response model used by auth routes. """ from pydantic import BaseModel, Field class Token(BaseModel): """Authentication response model for session credentials.""" access_token: str = Field( ..., description = "Session access credential used for authenticated API requests" ) refresh_token: str = Field( ..., description = "Session refresh credential used to renew an expired access credential", ) token_type: str = Field( ..., description = "Credential type for the Authorization header, always 'bearer'" ) must_change_password: bool = Field( ..., description = "True when the user must change the seeded default password" ) ================================================ FILE: studio/backend/plugins/__init__.py ================================================ ================================================ FILE: studio/backend/plugins/data-designer-unstructured-seed/__init__.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 ================================================ FILE: studio/backend/plugins/data-designer-unstructured-seed/pyproject.toml ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 [build-system] requires = ["setuptools>=68", "wheel"] build-backend = "setuptools.build_meta" [project] name = "data-designer-unstructured-seed" version = "0.1.0" description = "Local Data Designer unstructured seed reader plugin" requires-python = ">=3.11" dependencies = [ "data-designer-engine>=0.5.1,<0.6", "pandas>=2,<3", ] [project.entry-points."data_designer.plugins"] unstructured = "data_designer_unstructured_seed.plugin:unstructured_seed_plugin" [tool.setuptools] package-dir = {"" = "src"} [tool.setuptools.packages.find] where = ["src"] ================================================ FILE: studio/backend/plugins/data-designer-unstructured-seed/src/data_designer_unstructured_seed/__init__.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 from .chunking import ( DEFAULT_CHUNK_OVERLAP, DEFAULT_CHUNK_SIZE, build_unstructured_preview_rows, materialize_unstructured_seed_dataset, resolve_chunking, ) from .config import UnstructuredSeedSource from .impl import UnstructuredSeedReader from .plugin import unstructured_seed_plugin __all__ = [ "DEFAULT_CHUNK_OVERLAP", "DEFAULT_CHUNK_SIZE", "build_unstructured_preview_rows", "materialize_unstructured_seed_dataset", "resolve_chunking", "UnstructuredSeedSource", "UnstructuredSeedReader", "unstructured_seed_plugin", ] ================================================ FILE: studio/backend/plugins/data-designer-unstructured-seed/src/data_designer_unstructured_seed/chunking.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 from __future__ import annotations import hashlib import re from pathlib import Path from typing import Any from utils.paths import ensure_dir, unstructured_seed_cache_root DEFAULT_CHUNK_SIZE = 1200 DEFAULT_CHUNK_OVERLAP = 200 MAX_CHUNK_SIZE = 20000 _MIN_BREAK_RATIO = 0.6 _CACHE_DIR = unstructured_seed_cache_root() def resolve_chunking( chunk_size: Any, chunk_overlap: Any, ) -> tuple[int, int]: size = _to_int(chunk_size, DEFAULT_CHUNK_SIZE) size = max(1, min(size, MAX_CHUNK_SIZE)) overlap = _to_int(chunk_overlap, DEFAULT_CHUNK_OVERLAP) overlap = max(0, min(overlap, max(0, size - 1))) return size, overlap def build_unstructured_preview_rows( *, source_path: Path, preview_size: int, chunk_size: Any, chunk_overlap: Any, ) -> list[dict[str, str]]: parquet_path, rows = materialize_unstructured_seed_dataset( source_path = source_path, chunk_size = chunk_size, chunk_overlap = chunk_overlap, ) count = max(0, int(preview_size)) if rows: return rows[:count] try: import pandas as pd except ImportError as exc: # pragma: no cover raise RuntimeError( f"pandas is required for unstructured seed processing: {exc}" ) from exc dataframe = pd.read_parquet(parquet_path).head(count) return [ {"chunk_text": str(value.get("chunk_text", "")).strip()} for value in dataframe.to_dict(orient = "records") if str(value.get("chunk_text", "")).strip() ] def materialize_unstructured_seed_dataset( *, source_path: Path, chunk_size: Any, chunk_overlap: Any, ) -> tuple[Path, list[dict[str, str]]]: resolved = source_path.expanduser().resolve() if not resolved.is_file(): raise FileNotFoundError(f"unstructured seed file not found: {resolved}") size, overlap = resolve_chunking(chunk_size, chunk_overlap) key = _compute_cache_key( source_path = resolved, chunk_size = size, chunk_overlap = overlap, ) parquet_path = _CACHE_DIR / f"{key}.parquet" if parquet_path.exists(): return parquet_path, [] text = load_unstructured_text_file(resolved) chunks = split_text_into_chunks( text = text, chunk_size = size, chunk_overlap = overlap, ) if not chunks: raise ValueError("No text found in unstructured seed source.") rows = [{"chunk_text": chunk} for chunk in chunks] ensure_dir(_CACHE_DIR) try: import pandas as pd except ImportError as exc: # pragma: no cover raise RuntimeError( f"pandas is required for unstructured seed processing: {exc}" ) from exc tmp_path = _CACHE_DIR / f"{key}.tmp.parquet" pd.DataFrame(rows).to_parquet(tmp_path, index = False) tmp_path.replace(parquet_path) return parquet_path, rows def load_unstructured_text_file(path: Path) -> str: ext = path.suffix.lower() if ext not in {".txt", ".md"}: raise ValueError(f"Unsupported unstructured seed file type: {ext}") raw = path.read_text(encoding = "utf-8", errors = "ignore") return normalize_unstructured_text(raw) def normalize_unstructured_text(text: str) -> str: normalized = text.replace("\r\n", "\n").replace("\r", "\n") return re.sub(r"\n{3,}", "\n\n", normalized).strip() def split_text_into_chunks( *, text: str, chunk_size: int, chunk_overlap: int, ) -> list[str]: if not text: return [] if chunk_size <= 0: return [text] chunks: list[str] = [] start = 0 min_break_index = int(chunk_size * _MIN_BREAK_RATIO) text_len = len(text) while start < text_len: end = min(text_len, start + chunk_size) if end < text_len: window = text[start:end] cut = _find_break_index(window, min_break_index) if cut is not None and cut > 0: end = start + cut if end <= start: end = min(text_len, start + chunk_size) chunk = text[start:end].strip() if chunk: chunks.append(chunk) if end >= text_len: break next_start = end - chunk_overlap if next_start <= start: next_start = end start = max(0, next_start) return chunks def _find_break_index(window: str, min_index: int) -> int | None: breakpoints = ["\n\n", "\n", " "] for token in breakpoints: idx = window.rfind(token) if idx >= min_index: return idx + len(token) return None def _to_int(value: Any, fallback: int) -> int: if isinstance(value, bool): return fallback try: parsed = int(str(value).strip()) except (TypeError, ValueError): return fallback return parsed def _compute_cache_key( *, source_path: Path, chunk_size: int, chunk_overlap: int, ) -> str: stat = source_path.stat() payload = "|".join( [ str(source_path), str(stat.st_size), str(stat.st_mtime_ns), str(chunk_size), str(chunk_overlap), ] ).encode("utf-8") return hashlib.sha256(payload).hexdigest() ================================================ FILE: studio/backend/plugins/data-designer-unstructured-seed/src/data_designer_unstructured_seed/config.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 from __future__ import annotations from pathlib import Path from typing import Literal from pydantic import Field, field_validator from data_designer.config.seed_source import SeedSource from .chunking import DEFAULT_CHUNK_OVERLAP, DEFAULT_CHUNK_SIZE, resolve_chunking class UnstructuredSeedSource(SeedSource): seed_type: Literal["unstructured"] = "unstructured" path: str = Field(..., min_length = 1) chunk_size: int = DEFAULT_CHUNK_SIZE chunk_overlap: int = DEFAULT_CHUNK_OVERLAP @field_validator("path", mode = "after") @classmethod def _validate_path(cls, value: str) -> str: path = Path(value).expanduser() if not path.is_file(): raise ValueError(f"Unstructured seed path is not a file: {path}") return value @field_validator("chunk_size", mode = "after") @classmethod def _validate_chunk_size(cls, value: int) -> int: size, _ = resolve_chunking(value, 0) return size @field_validator("chunk_overlap", mode = "after") @classmethod def _validate_chunk_overlap(cls, value: int, info) -> int: size = info.data.get("chunk_size", cls.model_fields["chunk_size"].default) _, overlap = resolve_chunking(size, value) return overlap ================================================ FILE: studio/backend/plugins/data-designer-unstructured-seed/src/data_designer_unstructured_seed/impl.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 from __future__ import annotations from pathlib import Path import data_designer.lazy_heavy_imports as lazy from data_designer.engine.resources.seed_reader import SeedReader from .chunking import materialize_unstructured_seed_dataset from .config import UnstructuredSeedSource class UnstructuredSeedReader(SeedReader[UnstructuredSeedSource]): def create_duckdb_connection(self): return lazy.duckdb.connect() def get_dataset_uri(self) -> str: path, _ = materialize_unstructured_seed_dataset( source_path = Path(self.source.path), chunk_size = self.source.chunk_size, chunk_overlap = self.source.chunk_overlap, ) return str(path) ================================================ FILE: studio/backend/plugins/data-designer-unstructured-seed/src/data_designer_unstructured_seed/plugin.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 from data_designer.plugins.plugin import Plugin, PluginType unstructured_seed_plugin = Plugin( impl_qualified_name = "data_designer_unstructured_seed.impl.UnstructuredSeedReader", config_qualified_name = "data_designer_unstructured_seed.config.UnstructuredSeedSource", plugin_type = PluginType.SEED_READER, ) ================================================ FILE: studio/backend/requirements/__init__.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 ================================================ FILE: studio/backend/requirements/base.txt ================================================ # Core unsloth packages unsloth-zoo unsloth ================================================ FILE: studio/backend/requirements/extras-no-deps.txt ================================================ # Audio extras (installed with --no-deps --no-cache-dir) descript-audio-codec descript-audiotools julius torchcodec snac # TRL and related packages trl==0.23.1 git+https://github.com/meta-pytorch/OpenEnv.git # executorch>=1.0.1 # 41.5 MB - no imports in unsloth/zoo/studio torch-c-dlpack-ext sentence_transformers==5.2.0 transformers==4.57.6 ================================================ FILE: studio/backend/requirements/extras.txt ================================================ # OpenEnv dependencies tomli tomli-w # ExecuTorch dependencies ruamel.yaml # coremltools # 10.2 MB - Apple CoreML, no imports in unsloth/zoo/studio expecttest flatbuffers hydra-core hypothesis kgb parameterized pytest<9.0 pytest-json-report pytest-rerunfailures==15.1 pytest-xdist # Also needed by sentence_transformers (installed with --no-deps in extras-no-deps.txt) scikit-learn==1.7.1 # Additional extras pybind11 langid jiwer omegaconf einx pyloudnorm openai-whisper uroman # 4.0 MB - used for Outetts. MeCab # 19.9 MB - used for Outetts. inflect # number-to-words, required by OuteTTS loguru flatten_dict ffmpy randomname argbind tiktoken ftfy importlib-resources librosa markdown2 matplotlib pystoi soundfile tensorboard torch-stoi evaluate timm transformers-cfg open_spiel addict easydict einops tabulate fastmcp>=3.0.2 openai>=2.7.2 websockets>=15.0.1 ================================================ FILE: studio/backend/requirements/overrides.txt ================================================ # Torch AO overrides (installed with --force-reinstall --no-cache-dir) torchao==0.14.0 pytorch_tokenizers # Kernel packages kernels ================================================ FILE: studio/backend/requirements/single-env/constraints.txt ================================================ # Single-env pins for unsloth + studio + data-designer # Keep compatible with unsloth transformers bounds. transformers==4.57.6 trl==0.23.1 huggingface-hub==0.36.2 # Studio stack datasets==4.3.0 pyarrow==23.0.1 # FastMCP/OpenEnv compat fastmcp>=3.0.2 mcp>=1.24,<2 websockets>=15.0.1 pandas==2.3.3 ================================================ FILE: studio/backend/requirements/single-env/data-designer-deps.txt ================================================ # Data Designer runtime deps installed explicitly (single-env mode). # DuckDB 1.5 removed Relation.record_batch(); keep <1.5 until upstream ships the fix. anyascii<1,>=0.3.3 duckdb<1.5,>=1.1.3 faker<21,>=20.1.0 httpx<1,>=0.27.2 httpx-retries<1,>=0.4.2 json-repair<1,>=0.48.0 jsonpath-rust-bindings<2,>=1.0 jsonschema<5,>=4.0.0 litellm<1.80.12,>=1.73.6 lxml<7,>=6.0.2 marko<3,>=2.1.2 networkx<4,>=3.0 python-json-logger<4,>=3 ruff<1,>=0.14.10 scipy<2,>=1.11.0 sqlfluff<4,>=3.2.0 tiktoken<1,>=0.8.0 ================================================ FILE: studio/backend/requirements/single-env/data-designer.txt ================================================ # Install Data Designer in same env as Unsloth. data-designer==0.5.2 data-designer-config==0.5.2 data-designer-engine==0.5.2 prompt-toolkit>=3,<4 ================================================ FILE: studio/backend/requirements/single-env/patch_metadata.py ================================================ #!/usr/bin/env python3 # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """Relax strict metadata pins so pip check matches known working single-env stack. Why: - data-designer pins huggingface-hub>=1.0.1 and pyarrow<20. - unsloth/transformers pins huggingface-hub<1. - studio datasets pins pyarrow>=21. Runtime works in this app with hub 0.36.x + pyarrow 23.x, but metadata conflicts. """ from __future__ import annotations import importlib.metadata as im import re from pathlib import Path TARGETS = ( "data-designer", "data-designer-engine", "data-designer-config", ) PATCHES: tuple[tuple[re.Pattern[str], str], ...] = ( ( re.compile(r"^Requires-Dist: huggingface-hub<2,>=1\.0\.1$", re.MULTILINE), "Requires-Dist: huggingface-hub<2,>=0.34.0", ), ( re.compile(r"^Requires-Dist: pyarrow<20,>=19\.0\.1$", re.MULTILINE), "Requires-Dist: pyarrow>=21.0.0", ), ) def metadata_path(dist_name: str) -> Path | None: try: dist = im.distribution(dist_name) except im.PackageNotFoundError: return None for f in dist.files or []: sf = str(f) if sf.endswith(".dist-info/METADATA"): return Path(dist.locate_file(f)) return None def patch_file(path: Path) -> bool: original = path.read_text(encoding = "utf-8") updated = original for pattern, repl in PATCHES: updated = pattern.sub(repl, updated) if updated == original: return False path.write_text(updated, encoding = "utf-8") return True def main() -> int: changed = 0 checked = 0 for name in TARGETS: p = metadata_path(name) if p is None: continue checked += 1 if patch_file(p): changed += 1 print(f"single-env metadata patch: checked={checked}, changed={changed}") return 0 if __name__ == "__main__": raise SystemExit(main()) ================================================ FILE: studio/backend/requirements/studio.txt ================================================ # Studio UI backend dependencies typer fastapi uvicorn pydantic matplotlib pandas nest_asyncio datasets==4.3.0 pyjwt easydict addict # gradio>=4.0.0 # 148 MB - Studio uses React + FastAPI, not Gradio huggingface-hub==0.36.2 structlog>=24.1.0 diceware ddgs ================================================ FILE: studio/backend/requirements/triton-kernels.txt ================================================ # Triton kernels (installed with --no-deps, from source) triton_kernels @ git+https://github.com/triton-lang/triton.git@release/3.6.x#subdirectory=python/triton_kernels ================================================ FILE: studio/backend/routes/.gitkeep ================================================ ================================================ FILE: studio/backend/routes/__init__.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ API Routes """ from routes.training import router as training_router from routes.models import router as models_router from routes.inference import router as inference_router from routes.datasets import router as datasets_router from routes.auth import router as auth_router from routes.data_recipe import router as data_recipe_router from routes.export import router as export_router __all__ = [ "training_router", "models_router", "inference_router", "datasets_router", "auth_router", "data_recipe_router", "export_router", ] ================================================ FILE: studio/backend/routes/auth.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Authentication API routes """ from fastapi import APIRouter, Depends, HTTPException, status from models.auth import ( AuthLoginRequest, RefreshTokenRequest, AuthStatusResponse, ChangePasswordRequest, ) from models.users import Token from auth import storage, hashing from auth.authentication import ( create_access_token, create_refresh_token, get_current_subject, get_current_subject_allow_password_change, refresh_access_token, ) router = APIRouter() @router.get("/status", response_model = AuthStatusResponse) async def auth_status() -> AuthStatusResponse: """ Check whether auth has already been initialized. - initialized = False -> frontend should wait for the seeded admin bootstrap. - initialized = True -> frontend should show login or force the first password change. """ return AuthStatusResponse( initialized = storage.is_initialized(), default_username = storage.DEFAULT_ADMIN_USERNAME, requires_password_change = storage.requires_password_change( storage.DEFAULT_ADMIN_USERNAME ) if storage.is_initialized() else True, ) @router.post("/login", response_model = Token) async def login(payload: AuthLoginRequest) -> Token: """ Login with username/password and receive access + refresh tokens. """ record = storage.get_user_and_secret(payload.username) if record is None: raise HTTPException( status_code = status.HTTP_401_UNAUTHORIZED, detail = "Incorrect password. Run 'unsloth studio reset-password' in your terminal to reset it.", ) salt, pwd_hash, _jwt_secret, must_change_password = record if not hashing.verify_password(payload.password, salt, pwd_hash): raise HTTPException( status_code = status.HTTP_401_UNAUTHORIZED, detail = "Incorrect password. Run 'unsloth studio reset-password' in your terminal to reset it.", ) access_token = create_access_token(subject = payload.username) refresh_token = create_refresh_token(subject = payload.username) return Token( access_token = access_token, refresh_token = refresh_token, token_type = "bearer", must_change_password = must_change_password, ) @router.post("/refresh", response_model = Token) async def refresh(payload: RefreshTokenRequest) -> Token: """ Exchange a valid refresh token for a new access token. The refresh token itself is reusable until it expires (7 days). """ new_access_token, username = refresh_access_token(payload.refresh_token) if new_access_token is None or username is None: raise HTTPException( status_code = status.HTTP_401_UNAUTHORIZED, detail = "Invalid or expired refresh token", ) return Token( access_token = new_access_token, refresh_token = payload.refresh_token, token_type = "bearer", must_change_password = storage.requires_password_change(username), ) @router.post("/change-password", response_model = Token) async def change_password( payload: ChangePasswordRequest, current_subject: str = Depends(get_current_subject_allow_password_change), ) -> Token: """Allow the authenticated user to replace the default password.""" record = storage.get_user_and_secret(current_subject) if record is None: raise HTTPException( status_code = status.HTTP_401_UNAUTHORIZED, detail = "User session is invalid", ) salt, pwd_hash, _jwt_secret, _must_change_password = record if not hashing.verify_password(payload.current_password, salt, pwd_hash): raise HTTPException( status_code = status.HTTP_401_UNAUTHORIZED, detail = "Current password is incorrect", ) if payload.current_password == payload.new_password: raise HTTPException( status_code = status.HTTP_400_BAD_REQUEST, detail = "New password must be different from the current password", ) storage.update_password(current_subject, payload.new_password) storage.revoke_user_refresh_tokens(current_subject) access_token = create_access_token(subject = current_subject) refresh_token = create_refresh_token(subject = current_subject) return Token( access_token = access_token, refresh_token = refresh_token, token_type = "bearer", must_change_password = False, ) ================================================ FILE: studio/backend/routes/data_recipe/__init__.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """Data Recipe route package.""" from __future__ import annotations import sys from pathlib import Path from fastapi import APIRouter, Depends from auth.authentication import get_current_subject backend_path = Path(__file__).parent.parent.parent if str(backend_path) not in sys.path: sys.path.insert(0, str(backend_path)) from .jobs import router as jobs_router from .mcp import router as mcp_router from .seed import router as seed_router from .validate import router as validate_router router = APIRouter(dependencies = [Depends(get_current_subject)]) router.include_router(seed_router) router.include_router(validate_router) router.include_router(jobs_router) router.include_router(mcp_router) __all__ = ["router"] ================================================ FILE: studio/backend/routes/data_recipe/jobs.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """Job lifecycle endpoints for data recipe.""" from __future__ import annotations from typing import Any from fastapi import APIRouter, HTTPException, Query, Request from fastapi.responses import JSONResponse, StreamingResponse from pydantic import ValidationError from core.data_recipe.huggingface import ( RecipeDatasetPublishError, publish_recipe_dataset, ) from core.data_recipe.jobs import get_job_manager from models.data_recipe import ( JobCreateResponse, PublishDatasetRequest, PublishDatasetResponse, RecipePayload, ) router = APIRouter() def _normalize_run_name(value: Any) -> str | None: if value is None: return None if not isinstance(value, str): raise HTTPException( status_code = 400, detail = "invalid run_name: must be a string" ) trimmed = value.strip() if not trimmed: return None return trimmed[:120] @router.post("/jobs", response_class = JSONResponse, response_model = JobCreateResponse) def create_job(payload: RecipePayload): recipe = payload.recipe if not recipe.get("columns"): raise HTTPException(status_code = 400, detail = "Recipe must include columns.") run: dict[str, Any] = payload.run or {} run.pop("artifact_path", None) run.pop("dataset_name", None) execution_type = str(run.get("execution_type") or "full").strip().lower() if execution_type not in {"preview", "full"}: raise HTTPException( status_code = 400, detail = "invalid execution_type: must be 'preview' or 'full'", ) run["execution_type"] = execution_type run["run_name"] = _normalize_run_name(run.get("run_name")) run_config_raw = run.get("run_config") if run_config_raw is not None: try: from data_designer.config.run_config import RunConfig RunConfig.model_validate(run_config_raw) except (ImportError, ValidationError, TypeError, ValueError) as exc: raise HTTPException( status_code = 400, detail = f"invalid run_config: {exc}" ) from exc mgr = get_job_manager() try: job_id = mgr.start(recipe = recipe, run = run) except RuntimeError as exc: raise HTTPException(status_code = 409, detail = str(exc)) from exc except ValueError as exc: raise HTTPException(status_code = 400, detail = str(exc)) from exc return {"job_id": job_id} @router.get("/jobs/{job_id}/status") def job_status(job_id: str): mgr = get_job_manager() state = mgr.get_status(job_id) if state is None: raise HTTPException(status_code = 404, detail = "job not found") return state @router.get("/jobs/current") def current_job(): mgr = get_job_manager() state = mgr.get_current_status() if state is None: raise HTTPException(status_code = 404, detail = "no job") return state @router.post("/jobs/{job_id}/cancel") def cancel_job(job_id: str): mgr = get_job_manager() ok = mgr.cancel(job_id) if not ok: raise HTTPException(status_code = 404, detail = "job not found") return mgr.get_status(job_id) @router.get("/jobs/{job_id}/analysis") def job_analysis(job_id: str): mgr = get_job_manager() analysis = mgr.get_analysis(job_id) if analysis is None: raise HTTPException(status_code = 404, detail = "analysis not ready") return analysis @router.get("/jobs/{job_id}/dataset") def job_dataset( job_id: str, limit: int = Query(default = 20, ge = 1, le = 500), offset: int = Query(default = 0, ge = 0), ): mgr = get_job_manager() result = mgr.get_dataset(job_id, limit = limit, offset = offset) if result is None: raise HTTPException(status_code = 404, detail = "dataset not ready") if "error" in result: raise HTTPException(status_code = 422, detail = result["error"]) return { "dataset": result["dataset"], "total": result["total"], "limit": limit, "offset": offset, } @router.post( "/jobs/{job_id}/publish", response_class = JSONResponse, response_model = PublishDatasetResponse, ) def publish_job_dataset(job_id: str, payload: PublishDatasetRequest): repo_id = payload.repo_id.strip() description = payload.description.strip() hf_token = payload.hf_token.strip() if isinstance(payload.hf_token, str) else None artifact_path = ( payload.artifact_path.strip() if isinstance(payload.artifact_path, str) else None ) if not repo_id: raise HTTPException(status_code = 400, detail = "repo_id is required") if not description: raise HTTPException(status_code = 400, detail = "description is required") mgr = get_job_manager() status = mgr.get_status(job_id) if status is not None: if ( status.get("status") != "completed" or status.get("execution_type") != "full" ): raise HTTPException( status_code = 409, detail = "Only completed full runs can be published.", ) status_artifact = status.get("artifact_path") if isinstance(status_artifact, str) and status_artifact.strip(): artifact_path = status_artifact.strip() if not artifact_path: raise HTTPException( status_code = 400, detail = "This execution does not have publishable dataset artifacts.", ) try: url = publish_recipe_dataset( artifact_path = artifact_path, repo_id = repo_id, description = description, hf_token = hf_token or None, private = payload.private, ) except RecipeDatasetPublishError as exc: raise HTTPException(status_code = 400, detail = str(exc)) from exc except Exception as exc: raise HTTPException(status_code = 500, detail = str(exc)) from exc return { "success": True, "url": url, "message": f"Published dataset to {repo_id}.", } @router.get("/jobs/{job_id}/events") async def job_events(request: Request, job_id: str): mgr = get_job_manager() last_id = request.headers.get("last-event-id") after_seq: int | None = None if last_id: try: after_seq = int(str(last_id).strip()) except (TypeError, ValueError): after_seq = None after_q = request.query_params.get("after") if after_q: try: after_seq = int(str(after_q).strip()) except (TypeError, ValueError): pass sub = mgr.subscribe(job_id, after_seq = after_seq) if sub is None: raise HTTPException(status_code = 404, detail = "job not found") async def gen(): try: for event in sub.replay: yield sub.format_sse(event) while True: if await request.is_disconnected(): break event = await sub.next_event(timeout_sec = 1.0) if event is None: continue yield sub.format_sse(event) finally: mgr.unsubscribe(sub) return StreamingResponse(gen(), media_type = "text/event-stream") ================================================ FILE: studio/backend/routes/data_recipe/mcp.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """MCP helper endpoints for data recipe.""" from __future__ import annotations from collections import defaultdict from fastapi import APIRouter from core.data_recipe.service import build_mcp_providers from models.data_recipe import ( McpToolsListRequest, McpToolsListResponse, McpToolsProviderResult, ) router = APIRouter() @router.post("/mcp/tools", response_model = McpToolsListResponse) def list_mcp_tools(payload: McpToolsListRequest) -> McpToolsListResponse: try: from data_designer.engine.mcp import io as mcp_io except ImportError as exc: return McpToolsListResponse( providers = [ McpToolsProviderResult( name = "", error = f"MCP dependencies unavailable: {exc}", ) ] ) providers: list[McpToolsProviderResult] = [] tool_to_providers: dict[str, list[str]] = defaultdict(list) for provider_payload in payload.mcp_providers: provider_name = str(provider_payload.get("name", "")).strip() built = build_mcp_providers({"mcp_providers": [provider_payload]}) if len(built) != 1: providers.append( McpToolsProviderResult( name = provider_name, error = "Unsupported MCP provider config.", ) ) continue provider = built[0] try: tools = mcp_io.list_tools(provider, timeout_sec = payload.timeout_sec) tool_names = sorted( {tool.name for tool in tools if getattr(tool, "name", "")} ) for tool_name in tool_names: tool_to_providers[tool_name].append(provider.name) providers.append( McpToolsProviderResult( name = provider.name, tools = tool_names, ) ) except Exception as exc: providers.append( McpToolsProviderResult( name = provider.name or provider_name, error = str(exc).strip() or "Failed to load tools.", ) ) duplicate_tools = { tool_name: provider_names for tool_name, provider_names in sorted(tool_to_providers.items()) if len(provider_names) > 1 } return McpToolsListResponse( providers = providers, duplicate_tools = duplicate_tools, ) ================================================ FILE: studio/backend/routes/data_recipe/seed.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """Seed inspect endpoints for data recipe.""" from __future__ import annotations import base64 import binascii from itertools import islice from pathlib import Path from typing import Any from uuid import uuid4 from fastapi import APIRouter, HTTPException from data_designer_unstructured_seed.chunking import ( build_unstructured_preview_rows, resolve_chunking, ) from core.data_recipe.jsonable import to_preview_jsonable from utils.paths import ensure_dir, seed_uploads_root from models.data_recipe import ( SeedInspectRequest, SeedInspectResponse, SeedInspectUploadRequest, ) router = APIRouter() DATA_EXTS = (".parquet", ".jsonl", ".json", ".csv") DEFAULT_SPLIT = "train" LOCAL_UPLOAD_EXTS = {".csv", ".json", ".jsonl"} UNSTRUCTURED_UPLOAD_EXTS = {".txt", ".md"} SEED_UPLOAD_DIR = seed_uploads_root() def _serialize_preview_value(value: Any) -> Any: return to_preview_jsonable(value) def _serialize_preview_rows(rows: list[dict[str, Any]]) -> list[dict[str, Any]]: return [ {str(key): _serialize_preview_value(value) for key, value in row.items()} for row in rows ] def _normalize_optional_text(value: str | None) -> str | None: if value is None: return None trimmed = value.strip() return trimmed if trimmed else None def _list_hf_data_files(*, dataset_name: str, token: str | None) -> list[str]: try: from huggingface_hub import HfApi from huggingface_hub.utils import HfHubHTTPError except ImportError: return [] try: api = HfApi() repo_files = api.list_repo_files(dataset_name, repo_type = "dataset", token = token) return [file for file in repo_files if file.lower().endswith(DATA_EXTS)] except (HfHubHTTPError, OSError, ValueError): return [] def _select_best_file(data_files: list[str], split: str = DEFAULT_SPLIT) -> str | None: if not data_files: return None split_lower = split.lower() def score(path: str) -> tuple[int, int]: name = path.lower() if f"/{split_lower}/" in name: return (0, len(path)) if ( f"_{split_lower}." in name or f"-{split_lower}." in name or f"/{split_lower}." in name or f"/{split_lower}_" in name or f"/{split_lower}-" in name ): return (1, len(path)) return (2, len(path)) return sorted(data_files, key = score)[0] def _resolve_seed_hf_path( dataset_name: str, data_files: list[str], split: str = DEFAULT_SPLIT ) -> str | None: selected = _select_best_file(data_files, split) if not selected: return None ext = Path(selected).suffix.lower() if ext not in DATA_EXTS: return f"datasets/{dataset_name}/{selected}" parent = Path(selected).parent.as_posix() if not parent or parent == ".": return f"datasets/{dataset_name}/**/*{ext}" return f"datasets/{dataset_name}/{parent}/**/*{ext}" def _build_stream_load_kwargs( *, dataset_name: str, split: str, subset: str | None, token: str | None, data_file: str | None = None, ) -> dict[str, Any]: kwargs: dict[str, Any] = { "path": dataset_name, "split": split, "streaming": True, "trust_remote_code": False, } if data_file: kwargs["data_files"] = [data_file] if subset: kwargs["name"] = subset if token: kwargs["token"] = token return kwargs def _load_preview_rows( *, load_dataset_fn, load_kwargs: dict[str, Any], preview_size: int, ) -> list[dict[str, Any]]: streamed_ds = load_dataset_fn(**load_kwargs) return [row for row in islice(streamed_ds, preview_size)] def _extract_columns(rows: list[dict[str, Any]]) -> list[str]: columns_seen: dict[str, None] = {} for row in rows: for key in row.keys(): columns_seen[str(key)] = None return list(columns_seen.keys()) def _sanitize_filename(filename: str) -> str: name = Path(filename).name.strip().replace("\x00", "") if not name: return "seed_upload" return name def _decode_base64_payload(content_base64: str) -> bytes: raw = content_base64.strip() if "," in raw and raw.lower().startswith("data:"): raw = raw.split(",", 1)[1] try: return base64.b64decode(raw, validate = True) except binascii.Error as exc: raise HTTPException(status_code = 400, detail = "invalid base64 payload") from exc def _read_preview_rows_from_local_file( path: Path, preview_size: int ) -> list[dict[str, Any]]: try: import pandas as pd except ImportError as exc: raise HTTPException( status_code = 500, detail = f"seed inspect dependencies unavailable: {exc}" ) from exc ext = path.suffix.lower() try: if ext == ".csv": df = pd.read_csv(path, nrows = preview_size) elif ext == ".jsonl": df = pd.read_json(path, lines = True).head(preview_size) elif ext == ".json": try: df = pd.read_json(path).head(preview_size) except ValueError: df = pd.read_json(path, lines = True).head(preview_size) else: raise HTTPException(status_code = 422, detail = f"unsupported file type: {ext}") except HTTPException: raise except (ValueError, OSError) as exc: raise HTTPException( status_code = 422, detail = f"seed inspect failed: {exc}" ) from exc rows = df.to_dict(orient = "records") return _serialize_preview_rows(rows) def _read_preview_rows_from_unstructured_file( *, path: Path, preview_size: int, chunk_size: int | None, chunk_overlap: int | None, ) -> list[dict[str, Any]]: size, overlap = resolve_chunking(chunk_size, chunk_overlap) try: rows = build_unstructured_preview_rows( source_path = path, preview_size = preview_size, chunk_size = size, chunk_overlap = overlap, ) except (FileNotFoundError, RuntimeError, ValueError, OSError) as exc: raise HTTPException( status_code = 422, detail = f"seed inspect failed: {exc}" ) from exc return _serialize_preview_rows(rows) @router.post("/seed/inspect", response_model = SeedInspectResponse) def inspect_seed_dataset(payload: SeedInspectRequest) -> SeedInspectResponse: dataset_name = payload.dataset_name.strip() if not dataset_name or dataset_name.count("/") < 1: raise HTTPException( status_code = 400, detail = "dataset_name must be a Hugging Face repo id like org/repo", ) try: from datasets import load_dataset except ImportError as exc: raise HTTPException( status_code = 500, detail = f"seed inspect dependencies unavailable: {exc}" ) from exc split = _normalize_optional_text(payload.split) or DEFAULT_SPLIT subset = _normalize_optional_text(payload.subset) token = _normalize_optional_text(payload.hf_token) preview_size = int(payload.preview_size) preview_rows: list[dict[str, Any]] = [] data_files = _list_hf_data_files(dataset_name = dataset_name, token = token) selected_file = _select_best_file(data_files, split) if selected_file: try: single_file_kwargs = _build_stream_load_kwargs( dataset_name = dataset_name, split = split, subset = subset, token = token, data_file = selected_file, ) preview_rows = _load_preview_rows( load_dataset_fn = load_dataset, load_kwargs = single_file_kwargs, preview_size = preview_size, ) except (ValueError, OSError, RuntimeError): preview_rows = [] if not preview_rows: try: split_kwargs = _build_stream_load_kwargs( dataset_name = dataset_name, split = split, subset = subset, token = token, ) preview_rows = _load_preview_rows( load_dataset_fn = load_dataset, load_kwargs = split_kwargs, preview_size = preview_size, ) except (ValueError, OSError, RuntimeError) as exc: raise HTTPException( status_code = 422, detail = f"seed inspect failed: {exc}" ) from exc if not preview_rows: raise HTTPException( status_code = 422, detail = "dataset appears empty or unreadable" ) preview_rows = _serialize_preview_rows(preview_rows) columns = _extract_columns(preview_rows) if not data_files: resolved_path = f"datasets/{dataset_name}/**/*.parquet" else: resolved_path = _resolve_seed_hf_path(dataset_name, data_files, split) if not resolved_path: raise HTTPException( status_code = 422, detail = "unable to resolve seed dataset path" ) return SeedInspectResponse( dataset_name = dataset_name, resolved_path = resolved_path, columns = columns, preview_rows = preview_rows, split = split, subset = subset, ) @router.post("/seed/inspect-upload", response_model = SeedInspectResponse) def inspect_seed_upload(payload: SeedInspectUploadRequest) -> SeedInspectResponse: seed_source_type = _normalize_optional_text(payload.seed_source_type) or "local" filename = _sanitize_filename(payload.filename) ext = Path(filename).suffix.lower() if seed_source_type == "unstructured": if ext not in UNSTRUCTURED_UPLOAD_EXTS: allowed = ", ".join(sorted(UNSTRUCTURED_UPLOAD_EXTS)) raise HTTPException( status_code = 400, detail = f"unsupported file type: {ext}. allowed: {allowed}", ) else: if ext not in LOCAL_UPLOAD_EXTS: allowed = ", ".join(sorted(LOCAL_UPLOAD_EXTS)) raise HTTPException( status_code = 400, detail = f"unsupported file type: {ext}. allowed: {allowed}", ) file_bytes = _decode_base64_payload(payload.content_base64) if not file_bytes: raise HTTPException(status_code = 400, detail = "empty upload payload") max_size_bytes = 50 * 1024 * 1024 if len(file_bytes) > max_size_bytes: raise HTTPException(status_code = 413, detail = "file too large (max 50MB)") ensure_dir(SEED_UPLOAD_DIR) stored_name = f"{uuid4().hex}_{filename}" stored_path = SEED_UPLOAD_DIR / stored_name stored_path.write_bytes(file_bytes) if seed_source_type == "unstructured": preview_rows = _read_preview_rows_from_unstructured_file( path = stored_path, preview_size = int(payload.preview_size), chunk_size = payload.unstructured_chunk_size, chunk_overlap = payload.unstructured_chunk_overlap, ) else: preview_rows = _read_preview_rows_from_local_file( stored_path, int(payload.preview_size), ) if not preview_rows: raise HTTPException( status_code = 422, detail = "dataset appears empty or unreadable" ) columns = _extract_columns(preview_rows) return SeedInspectResponse( dataset_name = filename, resolved_path = str(stored_path), columns = columns, preview_rows = preview_rows, split = None, subset = None, ) ================================================ FILE: studio/backend/routes/data_recipe/validate.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """Validation endpoints for data recipe.""" from __future__ import annotations from typing import Any from fastapi import APIRouter, HTTPException from core.data_recipe.service import ( build_config_builder, create_data_designer, validate_recipe, ) from models.data_recipe import RecipePayload, ValidateError, ValidateResponse router = APIRouter() def _collect_validation_errors(recipe: dict[str, Any]) -> list[ValidateError]: try: from data_designer.engine.compiler import ( _add_internal_row_id_column_if_needed, _get_allowed_references, _resolve_and_add_seed_columns, ) from data_designer.engine.validation import ( ViolationLevel, validate_data_designer_config, ) except ImportError: return [] try: builder = build_config_builder(recipe) designer = create_data_designer(recipe) resource_provider = designer._create_resource_provider( # type: ignore[attr-defined] "validate-configuration", builder, ) config = builder.build() _resolve_and_add_seed_columns(config, resource_provider.seed_reader) _add_internal_row_id_column_if_needed(config) violations = validate_data_designer_config( columns = config.columns, processor_configs = config.processors or [], allowed_references = _get_allowed_references(config), ) except (TypeError, ValueError, AttributeError): return [] errors: list[ValidateError] = [] for violation in violations: if violation.level != ViolationLevel.ERROR: continue code = getattr(violation.type, "value", None) path = violation.column if violation.column else None message = str(violation.message).strip() or "Validation failed." errors.append( ValidateError( message = message, path = path, code = code, ) ) return errors @router.post("/validate", response_model = ValidateResponse) def validate(payload: RecipePayload) -> ValidateResponse: recipe = payload.recipe if not recipe.get("columns"): return ValidateResponse( valid = False, errors = [ValidateError(message = "Recipe must include columns.")], ) try: validate_recipe(recipe) except RuntimeError as exc: raise HTTPException(status_code = 503, detail = str(exc)) from exc except Exception as exc: detail = str(exc).strip() or "Validation failed." parsed_errors = _collect_validation_errors(recipe) return ValidateResponse( valid = False, errors = parsed_errors or [ValidateError(message = detail)], raw_detail = detail, ) return ValidateResponse(valid = True) ================================================ FILE: studio/backend/routes/datasets.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Datasets API routes """ import base64 import io import json import sys from pathlib import Path from uuid import uuid4 from fastapi import APIRouter, Depends, HTTPException, UploadFile import structlog from loggers import get_logger # Add backend directory to path backend_path = Path(__file__).parent.parent.parent if str(backend_path) not in sys.path: sys.path.insert(0, str(backend_path)) # Import dataset utilities from utils.datasets import check_dataset_format from auth.authentication import get_current_subject router = APIRouter() logger = get_logger(__name__) from models.datasets import ( AiAssistMappingRequest, AiAssistMappingResponse, CheckFormatRequest, CheckFormatResponse, LocalDatasetItem, LocalDatasetsResponse, UploadDatasetResponse, ) from utils.paths import ( dataset_uploads_root, ensure_dir, recipe_datasets_root, resolve_dataset_path, ) def _serialize_preview_value(value): """make it json safe for client preview ⊂(◉‿◉)つ""" if value is None or isinstance(value, (str, int, float, bool)): return value try: from PIL.Image import Image as PILImage if isinstance(value, PILImage): buffer = io.BytesIO() value.convert("RGB").save(buffer, format = "JPEG", quality = 85) return { "type": "image", "mime": "image/jpeg", "width": value.width, "height": value.height, "data": base64.b64encode(buffer.getvalue()).decode("ascii"), } except Exception: pass if isinstance(value, dict): return {str(key): _serialize_preview_value(item) for key, item in value.items()} if isinstance(value, (list, tuple)): return [_serialize_preview_value(item) for item in value] return str(value) def _serialize_preview_rows(rows): return [ {str(key): _serialize_preview_value(value) for key, value in dict(row).items()} for row in rows ] # --- Endpoints --- # Recognized data-file extensions for the single-file fallback approach. # Tabular formats are preferred over archives for Tier 1 preview because # archives (e.g. images.zip) may be loaded as ImageFolder datasets with # synthetic columns (image/label) that don't match the real dataset schema. _TABULAR_EXTS = (".parquet", ".json", ".jsonl", ".csv", ".tsv", ".arrow") _ARCHIVE_EXTS = (".tar", ".tar.gz", ".tgz", ".gz", ".zst", ".zip", ".txt") DATA_EXTS = _TABULAR_EXTS + _ARCHIVE_EXTS LOCAL_FILE_EXTS = (".json", ".jsonl", ".csv", ".parquet") LOCAL_UPLOAD_EXTS = {".csv", ".json", ".jsonl", ".parquet"} LOCAL_DATASETS_ROOT = recipe_datasets_root() DATASET_UPLOAD_DIR = dataset_uploads_root() def _safe_read_metadata(path: Path) -> dict | None: try: payload = json.loads(path.read_text(encoding = "utf-8")) except (OSError, ValueError, TypeError): return None if not isinstance(payload, dict): return None return payload def _safe_read_rows_from_metadata(payload: dict | None) -> int | None: if not payload: return None for key in ("actual_num_records", "target_num_records"): value = payload.get(key) if isinstance(value, int): return value return None def _safe_read_metadata_summary(payload: dict | None) -> dict | None: if not payload: return None actual_num_records = ( payload.get("actual_num_records") if isinstance(payload.get("actual_num_records"), int) else None ) target_num_records = ( payload.get("target_num_records") if isinstance(payload.get("target_num_records"), int) else actual_num_records ) columns: list[str] | None = None schema = payload.get("schema") if isinstance(schema, dict): columns = [str(key) for key in schema.keys()] if not columns: stats = payload.get("column_statistics") if isinstance(stats, list): derived = [ str(item.get("column_name")) for item in stats if isinstance(item, dict) and item.get("column_name") ] columns = derived or None parquet_files_count = None file_paths = payload.get("file_paths") if isinstance(file_paths, dict): parquet_files = file_paths.get("parquet-files") if isinstance(parquet_files, list): parquet_files_count = len(parquet_files) total_num_batches = ( payload.get("total_num_batches") if isinstance(payload.get("total_num_batches"), int) else parquet_files_count ) num_completed_batches = ( payload.get("num_completed_batches") if isinstance(payload.get("num_completed_batches"), int) else total_num_batches ) return { "actual_num_records": actual_num_records, "target_num_records": target_num_records, "total_num_batches": total_num_batches, "num_completed_batches": num_completed_batches, "columns": columns, } def _build_local_dataset_items() -> list[LocalDatasetItem]: if not LOCAL_DATASETS_ROOT.exists(): return [] items: list[LocalDatasetItem] = [] for entry in LOCAL_DATASETS_ROOT.iterdir(): if not entry.is_dir() or not entry.name.startswith("recipe_"): continue parquet_dir = entry / "parquet-files" if not parquet_dir.exists() or not any(parquet_dir.glob("*.parquet")): continue rows = None metadata_summary = None metadata_path = entry / "metadata.json" if metadata_path.exists(): metadata_payload = _safe_read_metadata(metadata_path) rows = _safe_read_rows_from_metadata(metadata_payload) metadata_summary = _safe_read_metadata_summary(metadata_payload) try: updated_at = entry.stat().st_mtime except OSError: updated_at = None items.append( LocalDatasetItem( id = entry.name, label = entry.name, path = str(parquet_dir.resolve()), rows = rows, updated_at = updated_at, metadata = metadata_summary, ) ) items.sort(key = lambda item: item.updated_at or 0, reverse = True) return items def _load_local_preview_slice( *, dataset_path: Path, train_split: str, preview_size: int ): from datasets import load_dataset if dataset_path.is_dir(): parquet_dir = ( dataset_path / "parquet-files" if (dataset_path / "parquet-files").exists() else dataset_path ) parquet_files = sorted(parquet_dir.glob("*.parquet")) if parquet_files: dataset = load_dataset( "parquet", data_files = [str(path) for path in parquet_files], split = train_split, ) total_rows = len(dataset) preview_slice = dataset.select(range(min(preview_size, total_rows))) return preview_slice, total_rows else: candidate_files: list[Path] = [] for ext in LOCAL_FILE_EXTS: candidate_files.extend(sorted(dataset_path.glob(f"*{ext}"))) if not candidate_files: raise HTTPException( status_code = 400, detail = "Unsupported local dataset directory (expected parquet/json/jsonl/csv files)", ) dataset_path = candidate_files[0] if dataset_path.suffix in [".json", ".jsonl"]: dataset = load_dataset("json", data_files = str(dataset_path), split = train_split) elif dataset_path.suffix == ".csv": dataset = load_dataset("csv", data_files = str(dataset_path), split = train_split) elif dataset_path.suffix == ".parquet": dataset = load_dataset( "parquet", data_files = str(dataset_path), split = train_split ) else: raise HTTPException( status_code = 400, detail = f"Unsupported file format: {dataset_path.suffix}" ) total_rows = len(dataset) preview_slice = dataset.select(range(min(preview_size, total_rows))) return preview_slice, total_rows def _sanitize_filename(filename: str) -> str: name = Path(filename).name.strip().replace("\x00", "") if not name: return "dataset_upload" return name @router.post("/upload", response_model = UploadDatasetResponse) async def upload_dataset( file: UploadFile, current_subject: str = Depends(get_current_subject), ) -> UploadDatasetResponse: filename = _sanitize_filename(file.filename or "dataset_upload") ext = Path(filename).suffix.lower() if ext not in LOCAL_UPLOAD_EXTS: allowed = ", ".join(sorted(LOCAL_UPLOAD_EXTS)) raise HTTPException( status_code = 400, detail = f"Unsupported file type: {ext}. Allowed: {allowed}", ) ensure_dir(DATASET_UPLOAD_DIR) stem = Path(filename).stem stored_name = f"{uuid4().hex}_{stem}{ext}" stored_path = DATASET_UPLOAD_DIR / stored_name # Stream file to disk in chunks to avoid holding entire file in memory with open(stored_path, "wb") as f: while chunk := await file.read(1024 * 1024): f.write(chunk) if stored_path.stat().st_size == 0: stored_path.unlink(missing_ok = True) raise HTTPException(status_code = 400, detail = "Empty upload payload") return UploadDatasetResponse(filename = filename, stored_path = str(stored_path)) @router.get("/local", response_model = LocalDatasetsResponse) def list_local_datasets( current_subject: str = Depends(get_current_subject), ) -> LocalDatasetsResponse: return LocalDatasetsResponse(datasets = _build_local_dataset_items()) @router.post("/check-format", response_model = CheckFormatResponse) def check_format( request: CheckFormatRequest, current_subject: str = Depends(get_current_subject), ): """ Check if a dataset requires manual column mapping. Strategy for HuggingFace datasets: 1. list_repo_files → pick the first data file → load_dataset(data_files=[…]) Avoids resolving thousands of files; typically ~2-4 s. 2. Full streaming load_dataset as a last-resort fallback. Local files are loaded directly. Using a plain `def` (not async) so FastAPI runs this in a thread-pool, preventing any blocking IO from freezing the event loop. """ try: from itertools import islice from datasets import Dataset, load_dataset from utils.datasets import format_dataset PREVIEW_SIZE = 10 logger.info(f"Checking format for dataset: {request.dataset_name}") dataset_path = resolve_dataset_path(request.dataset_name) total_rows = None if dataset_path.exists(): # ── Local file ────────────────────────────────────────── train_split = request.train_split or "train" preview_slice, total_rows = _load_local_preview_slice( dataset_path = dataset_path, train_split = train_split, preview_size = PREVIEW_SIZE, ) else: # ── HuggingFace dataset ───────────────────────────────── # Tier 1: list_repo_files → load only the first data file preview_slice = None try: from huggingface_hub import HfApi api = HfApi() repo_files = api.list_repo_files( request.dataset_name, repo_type = "dataset", token = request.hf_token or None, ) data_files = [ f for f in repo_files if any(f.endswith(ext) for ext in DATA_EXTS) ] # Prefer tabular formats over archives (e.g. images.zip → ImageFolder # with synthetic image/label columns that don't match the real schema). tabular_files = [ f for f in data_files if any(f.endswith(ext) for ext in _TABULAR_EXTS) ] candidates = tabular_files or data_files # When a subset is specified, narrow to files whose name matches # (e.g. subset="testmini" → prefer "testmini.parquet"). if request.subset and candidates: subset_matches = [ f for f in candidates if request.subset in Path(f).stem ] if subset_matches: candidates = subset_matches if candidates: first_file = candidates[0] logger.info(f"Tier 1: loading single file {first_file}") load_kwargs = { "path": request.dataset_name, "data_files": [first_file], "split": "train", "streaming": True, } if request.hf_token: load_kwargs["token"] = request.hf_token streamed_ds = load_dataset(**load_kwargs) rows = list(islice(streamed_ds, PREVIEW_SIZE)) if rows: preview_slice = Dataset.from_list(rows) except Exception as e: logger.warning(f"Tier 1 (single-file) failed: {e}") if preview_slice is None: # Tier 2: full streaming (resolves all files — slow for large repos) logger.info("Tier 2: falling back to full streaming load_dataset") load_kwargs = { "path": request.dataset_name, "split": request.train_split, "streaming": True, } if request.subset: load_kwargs["name"] = request.subset if request.hf_token: load_kwargs["token"] = request.hf_token streamed_ds = load_dataset(**load_kwargs) rows = list(islice(streamed_ds, PREVIEW_SIZE)) if not rows: raise HTTPException( status_code = 400, detail = "Dataset appears to be empty or could not be streamed", ) preview_slice = Dataset.from_list(rows) total_rows = None # Run lightweight format check on the preview slice result = check_dataset_format(preview_slice, is_vlm = request.is_vlm) logger.info( f"Format check result: requires_mapping={result['requires_manual_mapping']}, format={result['detected_format']}, is_image={result.get('is_image', False)}" ) # Generate preview samples preview_samples = None if not result["requires_manual_mapping"]: if result.get("suggested_mapping"): # Heuristic-detected: show raw data so columns match the API response. # Processing (column stripping) happens at training time, not preview. preview_samples = _serialize_preview_rows(preview_slice) else: try: format_result = format_dataset( preview_slice, format_type = "auto", num_proc = 1, # Only 10 preview rows — no need for multiprocessing ) processed = format_result["dataset"] preview_samples = _serialize_preview_rows(processed) except Exception as e: logger.warning( f"Processed preview generation failed (non-fatal): {e}" ) preview_samples = _serialize_preview_rows(preview_slice) else: preview_samples = _serialize_preview_rows(preview_slice) # Collect warnings: from check_dataset_format + URL-based image detection warning = result.get("warning") image_col = result.get("detected_image_column") if image_col and image_col in (result.get("columns") or []): try: sample_val = preview_slice[0][image_col] if isinstance(sample_val, str) and sample_val.startswith( ("http://", "https://") ): url_warning = ( "This dataset contains image URLs instead of embedded images. " "Images will be downloaded during training, which may be slow for large datasets." ) logger.info(f"URL-based image column detected: {image_col}") warning = f"{warning} {url_warning}" if warning else url_warning except Exception: pass return CheckFormatResponse( requires_manual_mapping = result["requires_manual_mapping"], detected_format = result["detected_format"], columns = result["columns"], is_image = result.get("is_image", False), is_audio = result.get("is_audio", False), multimodal_columns = result.get("multimodal_columns"), suggested_mapping = result.get("suggested_mapping"), detected_image_column = result.get("detected_image_column"), detected_audio_column = result.get("detected_audio_column"), detected_text_column = result.get("detected_text_column"), detected_speaker_column = result.get("detected_speaker_column"), preview_samples = preview_samples, total_rows = total_rows, warning = warning, ) except HTTPException: raise except Exception as e: logger.error(f"Error checking dataset format: {e}", exc_info = True) raise HTTPException( status_code = 500, detail = f"Failed to check dataset format: {str(e)}" ) @router.post("/ai-assist-mapping", response_model = AiAssistMappingResponse) def ai_assist_mapping( request: AiAssistMappingRequest, current_subject: str = Depends(get_current_subject), ): """ Run LLM-assisted dataset conversion advisor (user-triggered). Multi-pass analysis using a 7B helper model: Pass 1: Classify dataset type from HF card + samples Pass 2: Generate conversion strategy (system prompt, templates) Pass 3: Validate conversion quality Falls back to simple column classification if the advisor fails. """ try: from utils.datasets.llm_assist import llm_conversion_advisor # Truncate sample values for the LLM prompt truncated = [ {col: str(s.get(col, ""))[:200] for col in request.columns} for s in request.samples[:5] ] result = llm_conversion_advisor( column_names = request.columns, samples = truncated, dataset_name = request.dataset_name, hf_token = request.hf_token, model_name = request.model_name, model_type = request.model_type, ) if result and result.get("success"): return AiAssistMappingResponse( success = True, suggested_mapping = result.get("suggested_mapping"), system_prompt = result.get("system_prompt"), user_template = result.get("user_template"), assistant_template = result.get("assistant_template"), label_mapping = result.get("label_mapping"), dataset_type = result.get("dataset_type"), is_conversational = result.get("is_conversational"), user_notification = result.get("user_notification"), ) return AiAssistMappingResponse( success = False, warning = "AI could not determine column roles. Please assign them manually.", ) except Exception as e: logger.error(f"AI assist mapping failed: {e}", exc_info = True) raise HTTPException(status_code = 500, detail = f"AI assist failed: {str(e)}") ================================================ FILE: studio/backend/routes/export.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Export API routes: checkpoint discovery and model export operations. """ import sys from pathlib import Path from fastapi import APIRouter, Depends, HTTPException, Query import structlog from loggers import get_logger # Add backend directory to path backend_path = Path(__file__).parent.parent.parent if str(backend_path) not in sys.path: sys.path.insert(0, str(backend_path)) # Auth from auth.authentication import get_current_subject # Import backend functions try: from core.export import get_export_backend except ImportError: parent_backend = backend_path.parent / "backend" if str(parent_backend) not in sys.path: sys.path.insert(0, str(parent_backend)) from core.export import get_export_backend # Import Pydantic models from models import ( LoadCheckpointRequest, ExportStatusResponse, ExportOperationResponse, ExportMergedModelRequest, ExportBaseModelRequest, ExportGGUFRequest, ExportLoRAAdapterRequest, ) router = APIRouter() logger = get_logger(__name__) @router.post("/load-checkpoint", response_model = ExportOperationResponse) async def load_checkpoint( request: LoadCheckpointRequest, current_subject: str = Depends(get_current_subject), ): """ Load a checkpoint into the export backend. Wraps ExportBackend.load_checkpoint. """ try: # Version switching is handled automatically by the subprocess-based # export backend — no need for ensure_transformers_version() here. # Free GPU memory: shut down any running inference/training subprocesses # before loading the export checkpoint (they'd compete for VRAM). try: from core.inference import get_inference_backend inf = get_inference_backend() if inf.active_model_name: logger.info( "Unloading inference model '%s' to free GPU memory for export", inf.active_model_name, ) inf._shutdown_subprocess() inf.active_model_name = None inf.models.clear() except Exception as e: logger.warning("Could not unload inference model: %s", e) try: from core.training import get_training_backend trn = get_training_backend() if trn.is_training_active(): logger.info("Stopping active training to free GPU memory for export") trn.stop_training() # Wait for training subprocess to actually exit before proceeding, # otherwise it may still hold GPU memory when export tries to load. for _ in range(60): # up to 30s if not trn.is_training_active(): break import time time.sleep(0.5) else: logger.warning( "Training subprocess did not exit within 30s, proceeding anyway" ) except Exception as e: logger.warning("Could not stop training: %s", e) backend = get_export_backend() success, message = backend.load_checkpoint( checkpoint_path = request.checkpoint_path, max_seq_length = request.max_seq_length, load_in_4bit = request.load_in_4bit, trust_remote_code = request.trust_remote_code, ) if not success: raise HTTPException(status_code = 400, detail = message) return ExportOperationResponse(success = True, message = message) except HTTPException: raise except Exception as e: logger.error(f"Error loading checkpoint: {e}", exc_info = True) raise HTTPException( status_code = 500, detail = f"Failed to load checkpoint: {str(e)}", ) @router.post("/cleanup", response_model = ExportOperationResponse) async def cleanup_export_memory( current_subject: str = Depends(get_current_subject), ): """ Cleanup export-related models from memory (GPU/CPU). Wraps ExportBackend.cleanup_memory. """ try: backend = get_export_backend() success = backend.cleanup_memory() if not success: raise HTTPException( status_code = 500, detail = "Memory cleanup failed. See server logs for details.", ) return ExportOperationResponse( success = True, message = "Memory cleanup completed successfully", ) except HTTPException: raise except Exception as e: logger.error(f"Error during export memory cleanup: {e}", exc_info = True) raise HTTPException( status_code = 500, detail = f"Failed to cleanup export memory: {str(e)}", ) @router.get("/status", response_model = ExportStatusResponse) async def get_export_status( current_subject: str = Depends(get_current_subject), ): """ Get current export backend status (loaded checkpoint, model type, PEFT flag). """ try: backend = get_export_backend() return ExportStatusResponse( current_checkpoint = backend.current_checkpoint, is_vision = bool(getattr(backend, "is_vision", False)), is_peft = bool(getattr(backend, "is_peft", False)), ) except Exception as e: logger.error(f"Error getting export status: {e}", exc_info = True) raise HTTPException( status_code = 500, detail = f"Failed to get export status: {str(e)}", ) @router.post("/export/merged", response_model = ExportOperationResponse) async def export_merged_model( request: ExportMergedModelRequest, current_subject: str = Depends(get_current_subject), ): """ Export a merged PEFT model (e.g., 16-bit or 4-bit) and optionally push to Hub. Wraps ExportBackend.export_merged_model. """ try: backend = get_export_backend() success, message = backend.export_merged_model( save_directory = request.save_directory, format_type = request.format_type, push_to_hub = request.push_to_hub, repo_id = request.repo_id, hf_token = request.hf_token, private = request.private, ) if not success: raise HTTPException(status_code = 400, detail = message) return ExportOperationResponse(success = True, message = message) except HTTPException: raise except Exception as e: logger.error(f"Error exporting merged model: {e}", exc_info = True) raise HTTPException( status_code = 500, detail = f"Failed to export merged model: {str(e)}", ) @router.post("/export/base", response_model = ExportOperationResponse) async def export_base_model( request: ExportBaseModelRequest, current_subject: str = Depends(get_current_subject), ): """ Export a non-PEFT base model and optionally push to Hub. Wraps ExportBackend.export_base_model. """ try: backend = get_export_backend() success, message = backend.export_base_model( save_directory = request.save_directory, push_to_hub = request.push_to_hub, repo_id = request.repo_id, hf_token = request.hf_token, private = request.private, base_model_id = request.base_model_id, ) if not success: raise HTTPException(status_code = 400, detail = message) return ExportOperationResponse(success = True, message = message) except HTTPException: raise except Exception as e: logger.error(f"Error exporting base model: {e}", exc_info = True) raise HTTPException( status_code = 500, detail = f"Failed to export base model: {str(e)}", ) @router.post("/export/gguf", response_model = ExportOperationResponse) async def export_gguf( request: ExportGGUFRequest, current_subject: str = Depends(get_current_subject), ): """ Export the current model to GGUF format and optionally push to Hub. Wraps ExportBackend.export_gguf. """ try: backend = get_export_backend() success, message = backend.export_gguf( save_directory = request.save_directory, quantization_method = request.quantization_method, push_to_hub = request.push_to_hub, repo_id = request.repo_id, hf_token = request.hf_token, ) if not success: raise HTTPException(status_code = 400, detail = message) return ExportOperationResponse(success = True, message = message) except HTTPException: raise except Exception as e: logger.error(f"Error exporting GGUF model: {e}", exc_info = True) raise HTTPException( status_code = 500, detail = f"Failed to export GGUF model: {str(e)}", ) @router.post("/export/lora", response_model = ExportOperationResponse) async def export_lora_adapter( request: ExportLoRAAdapterRequest, current_subject: str = Depends(get_current_subject), ): """ Export only the LoRA adapter (if the loaded model is PEFT). Wraps ExportBackend.export_lora_adapter. """ try: backend = get_export_backend() success, message = backend.export_lora_adapter( save_directory = request.save_directory, push_to_hub = request.push_to_hub, repo_id = request.repo_id, hf_token = request.hf_token, private = request.private, ) if not success: raise HTTPException(status_code = 400, detail = message) return ExportOperationResponse(success = True, message = message) except HTTPException: raise except Exception as e: logger.error(f"Error exporting LoRA adapter: {e}", exc_info = True) raise HTTPException( status_code = 500, detail = f"Failed to export LoRA adapter: {str(e)}", ) ================================================ FILE: studio/backend/routes/inference.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Inference API routes for model loading and text generation. """ import sys import time import uuid from pathlib import Path from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.responses import StreamingResponse, JSONResponse from typing import Optional import json import structlog from loggers import get_logger import asyncio import threading # Add backend directory to path backend_path = Path(__file__).parent.parent.parent if str(backend_path) not in sys.path: sys.path.insert(0, str(backend_path)) # Import backend functions try: from core.inference import get_inference_backend from core.inference.llama_cpp import LlamaCppBackend from utils.models import ModelConfig from utils.inference import load_inference_config from utils.models.model_config import load_model_defaults except ImportError: parent_backend = backend_path.parent / "backend" if str(parent_backend) not in sys.path: sys.path.insert(0, str(parent_backend)) from core.inference import get_inference_backend from core.inference.llama_cpp import LlamaCppBackend from utils.models import ModelConfig from utils.inference import load_inference_config from utils.models.model_config import load_model_defaults from models.inference import ( LoadRequest, UnloadRequest, GenerateRequest, LoadResponse, UnloadResponse, InferenceStatusResponse, ChatCompletionRequest, ChatCompletionChunk, ChatCompletion, ChunkChoice, ChoiceDelta, CompletionChoice, CompletionMessage, ValidateModelRequest, ValidateModelResponse, ) from auth.authentication import get_current_subject import io import wave import base64 import numpy as np router = APIRouter() logger = get_logger(__name__) # GGUF inference backend (llama-server) _llama_cpp_backend = LlamaCppBackend() def get_llama_cpp_backend() -> LlamaCppBackend: return _llama_cpp_backend @router.post("/load", response_model = LoadResponse) async def load_model( request: LoadRequest, current_subject: str = Depends(get_current_subject), ): """ Load a model for inference. The model_path should be a clean identifier from GET /models/list. Returns inference configuration parameters (temperature, top_p, top_k, min_p) from the model's YAML config, falling back to default.yaml for missing values. GGUF models are loaded via llama-server (llama.cpp) instead of Unsloth. """ try: # Version switching is handled automatically by the subprocess-based # inference backend — no need for ensure_transformers_version() here. # ── Already-loaded check: skip reload if the exact model is active ── backend = get_inference_backend() llama_backend = get_llama_cpp_backend() if request.gguf_variant: if ( llama_backend.is_loaded and llama_backend.hf_variant and llama_backend.hf_variant.lower() == request.gguf_variant.lower() and llama_backend.model_identifier and llama_backend.model_identifier.lower() == request.model_path.lower() ): logger.info( f"Model already loaded (GGUF): {request.model_path} variant={request.gguf_variant}, skipping reload" ) inference_config = load_inference_config(llama_backend.model_identifier) from utils.models import is_audio_input_type _gguf_audio = ( llama_backend._audio_type if hasattr(llama_backend, "_audio_type") else None ) _gguf_is_audio = getattr(llama_backend, "_is_audio", False) return LoadResponse( status = "already_loaded", model = llama_backend.model_identifier, display_name = llama_backend.model_identifier, is_vision = llama_backend._is_vision, is_lora = False, is_gguf = True, is_audio = _gguf_is_audio, audio_type = _gguf_audio, has_audio_input = is_audio_input_type(_gguf_audio) if _gguf_audio else False, inference = inference_config, context_length = llama_backend.context_length, supports_reasoning = llama_backend.supports_reasoning, chat_template = llama_backend.chat_template, ) else: if ( backend.active_model_name and backend.active_model_name.lower() == request.model_path.lower() ): logger.info( f"Model already loaded (Unsloth): {request.model_path}, skipping reload" ) inference_config = load_inference_config(backend.active_model_name) _model_info = backend.models.get(backend.active_model_name, {}) _chat_template = None try: _tpl_info = _model_info.get("chat_template_info", {}) _chat_template = _tpl_info.get("template") except Exception as e: logger.warning( f"Could not retrieve chat template for {backend.active_model_name}: {e}" ) return LoadResponse( status = "already_loaded", model = backend.active_model_name, display_name = backend.active_model_name, is_vision = _model_info.get("is_vision", False), is_lora = _model_info.get("is_lora", False), is_gguf = False, is_audio = _model_info.get("is_audio", False), audio_type = _model_info.get("audio_type"), has_audio_input = _model_info.get("has_audio_input", False), inference = inference_config, chat_template = _chat_template, ) # Create config using clean factory method # is_lora is auto-detected from adapter_config.json on disk/HF config = ModelConfig.from_identifier( model_id = request.model_path, hf_token = request.hf_token, gguf_variant = request.gguf_variant, ) if not config: raise HTTPException( status_code = 400, detail = f"Invalid model identifier: {request.model_path}", ) # ── GGUF path: load via llama-server ────────────────────── if config.is_gguf: llama_backend = get_llama_cpp_backend() unsloth_backend = get_inference_backend() # Unload any active Unsloth model first to free VRAM if unsloth_backend.active_model_name: logger.info( f"Unloading Unsloth model '{unsloth_backend.active_model_name}' before loading GGUF" ) unsloth_backend.unload_model(unsloth_backend.active_model_name) # Route to HF mode or local mode based on config # Run in a thread so the event loop stays free for progress # polling and other requests during the (potentially long) # GGUF download + llama-server startup. if config.gguf_hf_repo: # HF mode: download via huggingface_hub then start llama-server success = await asyncio.to_thread( llama_backend.load_model, hf_repo = config.gguf_hf_repo, hf_variant = config.gguf_variant, hf_token = request.hf_token, model_identifier = config.identifier, is_vision = config.is_vision, n_ctx = request.max_seq_length, chat_template_override = request.chat_template_override, cache_type_kv = request.cache_type_kv, ) else: # Local mode: llama-server loads via -m success = await asyncio.to_thread( llama_backend.load_model, gguf_path = config.gguf_file, mmproj_path = config.gguf_mmproj_file, model_identifier = config.identifier, is_vision = config.is_vision, n_ctx = request.max_seq_length, chat_template_override = request.chat_template_override, cache_type_kv = request.cache_type_kv, ) if not success: raise HTTPException( status_code = 500, detail = f"Failed to load GGUF model: {config.display_name}", ) logger.info(f"Loaded GGUF model via llama-server: {config.identifier}") # Detect TTS audio by probing the loaded model's vocabulary from utils.models import is_audio_input_type _gguf_audio = llama_backend.detect_audio_type() _gguf_is_audio = _gguf_audio in ("snac", "bicodec", "dac") llama_backend._is_audio = _gguf_is_audio llama_backend._audio_type = _gguf_audio if _gguf_is_audio: logger.info(f"GGUF model detected as audio: audio_type={_gguf_audio}") await asyncio.to_thread(llama_backend.init_audio_codec, _gguf_audio) inference_config = load_inference_config(config.identifier) return LoadResponse( status = "loaded", model = config.identifier, display_name = config.display_name, is_vision = config.is_vision, is_lora = False, is_gguf = True, is_audio = _gguf_is_audio, audio_type = _gguf_audio, has_audio_input = is_audio_input_type(_gguf_audio), inference = inference_config, context_length = llama_backend.context_length, supports_reasoning = llama_backend.supports_reasoning, supports_tools = llama_backend.supports_tools, cache_type_kv = llama_backend.cache_type_kv, chat_template = llama_backend.chat_template, ) # ── Standard path: load via Unsloth/transformers ────────── backend = get_inference_backend() # Unload any active GGUF model first llama_backend = get_llama_cpp_backend() if llama_backend.is_loaded: logger.info("Unloading GGUF model before loading Unsloth model") llama_backend.unload_model() # Shut down any export subprocess to free VRAM try: from core.export import get_export_backend exp_backend = get_export_backend() if exp_backend.current_checkpoint: logger.info( "Shutting down export subprocess to free GPU memory for inference" ) exp_backend._shutdown_subprocess() exp_backend.current_checkpoint = None exp_backend.is_vision = False exp_backend.is_peft = False except Exception as e: logger.warning("Could not shut down export subprocess: %s", e) # Auto-detect quantization for LoRA adapters from adapter_config.json # The training pipeline patches this file with "unsloth_training_method" # which is 'qlora' or 'lora'. Only LoRA (16-bit) needs load_in_4bit=False. load_in_4bit = request.load_in_4bit if config.is_lora and config.path: import json from pathlib import Path adapter_cfg_path = Path(config.path) / "adapter_config.json" if adapter_cfg_path.exists(): try: with open(adapter_cfg_path) as f: adapter_cfg = json.load(f) training_method = adapter_cfg.get("unsloth_training_method") if training_method == "lora" and load_in_4bit: logger.info( f"adapter_config.json says unsloth_training_method='lora' — " f"setting load_in_4bit=False to match 16-bit training" ) load_in_4bit = False elif training_method == "qlora" and not load_in_4bit: logger.info( f"adapter_config.json says unsloth_training_method='qlora' — " f"setting load_in_4bit=True to match QLoRA training" ) load_in_4bit = True elif training_method: logger.info( f"Training method: {training_method}, load_in_4bit={load_in_4bit}" ) else: # No unsloth_training_method — fallback to base model name if ( config.base_model and "-bnb-4bit" not in config.base_model.lower() and load_in_4bit ): logger.info( f"No unsloth_training_method in adapter_config.json. " f"Base model '{config.base_model}' has no -bnb-4bit suffix — " f"setting load_in_4bit=False" ) load_in_4bit = False except Exception as e: logger.warning(f"Could not read adapter_config.json: {e}") # Load the model in a thread so the event loop stays free # for download progress polling and other requests. success = await asyncio.to_thread( backend.load_model, config = config, max_seq_length = request.max_seq_length, load_in_4bit = load_in_4bit, hf_token = request.hf_token, trust_remote_code = request.trust_remote_code, ) if not success: # Check if YAML says this model needs trust_remote_code if not request.trust_remote_code: model_defaults = load_model_defaults(config.identifier) yaml_trust = model_defaults.get("inference", {}).get( "trust_remote_code", False ) if yaml_trust: raise HTTPException( status_code = 400, detail = ( f"Model '{config.display_name}' requires trust_remote_code to be enabled. " f"Please enable 'Trust remote code' in Chat Settings and try again." ), ) raise HTTPException( status_code = 500, detail = f"Failed to load model: {config.display_name}" ) logger.info(f"Loaded model: {config.identifier}") # Load inference configuration parameters inference_config = load_inference_config(config.identifier) # Get chat template from tokenizer _chat_template = None try: _model_info = backend.models.get(config.identifier, {}) _tpl_info = _model_info.get("chat_template_info", {}) _chat_template = _tpl_info.get("template") except Exception: pass return LoadResponse( status = "loaded", model = config.identifier, display_name = config.display_name, is_vision = config.is_vision, is_lora = config.is_lora, is_gguf = False, is_audio = config.is_audio, audio_type = config.audio_type, has_audio_input = config.has_audio_input, inference = inference_config, chat_template = _chat_template, ) except HTTPException: raise except Exception as e: logger.error(f"Error loading model: {e}", exc_info = True) msg = str(e) # Surface a friendlier message for models that Unsloth cannot load not_supported_hints = [ "No config file found", "not yet supported", "is not supported", "does not support", ] if any(h.lower() in msg.lower() for h in not_supported_hints): msg = f"This model is not supported yet. Try a different model. (Original error: {msg})" raise HTTPException(status_code = 500, detail = f"Failed to load model: {msg}") @router.post("/validate", response_model = ValidateModelResponse) async def validate_model( request: ValidateModelRequest, current_subject: str = Depends(get_current_subject), ): """ Lightweight validation endpoint for model identifiers. This checks that ModelConfig.from_identifier() can resolve the given model_path, but it does NOT actually load model weights into GPU memory. """ try: config = ModelConfig.from_identifier( model_id = request.model_path, hf_token = request.hf_token, gguf_variant = request.gguf_variant, ) if not config: raise HTTPException( status_code = 400, detail = f"Invalid model identifier: {request.model_path}", ) return ValidateModelResponse( valid = True, message = "Model identifier is valid.", identifier = config.identifier, display_name = getattr(config, "display_name", config.identifier), is_gguf = getattr(config, "is_gguf", False), is_lora = getattr(config, "is_lora", False), is_vision = getattr(config, "is_vision", False), ) except HTTPException: raise except Exception as e: logger.error( f"Error validating model identifier '{request.model_path}': {e}", exc_info = True, ) raise HTTPException( status_code = 400, detail = f"Invalid model: {str(e)}", ) @router.post("/unload", response_model = UnloadResponse) async def unload_model( request: UnloadRequest, current_subject: str = Depends(get_current_subject), ): """ Unload a model from memory. Routes to the correct backend (llama-server for GGUF, Unsloth otherwise). """ try: # Check if the GGUF backend has this model loaded or is loading it llama_backend = get_llama_cpp_backend() if llama_backend.is_active and ( llama_backend.model_identifier == request.model_path or not llama_backend.is_loaded ): llama_backend.unload_model() logger.info(f"Unloaded GGUF model: {request.model_path}") return UnloadResponse(status = "unloaded", model = request.model_path) # Otherwise, unload from Unsloth backend backend = get_inference_backend() backend.unload_model(request.model_path) logger.info(f"Unloaded model: {request.model_path}") return UnloadResponse(status = "unloaded", model = request.model_path) except Exception as e: logger.error(f"Error unloading model: {e}", exc_info = True) raise HTTPException(status_code = 500, detail = f"Failed to unload model: {str(e)}") @router.post("/generate/stream") async def generate_stream( request: GenerateRequest, current_subject: str = Depends(get_current_subject), ): """ Generate a chat response with Server-Sent Events (SSE) streaming. For vision models, provide image_base64 with the base64-encoded image. """ backend = get_inference_backend() if not backend.active_model_name: raise HTTPException( status_code = 400, detail = "No model loaded. Call POST /inference/load first." ) # Decode image if provided (for vision models) image = None if request.image_base64: try: import base64 from PIL import Image from io import BytesIO # Check if current model supports vision model_info = backend.models.get(backend.active_model_name, {}) if not model_info.get("is_vision"): raise HTTPException( status_code = 400, detail = "Image provided but current model is text-only. Load a vision model.", ) image_data = base64.b64decode(request.image_base64) image = Image.open(BytesIO(image_data)) image = backend.resize_image(image) except HTTPException: raise except Exception as e: raise HTTPException( status_code = 400, detail = f"Failed to decode image: {str(e)}" ) async def stream(): try: for chunk in backend.generate_chat_response( messages = request.messages, system_prompt = request.system_prompt, image = image, temperature = request.temperature, top_p = request.top_p, top_k = request.top_k, max_new_tokens = request.max_new_tokens, repetition_penalty = request.repetition_penalty, ): yield f"data: {json.dumps({'content': chunk})}\n\n" yield "data: [DONE]\n\n" except Exception as e: backend.reset_generation_state() logger.error(f"Error during generation: {e}", exc_info = True) yield f"data: {json.dumps({'error': 'An internal error occurred'})}\n\n" return StreamingResponse( stream(), media_type = "text/event-stream", headers = { "Cache-Control": "no-cache", "Connection": "keep-alive", }, ) @router.get("/status", response_model = InferenceStatusResponse) async def get_status( current_subject: str = Depends(get_current_subject), ): """ Get current inference backend status. Reports whichever backend (Unsloth or llama-server) is currently active. """ try: llama_backend = get_llama_cpp_backend() # If a GGUF model is loaded via llama-server, report that if llama_backend.is_loaded: _model_id = llama_backend.model_identifier _inference_cfg = load_inference_config(_model_id) if _model_id else None return InferenceStatusResponse( active_model = _model_id, is_vision = llama_backend.is_vision, is_gguf = True, gguf_variant = llama_backend.hf_variant, is_audio = getattr(llama_backend, "_is_audio", False), audio_type = getattr(llama_backend, "_audio_type", None), loading = [], loaded = [_model_id], inference = _inference_cfg, supports_reasoning = llama_backend.supports_reasoning, supports_tools = llama_backend.supports_tools, context_length = llama_backend.context_length, ) # Otherwise, report Unsloth backend status backend = get_inference_backend() is_vision = False is_audio = False audio_type = None has_audio_input = False if backend.active_model_name: model_info = backend.models.get(backend.active_model_name, {}) is_vision = model_info.get("is_vision", False) is_audio = model_info.get("is_audio", False) audio_type = model_info.get("audio_type") has_audio_input = model_info.get("has_audio_input", False) # gpt-oss safetensors models support reasoning via harmony channels supports_reasoning = False if backend.active_model_name and hasattr(backend, "_is_gpt_oss_model"): supports_reasoning = backend._is_gpt_oss_model() return InferenceStatusResponse( active_model = backend.active_model_name, is_vision = is_vision, is_gguf = False, is_audio = is_audio, audio_type = audio_type, has_audio_input = has_audio_input, loading = list(getattr(backend, "loading_models", set())), loaded = list(backend.models.keys()), supports_reasoning = supports_reasoning, ) except Exception as e: logger.error(f"Error getting status: {e}", exc_info = True) raise HTTPException(status_code = 500, detail = f"Failed to get status: {str(e)}") # ===================================================================== # Audio (TTS) Generation (/audio/generate) # ===================================================================== @router.post("/audio/generate") async def generate_audio( payload: ChatCompletionRequest, request: Request, current_subject: str = Depends(get_current_subject), ): """ Generate audio (TTS) from the latest user message. Returns a JSON response with base64-encoded WAV audio. Works with both GGUF (llama-server) and Unsloth/transformers backends. """ import base64 # Extract text from the last user message _, chat_messages, _ = _extract_content_parts(payload.messages) if not chat_messages: raise HTTPException(status_code = 400, detail = "No messages provided.") last_user_msg = next( (m for m in reversed(chat_messages) if m["role"] == "user"), None ) if not last_user_msg: raise HTTPException(status_code = 400, detail = "No user message found.") text = last_user_msg["content"] # Pick backend — both return (wav_bytes, sample_rate) llama_backend = get_llama_cpp_backend() if llama_backend.is_loaded and getattr(llama_backend, "_is_audio", False): model_name = llama_backend.model_identifier gen = lambda: llama_backend.generate_audio_response( text = text, audio_type = llama_backend._audio_type, temperature = payload.temperature, top_p = payload.top_p, top_k = payload.top_k, min_p = payload.min_p, max_new_tokens = payload.max_tokens or 2048, repetition_penalty = payload.repetition_penalty, ) else: backend = get_inference_backend() if not backend.active_model_name: raise HTTPException(status_code = 400, detail = "No model loaded.") model_info = backend.models.get(backend.active_model_name, {}) if not model_info.get("is_audio"): raise HTTPException( status_code = 400, detail = "Active model is not an audio model." ) model_name = backend.active_model_name gen = lambda: backend.generate_audio_response( text = text, temperature = payload.temperature, top_p = payload.top_p, top_k = payload.top_k, min_p = payload.min_p, max_new_tokens = payload.max_tokens or 2048, repetition_penalty = payload.repetition_penalty, use_adapter = payload.use_adapter, ) try: wav_bytes, sample_rate = await asyncio.get_event_loop().run_in_executor( None, gen ) except Exception as e: logger.error(f"Audio generation error: {e}", exc_info = True) raise HTTPException(status_code = 500, detail = str(e)) audio_b64 = base64.b64encode(wav_bytes).decode("ascii") return JSONResponse( content = { "id": f"chatcmpl-{uuid.uuid4().hex[:12]}", "object": "chat.completion.audio", "model": model_name, "audio": {"data": audio_b64, "format": "wav", "sample_rate": sample_rate}, "choices": [ { "index": 0, "message": { "role": "assistant", "content": f'[Generated audio from: "{text[:100]}"]', }, "finish_reason": "stop", } ], } ) # ===================================================================== # OpenAI-Compatible Chat Completions (/chat/completions) # ===================================================================== def _decode_audio_base64(b64: str) -> np.ndarray: """Decode base64 audio (any format) → float32 numpy array at 16kHz.""" import torch import torchaudio import tempfile import os from utils.paths import ensure_dir, tmp_root raw = base64.b64decode(b64) # torchaudio.load needs a file path or file-like object with format hint # Write to a temp file so torchaudio can auto-detect the format with tempfile.NamedTemporaryFile( suffix = ".audio", delete = False, dir = str(ensure_dir(tmp_root())), ) as tmp: tmp.write(raw) tmp_path = tmp.name try: waveform, sr = torchaudio.load(tmp_path) finally: os.unlink(tmp_path) # Convert to mono if stereo if waveform.shape[0] > 1: waveform = waveform.mean(dim = 0, keepdim = True) # Resample to 16kHz if needed if sr != 16000: resampler = torchaudio.transforms.Resample(orig_freq = sr, new_freq = 16000) waveform = resampler(waveform) return waveform.squeeze(0).numpy() def _extract_content_parts( messages: list, ) -> tuple[str, list[dict], "Optional[str]"]: """ Parse OpenAI-format messages into components the inference backend expects. Handles both plain-string ``content`` and multimodal content-part arrays (``[{type: "text", ...}, {type: "image_url", ...}]``). Returns: system_prompt: The system message text (empty string if none provided). chat_messages: Non-system messages with content flattened to strings. image_base64: Base64 data of the *first* image found, or ``None``. """ system_prompt = "" chat_messages: list[dict] = [] first_image_b64: Optional[str] = None for msg in messages: # ── System messages → extract as system_prompt ──────── if msg.role == "system": if isinstance(msg.content, str): system_prompt = msg.content elif isinstance(msg.content, list): # Unlikely but handle: join text parts system_prompt = "\n".join( p.text for p in msg.content if p.type == "text" ) continue # ── User / assistant messages ───────────────────────── if isinstance(msg.content, str): # Plain string content — pass through chat_messages.append({"role": msg.role, "content": msg.content}) elif isinstance(msg.content, list): # Multimodal content parts text_parts: list[str] = [] for part in msg.content: if part.type == "text": text_parts.append(part.text) elif part.type == "image_url" and first_image_b64 is None: url = part.image_url.url if url.startswith("data:"): # data:image/png;base64, → extract first_image_b64 = url.split(",", 1)[1] if "," in url else None else: logger.warning( f"Remote image URLs not yet supported: {url[:80]}..." ) combined_text = "\n".join(text_parts) if text_parts else "" chat_messages.append({"role": msg.role, "content": combined_text}) return system_prompt, chat_messages, first_image_b64 @router.post("/chat/completions") async def openai_chat_completions( payload: ChatCompletionRequest, request: Request, current_subject: str = Depends(get_current_subject), ): """ OpenAI-compatible chat completions endpoint. Supports multimodal messages: ``content`` may be a plain string or a list of content parts (``text`` / ``image_url``). Streaming (default): returns SSE chunks matching OpenAI's format. Non-streaming: returns a single ChatCompletion JSON object. Automatically routes to the correct backend: - GGUF models → llama-server via LlamaCppBackend - Other models → Unsloth/transformers via InferenceBackend """ llama_backend = get_llama_cpp_backend() using_gguf = llama_backend.is_loaded # ── Determine which backend is active ───────────────────── if using_gguf: model_name = llama_backend.model_identifier or payload.model if getattr(llama_backend, "_is_audio", False): return await generate_audio(payload, request) else: backend = get_inference_backend() if not backend.active_model_name: raise HTTPException( status_code = 400, detail = "No model loaded. Call POST /inference/load first.", ) model_name = backend.active_model_name or payload.model # ── Audio TTS path: auto-route to audio generation ──── # (Whisper is ASR not TTS — handled below in audio input path) model_info = backend.models.get(backend.active_model_name, {}) if model_info.get("is_audio") and model_info.get("audio_type") != "whisper": return await generate_audio(payload, request) # ── Whisper without audio: return clear error ── if model_info.get("audio_type") == "whisper" and not payload.audio_base64: raise HTTPException( status_code = 400, detail = "Whisper models require audio input. Please upload an audio file.", ) # ── Audio INPUT path: decode WAV and route to audio input generation ── if payload.audio_base64 and model_info.get("has_audio_input"): audio_array = _decode_audio_base64(payload.audio_base64) system_prompt, chat_messages, _ = _extract_content_parts(payload.messages) cancel_event = threading.Event() completion_id = f"chatcmpl-{uuid.uuid4().hex[:12]}" created = int(time.time()) def audio_input_generate(): if model_info.get("audio_type") == "whisper": return backend.generate_whisper_response( audio_array = audio_array, cancel_event = cancel_event, ) return backend.generate_audio_input_response( messages = chat_messages, system_prompt = system_prompt, audio_array = audio_array, temperature = payload.temperature, top_p = payload.top_p, top_k = payload.top_k, min_p = payload.min_p, max_new_tokens = payload.max_tokens or 2048, repetition_penalty = payload.repetition_penalty, cancel_event = cancel_event, ) if payload.stream: async def audio_input_stream(): try: first_chunk = ChatCompletionChunk( id = completion_id, created = created, model = model_name, choices = [ ChunkChoice( delta = ChoiceDelta(role = "assistant"), finish_reason = None, ) ], ) yield f"data: {first_chunk.model_dump_json(exclude_none = True)}\n\n" for chunk_text in audio_input_generate(): if await request.is_disconnected(): cancel_event.set() return if chunk_text: chunk = ChatCompletionChunk( id = completion_id, created = created, model = model_name, choices = [ ChunkChoice( delta = ChoiceDelta(content = chunk_text), finish_reason = None, ) ], ) yield f"data: {chunk.model_dump_json(exclude_none = True)}\n\n" final_chunk = ChatCompletionChunk( id = completion_id, created = created, model = model_name, choices = [ ChunkChoice(delta = ChoiceDelta(), finish_reason = "stop") ], ) yield f"data: {final_chunk.model_dump_json(exclude_none = True)}\n\n" yield "data: [DONE]\n\n" except asyncio.CancelledError: cancel_event.set() raise except Exception as e: logger.error( f"Error during audio input streaming: {e}", exc_info = True ) yield f"data: {json.dumps({'error': {'message': 'An internal error occurred', 'type': 'server_error'}})}\n\n" return StreamingResponse( audio_input_stream(), media_type = "text/event-stream", headers = { "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", }, ) else: full_text = "".join(audio_input_generate()) response = ChatCompletion( id = completion_id, created = created, model = model_name, choices = [ CompletionChoice( message = CompletionMessage(content = full_text), finish_reason = "stop", ) ], ) return JSONResponse(content = response.model_dump()) # ── Parse messages (handles multimodal content parts) ───── system_prompt, chat_messages, extracted_image_b64 = _extract_content_parts( payload.messages ) if not chat_messages: raise HTTPException( status_code = 400, detail = "At least one non-system message is required.", ) # ── GGUF path: proxy to llama-server /v1/chat/completions ── if using_gguf: # Reject images if this GGUF model doesn't support vision image_b64 = extracted_image_b64 or payload.image_base64 if image_b64 and not llama_backend.is_vision: raise HTTPException( status_code = 400, detail = "Image provided but current GGUF model does not support vision.", ) # Convert image to PNG for llama-server (stb_image has limited format support) if image_b64: try: import base64 as _b64 from io import BytesIO as _BytesIO from PIL import Image as _Image raw = _b64.b64decode(image_b64) img = _Image.open(_BytesIO(raw)) if img.mode == "RGBA": img = img.convert("RGB") buf = _BytesIO() img.save(buf, format = "PNG") image_b64 = _b64.b64encode(buf.getvalue()).decode("ascii") except Exception as e: raise HTTPException( status_code = 400, detail = f"Failed to process image: {e}" ) # Build message list with system prompt prepended gguf_messages = [] if system_prompt: gguf_messages.append({"role": "system", "content": system_prompt}) gguf_messages.extend(chat_messages) cancel_event = threading.Event() completion_id = f"chatcmpl-{uuid.uuid4().hex[:12]}" created = int(time.time()) # ── Tool-calling path (agentic loop) ────────────────── use_tools = ( payload.enable_tools and llama_backend.supports_tools and not image_b64 ) if use_tools: from core.inference.tools import ALL_TOOLS if payload.enabled_tools is not None: tools_to_use = [ t for t in ALL_TOOLS if t["function"]["name"] in payload.enabled_tools ] else: tools_to_use = ALL_TOOLS def gguf_generate_with_tools(): return llama_backend.generate_chat_completion_with_tools( messages = gguf_messages, tools = tools_to_use, temperature = payload.temperature, top_p = payload.top_p, top_k = payload.top_k, min_p = payload.min_p, max_tokens = payload.max_tokens, repetition_penalty = payload.repetition_penalty, presence_penalty = payload.presence_penalty, cancel_event = cancel_event, enable_thinking = payload.enable_thinking, auto_heal_tool_calls = payload.auto_heal_tool_calls if payload.auto_heal_tool_calls is not None else True, max_tool_iterations = payload.max_tool_calls_per_message if payload.max_tool_calls_per_message is not None else 10, tool_call_timeout = payload.tool_call_timeout if payload.tool_call_timeout is not None else 300, session_id = payload.session_id, ) _tool_sentinel = object() async def gguf_tool_stream(): try: first_chunk = ChatCompletionChunk( id = completion_id, created = created, model = model_name, choices = [ ChunkChoice( delta = ChoiceDelta(role = "assistant"), finish_reason = None, ) ], ) yield f"data: {first_chunk.model_dump_json(exclude_none = True)}\n\n" # Iterate the synchronous generator in a thread so # the event loop stays free for disconnect detection. gen = gguf_generate_with_tools() prev_text = "" while True: if await request.is_disconnected(): cancel_event.set() return event = await asyncio.to_thread(next, gen, _tool_sentinel) if event is _tool_sentinel: break if event["type"] == "status": # Emit tool status as a custom SSE event status_data = json.dumps( { "type": "tool_status", "content": event["text"], } ) yield f"data: {status_data}\n\n" continue if event["type"] in ("tool_start", "tool_end"): yield f"data: {json.dumps(event)}\n\n" continue # "content" type -- cumulative text cumulative = event.get("text", "") new_text = cumulative[len(prev_text) :] prev_text = cumulative if not new_text: continue chunk = ChatCompletionChunk( id = completion_id, created = created, model = model_name, choices = [ ChunkChoice( delta = ChoiceDelta(content = new_text), finish_reason = None, ) ], ) yield f"data: {chunk.model_dump_json(exclude_none = True)}\n\n" final_chunk = ChatCompletionChunk( id = completion_id, created = created, model = model_name, choices = [ ChunkChoice( delta = ChoiceDelta(), finish_reason = "stop", ) ], ) yield f"data: {final_chunk.model_dump_json(exclude_none = True)}\n\n" yield "data: [DONE]\n\n" except asyncio.CancelledError: cancel_event.set() raise except Exception as e: import traceback tb = traceback.format_exc() logger.error(f"Error during GGUF tool streaming: {e}\n{tb}") error_chunk = { "error": { "message": "An internal error occurred", "type": "server_error", }, } yield f"data: {json.dumps(error_chunk)}\n\n" return StreamingResponse( gguf_tool_stream(), media_type = "text/event-stream", headers = { "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", }, ) # ── Standard GGUF path (no tools) ───────────────────── def gguf_generate(): return llama_backend.generate_chat_completion( messages = gguf_messages, image_b64 = image_b64, temperature = payload.temperature, top_p = payload.top_p, top_k = payload.top_k, min_p = payload.min_p, max_tokens = payload.max_tokens, repetition_penalty = payload.repetition_penalty, presence_penalty = payload.presence_penalty, cancel_event = cancel_event, enable_thinking = payload.enable_thinking, ) _gguf_sentinel = object() if payload.stream: async def gguf_stream_chunks(): try: # First chunk: role first_chunk = ChatCompletionChunk( id = completion_id, created = created, model = model_name, choices = [ ChunkChoice( delta = ChoiceDelta(role = "assistant"), finish_reason = None, ) ], ) yield f"data: {first_chunk.model_dump_json(exclude_none = True)}\n\n" # Iterate the synchronous generator in a thread so # the event loop stays free for disconnect detection. gen = gguf_generate() prev_text = "" while True: if await request.is_disconnected(): cancel_event.set() return cumulative = await asyncio.to_thread(next, gen, _gguf_sentinel) if cumulative is _gguf_sentinel: break new_text = cumulative[len(prev_text) :] prev_text = cumulative if not new_text: continue chunk = ChatCompletionChunk( id = completion_id, created = created, model = model_name, choices = [ ChunkChoice( delta = ChoiceDelta(content = new_text), finish_reason = None, ) ], ) yield f"data: {chunk.model_dump_json(exclude_none = True)}\n\n" # Final chunk final_chunk = ChatCompletionChunk( id = completion_id, created = created, model = model_name, choices = [ ChunkChoice( delta = ChoiceDelta(), finish_reason = "stop", ) ], ) yield f"data: {final_chunk.model_dump_json(exclude_none = True)}\n\n" yield "data: [DONE]\n\n" except asyncio.CancelledError: cancel_event.set() raise except Exception as e: logger.error(f"Error during GGUF streaming: {e}", exc_info = True) error_chunk = { "error": { "message": "An internal error occurred", "type": "server_error", }, } yield f"data: {json.dumps(error_chunk)}\n\n" return StreamingResponse( gguf_stream_chunks(), media_type = "text/event-stream", headers = { "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", }, ) else: try: full_text = "" for token in gguf_generate(): full_text = token response = ChatCompletion( id = completion_id, created = created, model = model_name, choices = [ CompletionChoice( message = CompletionMessage(content = full_text), finish_reason = "stop", ) ], ) return JSONResponse(content = response.model_dump()) except Exception as e: logger.error(f"Error during GGUF completion: {e}", exc_info = True) raise HTTPException(status_code = 500, detail = str(e)) # ── Standard Unsloth path ───────────────────────────────── # Decode image (from content parts OR legacy field) image_b64 = extracted_image_b64 or payload.image_base64 image = None if image_b64: try: import base64 from PIL import Image from io import BytesIO model_info = backend.models.get(backend.active_model_name, {}) if not model_info.get("is_vision"): raise HTTPException( status_code = 400, detail = "Image provided but current model is text-only. Load a vision model.", ) image_data = base64.b64decode(image_b64) image = Image.open(BytesIO(image_data)) image = backend.resize_image(image) except HTTPException: raise except Exception as e: raise HTTPException(status_code = 400, detail = f"Failed to decode image: {e}") # Shared generation kwargs gen_kwargs = dict( messages = chat_messages, system_prompt = system_prompt, image = image, temperature = payload.temperature, top_p = payload.top_p, top_k = payload.top_k, min_p = payload.min_p, max_new_tokens = payload.max_tokens or 2048, repetition_penalty = payload.repetition_penalty, ) # Choose generation path (adapter-controlled or standard) cancel_event = threading.Event() if payload.use_adapter is not None: def generate(): return backend.generate_with_adapter_control( use_adapter = payload.use_adapter, cancel_event = cancel_event, **gen_kwargs, ) else: def generate(): return backend.generate_chat_response( cancel_event = cancel_event, **gen_kwargs ) completion_id = f"chatcmpl-{uuid.uuid4().hex[:12]}" created = int(time.time()) # ── Streaming response ──────────────────────────────────────── if payload.stream: async def stream_chunks(): try: first_chunk = ChatCompletionChunk( id = completion_id, created = created, model = model_name, choices = [ ChunkChoice( delta = ChoiceDelta(role = "assistant"), finish_reason = None, ) ], ) yield f"data: {first_chunk.model_dump_json(exclude_none = True)}\n\n" prev_text = "" # Run sync generator in thread pool to avoid blocking # the event loop. Critical for compare mode: two SSE # requests arrive concurrently but the orchestrator # serializes them via _gen_lock. Without run_in_executor # the second request's blocking lock acquisition would # freeze the entire event loop, stalling both streams. _DONE = object() # sentinel for generator exhaustion loop = asyncio.get_event_loop() gen = generate() while True: # next(gen, _DONE) returns _DONE instead of raising # StopIteration — StopIteration cannot propagate # through asyncio futures (Python limitation). cumulative = await loop.run_in_executor(None, next, gen, _DONE) if cumulative is _DONE: break if await request.is_disconnected(): cancel_event.set() backend.reset_generation_state() return new_text = cumulative[len(prev_text) :] prev_text = cumulative if not new_text: continue chunk = ChatCompletionChunk( id = completion_id, created = created, model = model_name, choices = [ ChunkChoice( delta = ChoiceDelta(content = new_text), finish_reason = None, ) ], ) yield f"data: {chunk.model_dump_json(exclude_none = True)}\n\n" final_chunk = ChatCompletionChunk( id = completion_id, created = created, model = model_name, choices = [ ChunkChoice( delta = ChoiceDelta(), finish_reason = "stop", ) ], ) yield f"data: {final_chunk.model_dump_json(exclude_none = True)}\n\n" yield "data: [DONE]\n\n" except asyncio.CancelledError: cancel_event.set() backend.reset_generation_state() raise except Exception as e: backend.reset_generation_state() logger.error(f"Error during OpenAI streaming: {e}", exc_info = True) error_chunk = { "error": { "message": "An internal error occurred", "type": "server_error", }, } yield f"data: {json.dumps(error_chunk)}\n\n" return StreamingResponse( stream_chunks(), media_type = "text/event-stream", headers = { "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", }, ) # ── Non-streaming response ──────────────────────────────────── else: try: full_text = "" for token in generate(): full_text = token response = ChatCompletion( id = completion_id, created = created, model = model_name, choices = [ CompletionChoice( message = CompletionMessage(content = full_text), finish_reason = "stop", ) ], ) return JSONResponse(content = response.model_dump()) except Exception as e: backend.reset_generation_state() logger.error(f"Error during OpenAI completion: {e}", exc_info = True) raise HTTPException(status_code = 500, detail = str(e)) # ===================================================================== # OpenAI-Compatible Models Listing (/models → /v1/models) # ===================================================================== @router.get("/models") async def openai_list_models( current_subject: str = Depends(get_current_subject), ): """ OpenAI-compatible model listing endpoint. Returns the currently loaded model in the format expected by OpenAI-compatible clients (``GET /v1/models``). """ models = [] # Check GGUF backend llama_backend = get_llama_cpp_backend() if llama_backend.is_loaded: models.append( { "id": llama_backend.model_identifier, "object": "model", "owned_by": "local", } ) # Check Unsloth backend backend = get_inference_backend() if backend.active_model_name: models.append( { "id": backend.active_model_name, "object": "model", "owned_by": "local", } ) return {"object": "list", "data": models} ================================================ FILE: studio/backend/routes/models.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Model Management API routes """ import os import sys from pathlib import Path from fastapi import APIRouter, Body, Depends, HTTPException, Query from typing import List, Optional import structlog from loggers import get_logger import re as _re _VALID_REPO_ID = _re.compile(r"^[A-Za-z0-9._-]+/[A-Za-z0-9._-]+$") def _is_valid_repo_id(repo_id: str) -> bool: return bool(_VALID_REPO_ID.fullmatch(repo_id)) # Add backend directory to path backend_path = Path(__file__).parent.parent.parent if str(backend_path) not in sys.path: sys.path.insert(0, str(backend_path)) from auth.authentication import get_current_subject # Import backend functions try: from utils.models import ( scan_trained_loras, scan_exported_models, load_model_defaults, get_base_model_from_lora, is_vision_model, is_embedding_model, scan_checkpoints, list_gguf_variants, ModelConfig, ) from utils.models.model_config import ( _pick_best_gguf, _extract_quant_label, is_audio_input_type, ) from core.inference import get_inference_backend from utils.paths import ( outputs_root, exports_root, resolve_output_dir, resolve_export_dir, ) except ImportError: # Fallback: try to import from parent directory parent_backend = backend_path.parent / "backend" if str(parent_backend) not in sys.path: sys.path.insert(0, str(parent_backend)) from utils.models import ( scan_trained_loras, scan_exported_models, load_model_defaults, get_base_model_from_lora, is_vision_model, is_embedding_model, scan_checkpoints, list_gguf_variants, ModelConfig, ) from utils.models.model_config import ( _pick_best_gguf, _extract_quant_label, is_audio_input_type, ) from core.inference import get_inference_backend from utils.paths import ( outputs_root, exports_root, resolve_output_dir, resolve_export_dir, ) from models import ( CheckpointInfo, CheckpointListResponse, LocalModelInfo, LocalModelListResponse, ModelCheckpoints, ModelDetails, LoRAScanResponse, LoRAInfo, ModelListResponse, ) from models.models import GgufVariantDetail, GgufVariantsResponse, ModelType from models.responses import ( LoRABaseModelResponse, VisionCheckResponse, EmbeddingCheckResponse, ) router = APIRouter() logger = get_logger(__name__) def derive_model_type( is_vision: bool, audio_type: Optional[str], is_embedding: bool = False ) -> ModelType: """Collapse individual capability flags into a single model modality string.""" if is_embedding: return "embeddings" if audio_type is not None: return "audio" if is_vision: return "vision" return "text" def _resolve_hf_cache_dir() -> Path: """Resolve local HF cache root used by hub downloads.""" try: from huggingface_hub.constants import HF_HUB_CACHE return Path(HF_HUB_CACHE) except Exception: return Path.home() / ".cache" / "huggingface" / "hub" def _scan_models_dir(models_dir: Path) -> List[LocalModelInfo]: if not models_dir.exists() or not models_dir.is_dir(): return [] found: List[LocalModelInfo] = [] for child in models_dir.iterdir(): if not child.is_dir(): continue has_model_files = ( (child / "config.json").exists() or (child / "adapter_config.json").exists() or any(child.glob("*.safetensors")) or any(child.glob("*.bin")) or any(child.glob("*.gguf")) ) if not has_model_files: continue try: updated_at = child.stat().st_mtime except OSError: updated_at = None found.append( LocalModelInfo( id = str(child), display_name = child.name, path = str(child), source = "models_dir", updated_at = updated_at, ), ) # Also scan for standalone .gguf files directly in the models directory for gguf_file in models_dir.glob("*.gguf"): if gguf_file.is_file(): try: updated_at = gguf_file.stat().st_mtime except OSError: updated_at = None found.append( LocalModelInfo( id = str(gguf_file), display_name = gguf_file.stem, path = str(gguf_file), source = "models_dir", updated_at = updated_at, ), ) return found def _scan_hf_cache(cache_dir: Path) -> List[LocalModelInfo]: if not cache_dir.exists() or not cache_dir.is_dir(): return [] found: List[LocalModelInfo] = [] for repo_dir in cache_dir.glob("models--*"): if not repo_dir.is_dir(): continue repo_name = repo_dir.name[len("models--") :] if not repo_name: continue model_id = repo_name.replace("--", "/") try: updated_at = repo_dir.stat().st_mtime except OSError: updated_at = None found.append( LocalModelInfo( id = model_id, model_id = model_id, display_name = model_id.split("/")[-1], path = str(repo_dir), source = "hf_cache", updated_at = updated_at, ), ) return found @router.get("/local", response_model = LocalModelListResponse) async def list_local_models( models_dir: str = Query( default = "./models", description = "Directory to scan for local model folders" ), current_subject: str = Depends(get_current_subject), ): """ List local model candidates from custom models dir and HF cache. """ # Validate models_dir against an allowlist of trusted directories. # Only the trusted Path objects are used for filesystem access -- the # user-supplied string is only used for matching, never for path construction. hf_cache_dir = _resolve_hf_cache_dir() allowed_roots = [Path("./models").resolve(), hf_cache_dir] try: from utils.paths import studio_root, outputs_root allowed_roots.extend([studio_root(), outputs_root()]) except Exception: pass requested = os.path.realpath(os.path.expanduser(models_dir)) models_root = None for root in allowed_roots: root_str = os.path.realpath(str(root)) if requested == root_str or requested.startswith(root_str + os.sep): models_root = root # Use the trusted root, not the user-supplied path break if models_root is None: raise HTTPException( status_code = 403, detail = "Directory not allowed", ) try: local_models = _scan_models_dir(models_root) + _scan_hf_cache(hf_cache_dir) deduped: dict[str, LocalModelInfo] = {} for model in local_models: if model.id not in deduped: deduped[model.id] = model models = sorted( deduped.values(), key = lambda item: (item.updated_at or 0), reverse = True, ) return LocalModelListResponse( models_dir = str(models_root), hf_cache_dir = str(hf_cache_dir), models = models, ) except Exception as e: logger.error(f"Error listing local models: {e}", exc_info = True) raise HTTPException( status_code = 500, detail = f"Failed to list local models: {str(e)}", ) @router.get("/list") async def list_models( current_subject: str = Depends(get_current_subject), ): """ List available models (default models and loaded models). This endpoint returns the default models and any currently loaded models. """ try: inference_backend = get_inference_backend() # Get default models default_models = inference_backend.default_models # Get loaded models loaded_models = [] for model_name, model_data in inference_backend.models.items(): _is_vision = model_data.get("is_vision", False) _audio_type = model_data.get("audio_type") model_info = ModelDetails( id = model_name, name = model_name.split("/")[-1] if "/" in model_name else model_name, is_vision = _is_vision, is_lora = model_data.get("is_lora", False), is_audio = model_data.get("is_audio", False), audio_type = _audio_type, has_audio_input = model_data.get("has_audio_input", False), model_type = derive_model_type(_is_vision, _audio_type), ) loaded_models.append(model_info) # Include active GGUF model (loaded via llama-server) from routes.inference import get_llama_cpp_backend llama_backend = get_llama_cpp_backend() if llama_backend.is_loaded and llama_backend.model_identifier: loaded_models.append( ModelDetails( id = llama_backend.model_identifier, name = llama_backend.model_identifier.split("/")[-1], is_gguf = True, is_vision = llama_backend.is_vision, is_audio = getattr(llama_backend, "_is_audio", False), audio_type = getattr(llama_backend, "_audio_type", None), ) ) # Combine default and loaded models all_models = [] seen_ids = set() # Add default models for model_id in default_models: if model_id not in seen_ids: model_info = ModelDetails( id = model_id, name = model_id.split("/")[-1] if "/" in model_id else model_id, is_gguf = model_id.upper().endswith("-GGUF"), ) all_models.append(model_info) seen_ids.add(model_id) # Add loaded models for model_info in loaded_models: if model_info.id not in seen_ids: all_models.append(model_info) seen_ids.add(model_info.id) return ModelListResponse(models = all_models, default_models = default_models) except Exception as e: logger.error(f"Error listing models: {e}", exc_info = True) raise HTTPException(status_code = 500, detail = f"Failed to list models: {str(e)}") def _get_max_position_embeddings(config) -> Optional[int]: """Extract max_position_embeddings from a model config, checking text_config fallback.""" if hasattr(config, "max_position_embeddings"): return config.max_position_embeddings if hasattr(config, "text_config") and hasattr( config.text_config, "max_position_embeddings" ): return config.text_config.max_position_embeddings return None def _get_model_size_bytes( model_name: str, hf_token: Optional[str] = None ) -> Optional[int]: """Get total size of model weight files from HF Hub.""" try: from huggingface_hub import HfApi api = HfApi(token = hf_token) info = api.repo_info(model_name, repo_type = "model", token = hf_token) if not info.siblings: return None weight_exts = (".safetensors", ".bin", ".pt", ".pth", ".gguf") total = 0 for sibling in info.siblings: if sibling.rfilename and any( sibling.rfilename.endswith(ext) for ext in weight_exts ): if sibling.size is not None: total += sibling.size return total if total > 0 else None except Exception as e: logger.warning(f"Could not get model size for {model_name}: {e}") return None @router.get("/config/{model_name:path}") async def get_model_config( model_name: str, hf_token: Optional[str] = Query(None), current_subject: str = Depends(get_current_subject), ): """ Get configuration for a specific model. This endpoint wraps the backend load_model_defaults function. """ try: from utils.models.model_config import is_local_path if not is_local_path(model_name): model_name = model_name.lower() logger.info(f"Getting model config for: {model_name}") from utils.models.model_config import detect_audio_type # Load model defaults from backend config_dict = load_model_defaults(model_name) # Detect model capabilities (pass HF token for gated models) is_vision = is_vision_model(model_name) is_embedding = is_embedding_model(model_name, hf_token = hf_token) audio_type = detect_audio_type(model_name, hf_token = hf_token) # Check if it's a LoRA adapter is_lora = False base_model = None max_position_embeddings = None try: model_config = ModelConfig.from_identifier(model_name) is_lora = model_config.is_lora base_model = model_config.base_model if is_lora else None max_position_embeddings = _get_max_position_embeddings(model_config) except Exception: pass # Fallback: try AutoConfig directly if not found yet if max_position_embeddings is None: try: from transformers import AutoConfig as _AutoConfig _trust = model_name.lower().startswith("unsloth/") _ac = _AutoConfig.from_pretrained( model_name, trust_remote_code = _trust, token = hf_token ) max_position_embeddings = _get_max_position_embeddings(_ac) except Exception: pass logger.info( f"Model config result for {model_name}: is_vision={is_vision}, is_embedding={is_embedding}, audio_type={audio_type}, is_lora={is_lora}, max_position_embeddings={max_position_embeddings}" ) return ModelDetails( id = model_name, model_name = model_name, config = config_dict, is_vision = is_vision, is_embedding = is_embedding, is_lora = is_lora, is_audio = audio_type is not None, audio_type = audio_type, has_audio_input = is_audio_input_type(audio_type), model_type = derive_model_type(is_vision, audio_type, is_embedding), base_model = base_model, max_position_embeddings = max_position_embeddings, model_size_bytes = _get_model_size_bytes(model_name, hf_token), ) except Exception as e: logger.error(f"Error getting model config: {e}", exc_info = True) raise HTTPException( status_code = 500, detail = f"Failed to get model config: {str(e)}" ) @router.get("/loras") async def scan_loras( outputs_dir: str = Query( default = str(outputs_root()), description = "Directory to scan for LoRA adapters" ), exports_dir: str = Query( default = str(exports_root()), description = "Directory to scan for exported models" ), current_subject: str = Depends(get_current_subject), ): """ Scan for trained LoRA adapters and exported models. Returns both training outputs (from outputs_dir) and exported models (from exports_dir) in a single list, distinguished by source field. """ try: resolved_outputs_dir = str(resolve_output_dir(outputs_dir)) resolved_exports_dir = str(resolve_export_dir(exports_dir)) lora_list = [] # Scan training outputs trained_loras = scan_trained_loras(outputs_dir = resolved_outputs_dir) for display_name, adapter_path in trained_loras: base_model = get_base_model_from_lora(adapter_path) lora_list.append( LoRAInfo( display_name = display_name, adapter_path = adapter_path, base_model = base_model, source = "training", ) ) # Scan exported models (merged, LoRA, base — skips GGUF) exported = scan_exported_models(exports_dir = resolved_exports_dir) for display_name, model_path, export_type, base_model in exported: lora_list.append( LoRAInfo( display_name = display_name, adapter_path = model_path, base_model = base_model, source = "exported", export_type = export_type, ) ) return LoRAScanResponse(loras = lora_list, outputs_dir = resolved_outputs_dir) except Exception as e: logger.error(f"Error scanning LoRAs: {e}", exc_info = True) raise HTTPException( status_code = 500, detail = f"Failed to scan LoRA adapters: {str(e)}" ) @router.get("/loras/{lora_path:path}/base-model", response_model = LoRABaseModelResponse) async def get_lora_base_model( lora_path: str, current_subject: str = Depends(get_current_subject), ): """ Get the base model for a LoRA adapter. This endpoint wraps the backend get_base_model_from_lora function. """ try: base_model = get_base_model_from_lora(lora_path) if base_model is None: raise HTTPException( status_code = 404, detail = f"Could not determine base model for LoRA: {lora_path}", ) return LoRABaseModelResponse( lora_path = lora_path, base_model = base_model, ) except HTTPException: raise except Exception as e: logger.error(f"Error getting LoRA base model: {e}", exc_info = True) raise HTTPException( status_code = 500, detail = f"Failed to get base model: {str(e)}" ) @router.get("/check-vision/{model_name:path}", response_model = VisionCheckResponse) async def check_vision_model( model_name: str, current_subject: str = Depends(get_current_subject), ): """ Check if a model is a vision model. This endpoint wraps the backend is_vision_model function. """ try: logger.info(f"Checking if vision model: {model_name}") is_vision = is_vision_model(model_name) logger.info(f"Vision check result for {model_name}: is_vision={is_vision}") return VisionCheckResponse( model_name = model_name, is_vision = is_vision, ) except Exception as e: logger.error(f"Error checking vision model: {e}", exc_info = True) raise HTTPException( status_code = 500, detail = f"Failed to check vision model: {str(e)}" ) @router.get("/check-embedding/{model_name:path}", response_model = EmbeddingCheckResponse) async def check_embedding_model( model_name: str, hf_token: Optional[str] = Query(None), current_subject: str = Depends(get_current_subject), ): """ Check if a model is an embedding model. This endpoint wraps the backend is_embedding_model function. """ try: logger.info(f"Checking if embedding model: {model_name}") is_embedding = is_embedding_model(model_name, hf_token = hf_token) logger.info( f"Embedding check result for {model_name}: is_embedding={is_embedding}" ) return EmbeddingCheckResponse( model_name = model_name, is_embedding = is_embedding, ) except Exception as e: logger.error(f"Error checking embedding model: {e}", exc_info = True) raise HTTPException( status_code = 500, detail = f"Failed to check embedding model: {str(e)}" ) @router.get("/gguf-variants", response_model = GgufVariantsResponse) async def get_gguf_variants( repo_id: str = Query( ..., description = "HuggingFace repo ID (e.g. 'unsloth/gemma-3-4b-it-GGUF')" ), hf_token: Optional[str] = Query( None, description = "HuggingFace token for private repos" ), current_subject: str = Depends(get_current_subject), ): """ List available GGUF quantization variants for a HuggingFace repo. Returns all available quantization variants (Q4_K_M, Q8_0, BF16, etc.) with file sizes, whether the model supports vision, and the recommended default variant. """ try: variants, has_vision = list_gguf_variants(repo_id, hf_token = hf_token) # Determine default variant filenames = [v.filename for v in variants] best = _pick_best_gguf(filenames) default_variant = _extract_quant_label(best) if best else None # Check which variants are fully downloaded in the HF cache. # For split GGUFs, ALL shards must be present -- sum cached bytes # per variant and compare against the expected total. # HF cache dir uses the exact case from the repo_id at download time, # which may differ from the canonical HF repo_id, so do a # case-insensitive match. cached_bytes_by_quant: dict[str, int] = {} try: import re as _re from huggingface_hub import constants as hf_constants # Sanitize repo_id: must be "owner/name" with safe chars only if not _is_valid_repo_id(repo_id): raise ValueError(f"Invalid repo_id format: {repo_id}") cache_dir = Path(hf_constants.HF_HUB_CACHE) target = f"models--{repo_id.replace('/', '--')}".lower() for entry in cache_dir.iterdir(): if entry.name.lower() == target: snapshots = entry / "snapshots" if snapshots.is_dir(): for snap in snapshots.iterdir(): for f in snap.rglob("*.gguf"): q = _extract_quant_label(f.name) cached_bytes_by_quant[q] = ( cached_bytes_by_quant.get(q, 0) + f.stat().st_size ) break except Exception: pass def _is_fully_downloaded(variant) -> bool: cached = cached_bytes_by_quant.get(variant.quant, 0) if cached == 0 or variant.size_bytes == 0: return False # Allow small rounding tolerance (symlinks vs real sizes) return cached >= variant.size_bytes * 0.99 return GgufVariantsResponse( repo_id = repo_id, variants = [ GgufVariantDetail( filename = v.filename, quant = v.quant, size_bytes = v.size_bytes, downloaded = _is_fully_downloaded(v), ) for v in variants ], has_vision = has_vision, default_variant = default_variant, ) except Exception as e: logger.error(f"Error listing GGUF variants for '{repo_id}': {e}", exc_info = True) raise HTTPException( status_code = 500, detail = f"Failed to list GGUF variants: {str(e)}", ) @router.get("/gguf-download-progress") async def get_gguf_download_progress( repo_id: str = Query(..., description = "HuggingFace repo ID"), variant: str = Query("", description = "Quantization variant (e.g. UD-TQ1_0)"), expected_bytes: int = Query(0, description = "Expected total download size in bytes"), current_subject: str = Depends(get_current_subject), ): """Return download progress by checking cached GGUF files for a specific variant. Tracks completed shard downloads in snapshots and in-progress downloads in the blobs directory (incomplete files). """ try: if not _is_valid_repo_id(repo_id): return { "downloaded_bytes": 0, "expected_bytes": expected_bytes, "progress": 0, } from huggingface_hub import constants as hf_constants cache_dir = Path(hf_constants.HF_HUB_CACHE) target = f"models--{repo_id.replace('/', '--')}".lower() variant_lower = variant.lower().replace("-", "").replace("_", "") downloaded_bytes = 0 in_progress_bytes = 0 for entry in cache_dir.iterdir(): if entry.name.lower() == target: # Count completed .gguf files matching this variant in snapshots for f in entry.rglob("*.gguf"): fname = f.name.lower().replace("-", "").replace("_", "") if not variant_lower or variant_lower in fname: downloaded_bytes += f.stat().st_size # Check blobs for in-progress downloads (.incomplete files) blobs_dir = entry / "blobs" if blobs_dir.is_dir(): for f in blobs_dir.iterdir(): if f.is_file() and f.name.endswith(".incomplete"): in_progress_bytes += f.stat().st_size break total_progress_bytes = downloaded_bytes + in_progress_bytes progress = ( min(total_progress_bytes / expected_bytes, 0.99) if expected_bytes > 0 else 0 ) # Only report 1.0 when all bytes are in completed files (not in-progress) if expected_bytes > 0 and downloaded_bytes >= expected_bytes: progress = 1.0 return { "downloaded_bytes": total_progress_bytes, "expected_bytes": expected_bytes, "progress": round(progress, 3), } except Exception: return {"downloaded_bytes": 0, "expected_bytes": expected_bytes, "progress": 0} @router.get("/download-progress") async def get_download_progress( repo_id: str = Query(..., description = "HuggingFace repo ID"), current_subject: str = Depends(get_current_subject), ): """Return download progress for any HuggingFace model repo. Checks the local HF cache for completed blobs and in-progress (.incomplete) downloads. Uses the HF API to determine the expected total size on the first call, then caches it for subsequent polls. """ _empty = {"downloaded_bytes": 0, "expected_bytes": 0, "progress": 0} try: if not _is_valid_repo_id(repo_id): return _empty from huggingface_hub import constants as hf_constants cache_dir = Path(hf_constants.HF_HUB_CACHE) target = f"models--{repo_id.replace('/', '--')}".lower() completed_bytes = 0 in_progress_bytes = 0 for entry in cache_dir.iterdir(): if entry.name.lower() != target: continue blobs_dir = entry / "blobs" if not blobs_dir.is_dir(): break for f in blobs_dir.iterdir(): if not f.is_file(): continue if f.name.endswith(".incomplete"): in_progress_bytes += f.stat().st_size else: completed_bytes += f.stat().st_size break downloaded_bytes = completed_bytes + in_progress_bytes if downloaded_bytes == 0: return _empty # Get expected size from HF API (cached per repo_id) expected_bytes = _get_repo_size_cached(repo_id) if expected_bytes <= 0: # Cannot determine total; report bytes only, no percentage return { "downloaded_bytes": downloaded_bytes, "expected_bytes": 0, "progress": 0, } # Use 95% threshold for completion (blob deduplication can make # completed_bytes differ slightly from expected_bytes). # Do NOT use "no .incomplete files" as a completion signal -- # HF downloads files sequentially, so between files there are # no .incomplete files even though the download is far from done. if completed_bytes >= expected_bytes * 0.95: progress = 1.0 else: progress = min(downloaded_bytes / expected_bytes, 0.99) return { "downloaded_bytes": downloaded_bytes, "expected_bytes": expected_bytes, "progress": round(progress, 3), } except Exception as e: logger.warning(f"Error checking download progress for {repo_id}: {e}") return _empty _repo_size_cache: dict[str, int] = {} def _get_repo_size_cached(repo_id: str) -> int: if repo_id in _repo_size_cache: return _repo_size_cache[repo_id] try: from huggingface_hub import model_info as hf_model_info info = hf_model_info(repo_id, token = None, files_metadata = True) total = sum(s.size for s in info.siblings if s.size) _repo_size_cache[repo_id] = total return total except Exception as e: logger.warning(f"Failed to get repo size for {repo_id}: {e}") return 0 @router.get("/cached-gguf") async def list_cached_gguf( current_subject: str = Depends(get_current_subject), ): """List GGUF repos that have already been downloaded to the HF cache. Uses scan_cache_dir() for proper repo IDs, then deduplicates by lowercased key (HF cache dirs are lowercased but the canonical repo ID preserves casing). """ try: from huggingface_hub import scan_cache_dir hf_cache = scan_cache_dir() seen_lower: dict[str, dict] = {} for repo_info in hf_cache.repos: if repo_info.repo_type != "model": continue repo_id = repo_info.repo_id if not repo_id.upper().endswith("-GGUF"): continue # Check for actual .gguf files and sum sizes total_size = 0 has_gguf = False for revision in repo_info.revisions: for f in revision.files: if f.file_name.endswith(".gguf"): has_gguf = True total_size += f.size_on_disk if not has_gguf: continue # Deduplicate: keep the entry with the most data key = repo_id.lower() existing = seen_lower.get(key) if existing is None or total_size > existing["size_bytes"]: seen_lower[key] = { "repo_id": repo_id, "size_bytes": total_size, "cache_path": str(repo_info.repo_path), } cached = sorted(seen_lower.values(), key = lambda c: c["repo_id"]) return {"cached": cached} except Exception as e: logger.error(f"Error listing cached GGUF repos: {e}", exc_info = True) return {"cached": []} @router.get("/cached-models") async def list_cached_models( current_subject: str = Depends(get_current_subject), ): """List non-GGUF model repos that have been downloaded to the HF cache. Only includes repos that actually contain model weight files (.safetensors, .bin), not repos with only config/metadata. """ _WEIGHT_EXTENSIONS = (".safetensors", ".bin") try: from huggingface_hub import scan_cache_dir hf_cache = scan_cache_dir() seen_lower: dict[str, dict] = {} for repo_info in hf_cache.repos: if repo_info.repo_type != "model": continue repo_id = repo_info.repo_id if repo_id.upper().endswith("-GGUF"): continue total_size = sum( f.size_on_disk for rev in repo_info.revisions for f in rev.files ) if total_size == 0: continue # Skip repos that only have config/metadata files (no weights) has_weights = any( f.file_name.endswith(_WEIGHT_EXTENSIONS) for rev in repo_info.revisions for f in rev.files ) if not has_weights: continue key = repo_id.lower() existing = seen_lower.get(key) if existing is None or total_size > existing["size_bytes"]: seen_lower[key] = { "repo_id": repo_id, "size_bytes": total_size, } cached = sorted(seen_lower.values(), key = lambda c: c["repo_id"]) return {"cached": cached} except Exception as e: logger.error(f"Error listing cached models: {e}", exc_info = True) return {"cached": []} @router.delete("/delete-cached") async def delete_cached_model( repo_id: str = Body(...), variant: Optional[str] = Body(None), current_subject: str = Depends(get_current_subject), ): """Delete a cached model repo (or a specific GGUF variant) from the HF cache. When *variant* is provided, only the GGUF files matching that quant label are removed (e.g. ``UD-Q4_K_XL``). Otherwise the entire repo is deleted. Refuses if the model is currently loaded for inference. """ if not _is_valid_repo_id(repo_id): raise HTTPException(status_code = 400, detail = "Invalid repo_id format") # Check if model is currently loaded try: from routes.inference import get_llama_cpp_backend llama_backend = get_llama_cpp_backend() if llama_backend.is_loaded and llama_backend.model_identifier: loaded_id = llama_backend.model_identifier.lower() if loaded_id == repo_id.lower() or loaded_id.startswith(repo_id.lower()): raise HTTPException( status_code = 400, detail = "Unload the model before deleting", ) except HTTPException: raise except Exception: pass try: inference_backend = get_inference_backend() if inference_backend.active_model_name: active = inference_backend.active_model_name.lower() if active == repo_id.lower() or active.startswith(repo_id.lower()): raise HTTPException( status_code = 400, detail = "Unload the model before deleting", ) except HTTPException: raise except Exception: pass try: from huggingface_hub import scan_cache_dir hf_cache = scan_cache_dir() target_repo = None for repo_info in hf_cache.repos: if repo_info.repo_type != "model": continue if repo_info.repo_id.lower() == repo_id.lower(): target_repo = repo_info break if target_repo is None: raise HTTPException(status_code = 404, detail = "Model not found in cache") # ── Per-variant GGUF deletion ──────────────────────────── if variant: deleted_bytes = 0 deleted_count = 0 for rev in target_repo.revisions: for f in rev.files: if not f.file_name.endswith(".gguf"): continue quant = _extract_quant_label(f.file_name) if quant.lower() != variant.lower(): continue # Delete the blob (actual data) and the snapshot symlink try: blob = Path(f.blob_path) snap = Path(f.file_path) size = blob.stat().st_size if blob.exists() else 0 if snap.exists() or snap.is_symlink(): snap.unlink() if blob.exists(): blob.unlink() deleted_bytes += size deleted_count += 1 except Exception as e: logger.warning(f"Failed to delete {f.file_name}: {e}") if deleted_count == 0: raise HTTPException( status_code = 404, detail = f"Variant {variant} not found in cache for {repo_id}", ) freed_mb = deleted_bytes / (1024 * 1024) logger.info( f"Deleted {deleted_count} file(s) for {repo_id} variant {variant}: " f"{freed_mb:.1f} MB freed" ) return {"status": "deleted", "repo_id": repo_id, "variant": variant} # ── Full repo deletion ─────────────────────────────────── revision_hashes = [rev.commit_hash for rev in target_repo.revisions] if not revision_hashes: raise HTTPException(status_code = 404, detail = "No revisions found for model") delete_strategy = hf_cache.delete_revisions(*revision_hashes) logger.info( f"Deleting cached model {repo_id}: " f"{delete_strategy.expected_freed_size_str} will be freed" ) delete_strategy.execute() return {"status": "deleted", "repo_id": repo_id} except HTTPException: raise except Exception as e: logger.error(f"Error deleting cached model {repo_id}: {e}", exc_info = True) raise HTTPException( status_code = 500, detail = f"Failed to delete cached model: {str(e)}", ) @router.get("/checkpoints", response_model = CheckpointListResponse) async def list_checkpoints( outputs_dir: str = Query( default = str(outputs_root()), description = "Directory to scan for checkpoints", ), current_subject: str = Depends(get_current_subject), ): """ List available checkpoints in the outputs directory. Scans the outputs folder for training runs and their checkpoints. """ try: resolved_outputs_dir = str(resolve_output_dir(outputs_dir)) raw_models = scan_checkpoints(outputs_dir = resolved_outputs_dir) models = [ ModelCheckpoints( name = model_name, checkpoints = [ CheckpointInfo(display_name = display_name, path = path, loss = loss) for display_name, path, loss in checkpoints ], base_model = metadata.get("base_model"), peft_type = metadata.get("peft_type"), lora_rank = metadata.get("lora_rank"), ) for model_name, checkpoints, metadata in raw_models ] return CheckpointListResponse( outputs_dir = resolved_outputs_dir, models = models, ) except Exception as e: logger.error(f"Error listing checkpoints: {e}", exc_info = True) raise HTTPException( status_code = 500, detail = f"Failed to list checkpoints: {str(e)}", ) ================================================ FILE: studio/backend/routes/training.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Training API routes """ import sys from pathlib import Path from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.responses import StreamingResponse from typing import Dict, Optional, Any import structlog from loggers import get_logger import asyncio from datetime import datetime # Add backend directory to path # The backend code should be in the same directory structure backend_path = Path(__file__).parent.parent.parent if str(backend_path) not in sys.path: sys.path.insert(0, str(backend_path)) # Import backend functions try: from core.training import get_training_backend from utils.models.model_config import load_model_defaults from utils.paths import resolve_dataset_path except ImportError: # Fallback: try to import from parent directory parent_backend = backend_path.parent / "backend" if str(parent_backend) not in sys.path: sys.path.insert(0, str(parent_backend)) from core.training import get_training_backend from utils.models.model_config import load_model_defaults from utils.paths import resolve_dataset_path # Auth from auth.authentication import get_current_subject from models import ( TrainingStartRequest, TrainingJobResponse, TrainingStatus, TrainingProgress, ) from models.responses import TrainingStopResponse, TrainingMetricsResponse from pydantic import BaseModel as PydanticBaseModel class TrainingStopRequest(PydanticBaseModel): save: bool = True router = APIRouter() logger = get_logger(__name__) def _validate_local_dataset_paths( paths: list[str], label: str = "Local dataset" ) -> list[str]: """Resolve and validate a list of local dataset paths. Returns validated absolute paths.""" validated = [] missing = [] for dataset_path in paths: dataset_file = resolve_dataset_path(dataset_path) if not dataset_file.exists(): missing.append(f"{dataset_path} (resolved: {dataset_file})") continue logger.info(f"Found {label.lower()} file: {dataset_file}") validated.append(str(dataset_file)) if missing: missing_detail = "; ".join(missing[:3]) raise HTTPException( status_code = 400, detail = f"{label} not found: {missing_detail}", ) return validated @router.get("/hardware") async def get_hardware_utilization( current_subject: str = Depends(get_current_subject), ): """ Get a live snapshot of GPU hardware utilization. Designed to be polled by the frontend during training. Returns GPU utilization %, temperature, VRAM usage, and power draw via nvidia-smi for maximum accuracy. """ from utils.hardware import get_gpu_utilization return get_gpu_utilization() @router.post("/start") async def start_training( request: TrainingStartRequest, current_subject: str = Depends(get_current_subject), ): """ Start a training job. This endpoint initiates training in the background and returns immediately. Use the /status endpoint to check training progress. """ try: logger.info(f"Starting training job with model: {request.model_name}") # NOTE: No in-process ensure_transformers_version() call here. # The subprocess (worker.py) activates the correct version in a # fresh Python interpreter before importing any ML libraries. backend = get_training_backend() # Generate job ID and attach to backend for later status/progress calls job_id = f"job_{datetime.now().strftime('%Y%m%d_%H%M%S')}" backend.current_job_id = job_id # Check if training is already active if backend.is_training_active(): existing_job_id: Optional[str] = getattr(backend, "current_job_id", "") return TrainingJobResponse( job_id = existing_job_id or job_id, status = "error", message = ( "Training is already in progress. " "Stop current training before starting a new one." ), error = "Training already active", ) # Validate dataset paths if provided if request.local_datasets: request.local_datasets = _validate_local_dataset_paths( request.local_datasets, "Local dataset" ) if request.local_eval_datasets and request.eval_steps > 0: request.local_eval_datasets = _validate_local_dataset_paths( request.local_eval_datasets, "Local eval dataset" ) # Convert request to kwargs for backend training_kwargs = { "model_name": request.model_name, "training_type": request.training_type, "hf_token": request.hf_token or "", "load_in_4bit": request.load_in_4bit, "max_seq_length": request.max_seq_length, "hf_dataset": request.hf_dataset or "", "local_datasets": request.local_datasets, "local_eval_datasets": request.local_eval_datasets, "format_type": request.format_type, "subset": request.subset, "train_split": request.train_split, "eval_split": request.eval_split, "eval_steps": request.eval_steps, "dataset_slice_start": request.dataset_slice_start, "dataset_slice_end": request.dataset_slice_end, "custom_format_mapping": request.custom_format_mapping, "num_epochs": request.num_epochs, "learning_rate": request.learning_rate, "batch_size": request.batch_size, "gradient_accumulation_steps": request.gradient_accumulation_steps, "warmup_steps": request.warmup_steps, "warmup_ratio": request.warmup_ratio, "max_steps": request.max_steps, "save_steps": request.save_steps, "weight_decay": request.weight_decay, "random_seed": request.random_seed, "packing": request.packing, "optim": request.optim, "lr_scheduler_type": request.lr_scheduler_type, "use_lora": request.use_lora, "lora_r": request.lora_r, "lora_alpha": request.lora_alpha, "lora_dropout": request.lora_dropout, "target_modules": request.target_modules if request.target_modules else None, "gradient_checkpointing": request.gradient_checkpointing.strip() if request.gradient_checkpointing and request.gradient_checkpointing.strip() else "unsloth", "use_rslora": request.use_rslora, "use_loftq": request.use_loftq, "train_on_completions": request.train_on_completions, "finetune_vision_layers": request.finetune_vision_layers, "finetune_language_layers": request.finetune_language_layers, "finetune_attention_modules": request.finetune_attention_modules, "finetune_mlp_modules": request.finetune_mlp_modules, "is_dataset_image": request.is_dataset_image, "is_dataset_audio": request.is_dataset_audio, "is_embedding": request.is_embedding, "enable_wandb": request.enable_wandb, "wandb_token": request.wandb_token or "", "wandb_project": request.wandb_project or "", "enable_tensorboard": request.enable_tensorboard, "tensorboard_dir": request.tensorboard_dir or "", "trust_remote_code": request.trust_remote_code, } # Training page has no trust_remote_code toggle — the value comes from # YAML model defaults applied when the user selects a model. As a safety # net, consult the YAML directly so models that need it always get it. if not training_kwargs["trust_remote_code"]: model_defaults = load_model_defaults(request.model_name) yaml_trust = model_defaults.get("training", {}).get( "trust_remote_code", False ) if yaml_trust: logger.info( f"YAML config sets trust_remote_code=True for {request.model_name}" ) training_kwargs["trust_remote_code"] = True # Free GPU memory: shut down any running inference/export subprocesses # before training starts (they'd compete for VRAM otherwise) try: from core.inference import get_inference_backend inf_backend = get_inference_backend() if inf_backend.active_model_name: logger.info( "Unloading inference model '%s' to free GPU memory for training", inf_backend.active_model_name, ) inf_backend._shutdown_subprocess() inf_backend.active_model_name = None inf_backend.models.clear() except Exception as e: logger.warning("Could not unload inference model: %s", e) try: from core.export import get_export_backend exp_backend = get_export_backend() if exp_backend.current_checkpoint: logger.info( "Shutting down export subprocess to free GPU memory for training" ) exp_backend._shutdown_subprocess() exp_backend.current_checkpoint = None exp_backend.is_vision = False exp_backend.is_peft = False except Exception as e: logger.warning("Could not shut down export subprocess: %s", e) # start_training now spawns a subprocess (non-blocking) success = backend.start_training(**training_kwargs) if not success: progress_error = backend.trainer.training_progress.error return TrainingJobResponse( job_id = job_id, status = "error", message = progress_error or "Failed to start training subprocess", error = progress_error or "subprocess_start_failed", ) return TrainingJobResponse( job_id = job_id, status = "queued", message = "Training job queued and starting in subprocess", error = None, ) except Exception as e: logger.error(f"Error starting training: {e}", exc_info = True) raise HTTPException( status_code = 500, detail = f"Failed to start training: {str(e)}", ) @router.post("/stop", response_model = TrainingStopResponse) async def stop_training( body: TrainingStopRequest = TrainingStopRequest(), current_subject: str = Depends(get_current_subject), ): """ Stop the currently running training job. Body: save (bool): If True (default), save the model at the current checkpoint. """ try: backend = get_training_backend() is_active = backend.is_training_active() logger.info("Stop requested: save=%s is_active=%s", body.save, is_active) if not is_active: return TrainingStopResponse( status = "idle", message = "No training job is currently running" ) # Call backend stop method backend.stop_training(save = body.save) return TrainingStopResponse( status = "stopped", message = "Stop requested. Training will stop at the next safe step.", ) except Exception as e: logger.error(f"Error stopping training: {e}", exc_info = True) raise HTTPException( status_code = 500, detail = f"Failed to stop training: {str(e)}" ) @router.post("/reset") async def reset_training( current_subject: str = Depends(get_current_subject), ): """ Reset training state so the user can return to configuration. """ try: backend = get_training_backend() is_active = backend.is_training_active() if is_active: if backend._cancel_requested: # Cancel (save=False) was requested — force-terminate so we can reset immediately logger.info( "Force-terminating subprocess for immediate reset (cancel path)" ) backend.force_terminate() else: logger.warning( "Rejected reset while training active: is_active=%s", is_active ) raise HTTPException( status_code = 409, detail = "Training is still running. Stop training and wait for it to finish before resetting.", ) logger.info("Reset training state: clearing runtime + metric history") backend._should_stop = False # Clear stop flag so status returns to idle backend.trainer._update_progress( is_training = False, is_completed = False, error = None, status_message = "Ready to train", step = 0, loss = 0.0, epoch = 0, total_steps = 0, ) backend.loss_history = [] backend.lr_history = [] backend.step_history = [] backend.grad_norm_history = [] backend.grad_norm_step_history = [] return {"status": "ok"} except HTTPException: raise except Exception as e: logger.error(f"Error resetting training: {e}", exc_info = True) raise HTTPException( status_code = 500, detail = f"Failed to reset training: {str(e)}", ) @router.get("/status") async def get_training_status( current_subject: str = Depends(get_current_subject), ): """ Get the current training status. """ try: backend = get_training_backend() job_id: str = getattr(backend, "current_job_id", "") or "" # Check if training is active is_active = backend.is_training_active() # Get progress info from trainer try: progress = backend.trainer.get_training_progress() except Exception: progress = None status_message = ( getattr(progress, "status_message", None) if progress else None ) or "Ready to train" error_message = getattr(progress, "error", None) if progress else None # Check if training was stopped by user trainer_stopped = getattr(backend, "_should_stop", False) # Derive high-level phase if error_message: phase = "error" elif is_active: msg_lower = status_message.lower() if "loading" in msg_lower or "importing" in msg_lower: phase = "loading_model" elif any( k in msg_lower for k in ["preparing", "initializing", "configuring"] ): phase = "configuring" else: phase = "training" elif trainer_stopped: phase = "stopped" elif progress and getattr(progress, "is_completed", False): phase = "completed" else: phase = "idle" details = None if progress: details = { "epoch": getattr(progress, "epoch", 0), "step": getattr(progress, "step", 0), "total_steps": getattr(progress, "total_steps", 0), "loss": getattr(progress, "loss", 0.0), "learning_rate": getattr(progress, "learning_rate", 0.0), } # Build metric history for chart recovery after SSE reconnection metric_history = None if backend.step_history: metric_history = { "steps": list(backend.step_history), "loss": list(backend.loss_history), "lr": list(backend.lr_history), "grad_norm": list(getattr(backend, "grad_norm_history", [])), "grad_norm_steps": list(getattr(backend, "grad_norm_step_history", [])), "eval_loss": list(backend.eval_loss_history), "eval_steps": list(backend.eval_step_history), } return TrainingStatus( job_id = job_id, phase = phase, is_training_running = is_active, eval_enabled = backend.eval_enabled, message = status_message, error = error_message, details = details, metric_history = metric_history, ) except Exception as e: logger.error(f"Error getting training status: {e}", exc_info = True) raise HTTPException( status_code = 500, detail = f"Failed to get training status: {str(e)}" ) @router.get("/metrics", response_model = TrainingMetricsResponse) async def get_training_metrics( current_subject: str = Depends(get_current_subject), ): """ Get training metrics (loss, learning rate, steps). """ try: backend = get_training_backend() # Get metrics from backend loss_history = backend.loss_history lr_history = backend.lr_history step_history = backend.step_history grad_norm_history = getattr(backend, "grad_norm_history", []) grad_norm_step_history = getattr(backend, "grad_norm_step_history", []) # Get current values current_loss = loss_history[-1] if loss_history else None current_lr = lr_history[-1] if lr_history else None current_step = step_history[-1] if step_history else None return TrainingMetricsResponse( loss_history = loss_history, lr_history = lr_history, step_history = step_history, grad_norm_history = grad_norm_history, grad_norm_step_history = grad_norm_step_history, current_loss = current_loss, current_lr = current_lr, current_step = current_step, ) except Exception as e: logger.error(f"Error getting training metrics: {e}", exc_info = True) raise HTTPException( status_code = 500, detail = f"Failed to get training metrics: {str(e)}" ) @router.get("/progress") async def stream_training_progress( request: Request, current_subject: str = Depends(get_current_subject), ): """ Stream training progress updates using Server-Sent Events (SSE). This endpoint provides real-time updates on training progress. Supports reconnection via the SSE spec: - Sends `id:` with each event so the browser tracks position. - Sends `retry:` to control reconnection interval. - Sends named `event:` types (progress, heartbeat, complete, error). - Reads `Last-Event-ID` header on reconnect to replay missed steps. """ # Read Last-Event-ID header for reconnection resume last_event_id = request.headers.get("last-event-id") resume_from_step: Optional[int] = None if last_event_id is not None: try: resume_from_step = int(last_event_id) logger.info(f"SSE reconnect: resuming from step {resume_from_step}") except ValueError: logger.warning(f"Invalid Last-Event-ID: {last_event_id}") async def event_generator(): backend = get_training_backend() job_id: str = getattr(backend, "current_job_id", "") or "" # ── Helpers ────────────────────────────────────────────── def build_progress( step: int, loss: float, learning_rate: float, total_steps: int, epoch: Optional[float] = None, progress: Optional[Any] = None, grad_norm_override: Optional[float] = None, eval_loss_override: Optional[float] = None, ) -> TrainingProgress: total = max(total_steps, 0) if step < 0 or total == 0: progress_percent = 0.0 else: progress_percent = ( float(step) / float(total) * 100.0 if total > 0 else 0.0 ) # Get actual values from progress object if available elapsed_seconds = ( getattr(progress, "elapsed_seconds", None) if progress else None ) eta_seconds = getattr(progress, "eta_seconds", None) if progress else None grad_norm = grad_norm_override if grad_norm is None and progress: grad_norm = getattr(progress, "grad_norm", None) num_tokens = getattr(progress, "num_tokens", None) if progress else None eval_loss = eval_loss_override if eval_loss is None and progress: eval_loss = getattr(progress, "eval_loss", None) return TrainingProgress( job_id = job_id, step = step, total_steps = total, loss = loss, learning_rate = learning_rate, progress_percent = progress_percent, epoch = epoch, elapsed_seconds = elapsed_seconds, eta_seconds = eta_seconds, grad_norm = grad_norm, num_tokens = num_tokens, eval_loss = eval_loss, ) def format_sse( data: str, event: str = "progress", event_id: Optional[int] = None, ) -> str: """Format a single SSE message with id/event/data fields.""" lines = [] if event_id is not None: lines.append(f"id: {event_id}") lines.append(f"event: {event}") lines.append(f"data: {data}") lines.append("") # trailing blank line lines.append("") # double newline terminates the event return "\n".join(lines) # ── Retry directive ────────────────────────────────────── # Tell the browser to reconnect after 3 seconds if the connection drops yield "retry: 3000\n\n" # ── Replay missed steps on reconnect ───────────────────── if resume_from_step is not None and backend.step_history: replayed = 0 grad_norm_by_step = { step_val: grad_val for step_val, grad_val in zip( getattr(backend, "grad_norm_step_history", []), getattr(backend, "grad_norm_history", []), ) } for i, step_val in enumerate(backend.step_history): if step_val > resume_from_step: loss_val = ( backend.loss_history[i] if i < len(backend.loss_history) else 0.0 ) lr_val = ( backend.lr_history[i] if i < len(backend.lr_history) else 0.0 ) tp_replay = getattr( getattr(backend, "trainer", None), "training_progress", None ) total_replay = ( getattr(tp_replay, "total_steps", step_val) if tp_replay else step_val ) epoch_replay = ( getattr(tp_replay, "epoch", None) if tp_replay else None ) payload = build_progress( step_val, loss_val, lr_val, total_replay, epoch_replay, progress = tp_replay, grad_norm_override = grad_norm_by_step.get(step_val), ) yield format_sse( payload.model_dump_json(), event = "progress", event_id = step_val ) replayed += 1 if replayed: logger.info(f"SSE reconnect: replayed {replayed} missed steps") # ── Initial status (only on fresh connections) ─────────── if resume_from_step is None: is_active = backend.is_training_active() tp = getattr(getattr(backend, "trainer", None), "training_progress", None) initial_total_steps = getattr(tp, "total_steps", 0) if tp else 0 initial_epoch = getattr(tp, "epoch", None) if tp else None initial_progress = build_progress( step = 0, loss = 0.0, learning_rate = 0.0, total_steps = initial_total_steps, epoch = initial_epoch, progress = tp, ) yield format_sse( initial_progress.model_dump_json(), event = "progress", event_id = 0 ) # If not active, send final state and exit if not is_active: if backend.step_history: final_step = backend.step_history[-1] final_loss = ( backend.loss_history[-1] if backend.loss_history else 0.0 ) final_lr = backend.lr_history[-1] if backend.lr_history else 0.0 final_total_steps = ( getattr(tp, "total_steps", final_step) if tp else final_step ) final_epoch = getattr(tp, "epoch", None) if tp else None payload = build_progress( final_step, final_loss, final_lr, final_total_steps, final_epoch, progress = tp, ) yield format_sse( payload.model_dump_json(), event = "complete", event_id = final_step ) else: yield format_sse( build_progress(-1, 0.0, 0.0, 0, progress = tp).model_dump_json(), event = "complete", event_id = 0, ) return # ── Live polling loop ──────────────────────────────────── last_step = resume_from_step if resume_from_step is not None else -1 no_update_count = 0 max_no_updates = ( 1800 # Timeout after 30 minutes (large models need time for compilation) ) while backend.is_training_active(): try: if backend.step_history: current_step = backend.step_history[-1] current_loss = ( backend.loss_history[-1] if backend.loss_history else 0.0 ) current_lr = backend.lr_history[-1] if backend.lr_history else 0.0 tp_inner = getattr( getattr(backend, "trainer", None), "training_progress", None ) current_total_steps = ( getattr(tp_inner, "total_steps", current_step) if tp_inner else current_step ) current_epoch = ( getattr(tp_inner, "epoch", None) if tp_inner else None ) # Only send if step changed if current_step != last_step: progress_payload = build_progress( current_step, current_loss, current_lr, current_total_steps, current_epoch, progress = tp_inner, ) yield format_sse( progress_payload.model_dump_json(), event = "progress", event_id = current_step, ) last_step = current_step no_update_count = 0 else: no_update_count += 1 # Send heartbeat every 10 seconds if no_update_count % 10 == 0: heartbeat_payload = build_progress( current_step, current_loss, current_lr, current_total_steps, current_epoch, progress = tp_inner, ) yield format_sse( heartbeat_payload.model_dump_json(), event = "heartbeat", event_id = current_step, ) else: # No steps yet, but training is active (model loading, etc.) no_update_count += 1 if no_update_count % 5 == 0: # Pull total_steps and status from trainer so # the frontend can show "Tokenizing…" etc. tp_prep = getattr( getattr(backend, "trainer", None), "training_progress", None, ) prep_total = ( getattr(tp_prep, "total_steps", 0) if tp_prep else 0 ) preparing_payload = build_progress( 0, 0.0, 0.0, prep_total, progress = tp_prep, ) yield format_sse( preparing_payload.model_dump_json(), event = "heartbeat", event_id = 0, ) # Timeout check if no_update_count > max_no_updates: logger.warning("Progress stream timeout - no updates received") tp_timeout = getattr( getattr(backend, "trainer", None), "training_progress", None ) timeout_payload = build_progress( last_step, 0.0, 0.0, 0, progress = tp_timeout ) yield format_sse( timeout_payload.model_dump_json(), event = "error", event_id = last_step if last_step >= 0 else 0, ) break await asyncio.sleep(1) # Poll every second except Exception as e: logger.error(f"Error in progress stream: {e}", exc_info = True) tp_error = getattr( getattr(backend, "trainer", None), "training_progress", None ) error_payload = build_progress(0, 0.0, 0.0, 0, progress = tp_error) yield format_sse( error_payload.model_dump_json(), event = "error", event_id = last_step if last_step >= 0 else 0, ) break # ── Final "complete" event ─────────────────────────────── final_step = backend.step_history[-1] if backend.step_history else last_step final_loss = backend.loss_history[-1] if backend.loss_history else 0.0 final_lr = backend.lr_history[-1] if backend.lr_history else 0.0 final_tp = getattr(getattr(backend, "trainer", None), "training_progress", None) final_total_steps = ( getattr(final_tp, "total_steps", final_step) if final_tp else final_step ) final_epoch = getattr(final_tp, "epoch", None) if final_tp else None final_payload = build_progress( final_step, final_loss, final_lr, final_total_steps, final_epoch, progress = final_tp, ) yield format_sse( final_payload.model_dump_json(), event = "complete", event_id = final_step if final_step >= 0 else 0, ) return StreamingResponse( event_generator(), media_type = "text/event-stream", headers = { "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", }, ) ================================================ FILE: studio/backend/run.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Run script for Unsloth UI Backend. Works independently and can be moved to any directory. """ import os import sys # Suppress annoying C-level dependency warnings globally (e.g. SwigPyPacked) os.environ["PYTHONWARNINGS"] = "ignore" from pathlib import Path # Add the backend directory to Python path backend_dir = Path(__file__).parent if str(backend_dir) not in sys.path: sys.path.insert(0, str(backend_dir)) from loggers import get_logger logger = get_logger(__name__) def _resolve_external_ip() -> str: """ Resolve the machine's external IP address. Tries (in order): 1. GCE metadata server (instant, works on Google Cloud VMs) 2. ifconfig.me (works anywhere with internet) 3. LAN IP via UDP socket trick (fallback) """ import urllib.request import socket # 1. Try GCE metadata server (responds in <10ms on GCE, times out fast elsewhere) try: req = urllib.request.Request( "http://metadata.google.internal/computeMetadata/v1/instance/network-interfaces/0/access-configs/0/external-ip", headers = {"Metadata-Flavor": "Google"}, ) with urllib.request.urlopen(req, timeout = 1) as resp: ip = resp.read().decode().strip() if ip: return ip except Exception: pass # 2. Try public IP service try: with urllib.request.urlopen("https://ifconfig.me", timeout = 3) as resp: ip = resp.read().decode().strip() if ip: return ip except Exception: pass # 3. Fallback: LAN IP via UDP socket trick try: s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s.connect(("8.8.8.8", 80)) ip = s.getsockname()[0] s.close() return ip except Exception: return "0.0.0.0" def _is_port_free(host: str, port: int) -> bool: """Check if a port is available for binding.""" import socket try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.bind((host, port)) return True except OSError: return False def _find_free_port(host: str, start: int, max_attempts: int = 20) -> int: """Find a free port starting from `start`, trying up to max_attempts ports.""" for offset in range(max_attempts): candidate = start + offset if _is_port_free(host, candidate): return candidate raise RuntimeError( f"Could not find a free port in range {start}-{start + max_attempts - 1}" ) def _graceful_shutdown(server = None): """Explicitly shut down all subprocess backends and the uvicorn server. Called from signal handlers to ensure child processes are cleaned up before the parent exits. This is critical on Windows where atexit handlers are unreliable after Ctrl+C. """ logger.info("Graceful shutdown initiated — cleaning up subprocesses...") # 1. Shut down uvicorn server (releases the listening socket) if server is not None: server.should_exit = True # 2. Clean up inference subprocess (if instantiated) try: from core.inference.orchestrator import _inference_backend if _inference_backend is not None: _inference_backend._shutdown_subprocess(timeout = 5.0) except Exception as e: logger.warning("Error shutting down inference subprocess: %s", e) # 3. Clean up export subprocess (if instantiated) try: from core.export.orchestrator import _export_backend if _export_backend is not None: _export_backend._shutdown_subprocess(timeout = 5.0) except Exception as e: logger.warning("Error shutting down export subprocess: %s", e) # 4. Clean up training subprocess (if active) try: from core.training.training import _training_backend if _training_backend is not None: _training_backend.force_terminate() except Exception as e: logger.warning("Error shutting down training subprocess: %s", e) # 5. Kill llama-server subprocess (if loaded) try: from routes.inference import _llama_cpp_backend if _llama_cpp_backend is not None: _llama_cpp_backend._kill_process() except Exception as e: logger.warning("Error shutting down llama-server: %s", e) logger.info("All subprocesses cleaned up") # The uvicorn server instance — set by run_server(), used by callers # that need to tell the server to exit (e.g. signal handlers). _server = None # Shutdown event — used to wake the main loop on signal _shutdown_event = None def run_server( host: str = "0.0.0.0", port: int = 8888, frontend_path: Path = Path(__file__).resolve().parent.parent / "frontend" / "dist", silent: bool = False, ): """ Start the FastAPI server. Args: host: Host to bind to port: Port to bind to (auto-increments if in use) frontend_path: Path to frontend build directory (optional) silent: Suppress startup messages Note: Signal handlers are NOT registered here so that embedders (e.g. Colab notebooks) keep their own interrupt semantics. Standalone callers should register handlers after calling this. """ global _server, _shutdown_event import nest_asyncio nest_asyncio.apply() import asyncio from threading import Thread, Event import time import uvicorn from main import app, setup_frontend from utils.paths import ensure_studio_directories # Create all standard directories on startup ensure_studio_directories() # Auto-find free port if requested port is in use if not _is_port_free(host, port): original_port = port port = _find_free_port(host, port) if not silent: print(f"Port {original_port} is in use, using port {port} instead") # Setup frontend if path provided if frontend_path: if setup_frontend(app, frontend_path): if not silent: print(f"✅ Frontend loaded from {frontend_path}") else: if not silent: print(f"⚠️ Frontend not found at {frontend_path}") # Create the uvicorn server and expose it for signal handlers config = uvicorn.Config( app, host = host, port = port, log_level = "info", access_log = False ) _server = uvicorn.Server(config) _shutdown_event = Event() # Run server in a daemon thread def _run(): asyncio.run(_server.serve()) thread = Thread(target = _run, daemon = True) thread.start() time.sleep(3) if not silent: display_host = _resolve_external_ip() if host == "0.0.0.0" else host print("") print("=" * 50) print(f"🦥 Open your web browser, and enter http://localhost:{port}") print("=" * 50) print("") print("=" * 50) print(f"🦥 Unsloth Studio is running on port {port}") print(f" Local Access: http://localhost:{port}") print(f" Worldwide Web Address: http://{display_host}:{port}") print(f" API: http://{display_host}:{port}/api") print(f" Health: http://{display_host}:{port}/api/health") print("=" * 50) return app # For direct execution (also invoked by CLI via os.execvp / subprocess) if __name__ == "__main__": import argparse import signal parser = argparse.ArgumentParser(description = "Run Unsloth UI Backend server") parser.add_argument("--host", default = "0.0.0.0", help = "Host to bind to") parser.add_argument("--port", type = int, default = 8888, help = "Port to bind to") parser.add_argument( "--frontend", type = str, default = Path(__file__).resolve().parent.parent / "frontend" / "dist", help = "Path to frontend build", ) parser.add_argument("--silent", action = "store_true", help = "Suppress output") args = parser.parse_args() kwargs = dict(host = args.host, port = args.port, silent = args.silent) if args.frontend is not None: kwargs["frontend_path"] = Path(args.frontend) run_server(**kwargs) # ── Signal handler — ensures subprocess cleanup on Ctrl+C ──── def _signal_handler(signum, frame): _graceful_shutdown(_server) _shutdown_event.set() signal.signal(signal.SIGINT, _signal_handler) signal.signal(signal.SIGTERM, _signal_handler) # On Windows, some terminals send SIGBREAK for Ctrl+C / Ctrl+Break if hasattr(signal, "SIGBREAK"): signal.signal(signal.SIGBREAK, _signal_handler) # Keep running until shutdown signal. # NOTE: Event.wait() without a timeout blocks at the C level on Linux, # which prevents Python from delivering SIGINT (Ctrl+C). Using a # short timeout in a loop lets the interpreter process pending signals. while not _shutdown_event.is_set(): _shutdown_event.wait(timeout = 1) ================================================ FILE: studio/backend/state/.gitkeep ================================================ ================================================ FILE: studio/backend/state/__init__.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 ================================================ FILE: studio/backend/tests/__init__.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 ================================================ FILE: studio/backend/tests/conftest.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Shared pytest configuration for the backend test suite. Ensures that the backend root is on sys.path so that `import utils.utils` (and similar flat imports) resolve correctly. """ import sys from pathlib import Path # Add backend root to sys.path (mirrors how the app itself is launched) _backend_root = Path(__file__).resolve().parent.parent if str(_backend_root) not in sys.path: sys.path.insert(0, str(_backend_root)) ================================================ FILE: studio/backend/tests/test_data_recipe_seed.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 from pathlib import Path def test_seed_inspect_load_kwargs_disables_remote_code_execution(): seed_route = ( Path(__file__).resolve().parent.parent / "routes" / "data_recipe" / "seed.py" ).read_text() assert '"trust_remote_code": False' in seed_route ================================================ FILE: studio/backend/tests/test_utils.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Tests for utils/hardware and utils/utils — device detection, GPU memory, error formatting. These tests are designed to pass on ANY platform: • NVIDIA GPU (CUDA backend, requires torch) • Apple Silicon (MLX backend, requires mlx) • CPU-only (no GPU at all) No ML framework is imported at the top level. Tests that need torch/mlx internals for mocking are skipped when unavailable. Run with: cd studio/backend python -m pytest tests/test_utils.py -v """ import platform from unittest.mock import patch, MagicMock import pytest # --- Conditional framework imports --- try: import torch HAS_TORCH = True except ImportError: HAS_TORCH = False try: import mlx.core as mx HAS_MLX = True except ImportError: HAS_MLX = False needs_torch = pytest.mark.skipif(not HAS_TORCH, reason = "PyTorch not installed") needs_mlx = pytest.mark.skipif(not HAS_MLX, reason = "MLX not installed") from utils.hardware import ( get_device, detect_hardware, is_apple_silicon, clear_gpu_cache, get_gpu_memory_info, log_gpu_memory, DeviceType, ) import utils.hardware.hardware as _hw_module from utils.utils import format_error_message # ========== Helpers ========== def _actual_device() -> str: """Return the real device string for the current machine.""" if HAS_TORCH and torch.cuda.is_available(): return "cuda" if is_apple_silicon() and HAS_MLX: return "mlx" return "cpu" def _reset_and_detect(): """Reset the cached DEVICE global and re-run detection.""" _hw_module.DEVICE = None return detect_hardware() # ========== get_device() ========== class TestGetDevice: """Tests for get_device() — should agree with the real hardware.""" def setup_method(self): self._saved_device = _hw_module.DEVICE def teardown_method(self): _hw_module.DEVICE = self._saved_device def test_returns_valid_device_type(self): result = get_device() assert result in (DeviceType.CUDA, DeviceType.MLX, DeviceType.CPU) def test_matches_actual_hardware(self): assert get_device().value == _actual_device() # --- Mocked paths --- @needs_torch def test_returns_cuda_when_cuda_available(self): with ( patch("utils.hardware.hardware._has_torch", return_value = True), patch("torch.cuda.is_available", return_value = True), ): assert _reset_and_detect() == DeviceType.CUDA @needs_mlx def test_returns_mlx_when_on_apple_silicon_with_mlx(self): with ( patch("utils.hardware.hardware._has_torch", return_value = False), patch("utils.hardware.hardware.is_apple_silicon", return_value = True), patch("utils.hardware.hardware._has_mlx", return_value = True), ): assert _reset_and_detect() == DeviceType.MLX def test_returns_cpu_when_nothing_available(self): with ( patch("utils.hardware.hardware._has_torch", return_value = False), patch("utils.hardware.hardware.is_apple_silicon", return_value = False), patch("utils.hardware.hardware._has_mlx", return_value = False), ): assert _reset_and_detect() == DeviceType.CPU # ========== is_apple_silicon() ========== class TestIsAppleSilicon: def test_returns_bool(self): assert isinstance(is_apple_silicon(), bool) def test_true_on_darwin_arm64(self): with patch("utils.hardware.hardware.platform") as mock_plat: mock_plat.system.return_value = "Darwin" mock_plat.machine.return_value = "arm64" assert is_apple_silicon() is True def test_false_on_linux_x86(self): with patch("utils.hardware.hardware.platform") as mock_plat: mock_plat.system.return_value = "Linux" mock_plat.machine.return_value = "x86_64" assert is_apple_silicon() is False def test_false_on_darwin_x86(self): """Intel Mac should return False.""" with patch("utils.hardware.hardware.platform") as mock_plat: mock_plat.system.return_value = "Darwin" mock_plat.machine.return_value = "x86_64" assert is_apple_silicon() is False # ========== clear_gpu_cache() ========== class TestClearGpuCache: """clear_gpu_cache() must never raise, regardless of platform.""" def test_does_not_raise(self): clear_gpu_cache() @needs_torch def test_calls_cuda_cache_when_cuda(self): with ( patch("utils.hardware.hardware.get_device", return_value = DeviceType.CUDA), patch("torch.cuda.empty_cache") as mock_empty, patch("torch.cuda.ipc_collect") as mock_ipc, ): clear_gpu_cache() mock_empty.assert_called_once() mock_ipc.assert_called_once() @needs_mlx def test_mlx_does_not_raise(self): """MLX cache clear is a no-op — should just succeed.""" with patch("utils.hardware.hardware.get_device", return_value = DeviceType.MLX): clear_gpu_cache() def test_noop_on_cpu(self): with patch("utils.hardware.hardware.get_device", return_value = DeviceType.CPU): clear_gpu_cache() # ========== get_gpu_memory_info() ========== class TestGetGpuMemoryInfo: def test_returns_dict(self): result = get_gpu_memory_info() assert isinstance(result, dict) def test_has_available_key(self): assert "available" in get_gpu_memory_info() def test_has_backend_key(self): assert "backend" in get_gpu_memory_info() def test_backend_matches_device(self): result = get_gpu_memory_info() assert result["backend"] == get_device().value # --- When a GPU IS available --- @pytest.mark.skipif( _actual_device() == "cpu", reason = "No GPU available on this machine" ) def test_gpu_available_fields(self): result = get_gpu_memory_info() assert result["available"] is True assert result["total_gb"] > 0 assert result["allocated_gb"] >= 0 assert result["free_gb"] >= 0 assert 0 <= result["utilization_pct"] <= 100 assert "device_name" in result # --- CUDA-specific mocked test --- @needs_torch def test_cuda_path_returns_correct_fields(self): mock_props = MagicMock() mock_props.total_memory = 16 * (1024**3) mock_props.name = "NVIDIA Test GPU" with ( patch("utils.hardware.hardware.get_device", return_value = DeviceType.CUDA), patch("torch.cuda.current_device", return_value = 0), patch("torch.cuda.get_device_properties", return_value = mock_props), patch("torch.cuda.memory_allocated", return_value = 4 * (1024**3)), patch("torch.cuda.memory_reserved", return_value = 6 * (1024**3)), ): result = get_gpu_memory_info() assert result["available"] is True assert result["backend"] == "cuda" assert result["device_name"] == "NVIDIA Test GPU" assert abs(result["total_gb"] - 16.0) < 0.01 assert abs(result["allocated_gb"] - 4.0) < 0.01 assert abs(result["free_gb"] - 12.0) < 0.01 assert abs(result["utilization_pct"] - 25.0) < 0.1 # --- MLX-specific mocked test --- @needs_mlx def test_mlx_path_returns_correct_fields(self): mock_psutil_mem = MagicMock() mock_psutil_mem.total = 32 * (1024**3) # 32 GB unified mock_psutil = MagicMock() mock_psutil.virtual_memory.return_value = mock_psutil_mem with ( patch("utils.hardware.hardware.get_device", return_value = DeviceType.MLX), patch.dict("sys.modules", {"psutil": mock_psutil}), ): result = get_gpu_memory_info() assert result["available"] is True assert result["backend"] == "mlx" assert "Apple Silicon" in result["device_name"] assert abs(result["total_gb"] - 32.0) < 0.01 # --- CPU-only path --- def test_cpu_path_returns_unavailable(self): with patch("utils.hardware.hardware.get_device", return_value = DeviceType.CPU): result = get_gpu_memory_info() assert result["available"] is False assert result["backend"] == "cpu" # --- Error resilience --- @needs_torch def test_cuda_error_returns_unavailable(self): with ( patch("utils.hardware.hardware.get_device", return_value = DeviceType.CUDA), patch( "torch.cuda.current_device", side_effect = RuntimeError("CUDA init failed"), ), ): result = get_gpu_memory_info() assert result["available"] is False assert "error" in result # ========== log_gpu_memory() ========== class TestLogGpuMemory: def test_does_not_raise(self): log_gpu_memory("test") def test_logs_gpu_info_when_available(self, caplog): fake_info = { "available": True, "backend": "cuda", "device_name": "FakeGPU", "allocated_gb": 2.0, "total_gb": 16.0, "utilization_pct": 12.5, "free_gb": 14.0, } import structlog from loggers import get_logger with ( patch( "utils.hardware.hardware.get_gpu_memory_info", return_value = fake_info ), caplog.at_level(logging.INFO, logger = "utils.hardware.hardware"), ): log_gpu_memory("unit-test") assert "unit-test" in caplog.text assert "CUDA" in caplog.text assert "FakeGPU" in caplog.text def test_logs_cpu_fallback_when_no_gpu(self, caplog): fake_info = {"available": False, "backend": "cpu"} import structlog from loggers import get_logger with ( patch( "utils.hardware.hardware.get_gpu_memory_info", return_value = fake_info ), caplog.at_level(logging.INFO, logger = "utils.hardware.hardware"), ): log_gpu_memory("cpu-test") assert "No GPU available" in caplog.text # ========== format_error_message() ========== class TestFormatErrorMessage: def test_not_found(self): err = Exception("Repository not found for unsloth/test") msg = format_error_message(err, "unsloth/test") assert "not found" in msg.lower() assert "test" in msg def test_unauthorized(self): err = Exception("401 Unauthorized") msg = format_error_message(err, "some/model") assert "authentication" in msg.lower() or "unauthorized" in msg.lower() def test_gated_model(self): err = Exception("Access to model requires authentication") msg = format_error_message(err, "meta/llama") assert "authentication" in msg.lower() def test_invalid_token(self): err = Exception("Invalid user token") msg = format_error_message(err, "any/model") assert "invalid" in msg.lower() # --- OOM on CUDA --- @needs_torch def test_cuda_oom(self): err = Exception("CUDA out of memory") with patch("utils.hardware.get_device", return_value = DeviceType.CUDA): msg = format_error_message(err, "big/model") assert "GPU" in msg assert "big/model" not in msg assert "model" in msg # --- OOM on MLX --- @needs_mlx def test_mlx_oom(self): err = Exception("MLX backend out of memory") with patch("utils.hardware.get_device", return_value = DeviceType.MLX): msg = format_error_message(err, "unsloth/huge-model") assert "Apple Silicon" in msg # --- OOM on CPU --- def test_cpu_oom(self): err = Exception("not enough memory to allocate") with patch("utils.hardware.get_device", return_value = DeviceType.CPU): msg = format_error_message(err, "any/model") assert "system" in msg.lower() # --- Generic fallback --- def test_generic_error(self): err = Exception("Something completely unexpected") msg = format_error_message(err, "any/model") assert msg == "Something completely unexpected" ================================================ FILE: studio/backend/utils/.gitkeep ================================================ ================================================ FILE: studio/backend/utils/__init__.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 ================================================ FILE: studio/backend/utils/cache_cleanup.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Utility for cleaning up the Unsloth compiled cache directory. The unsloth_compiled_cache is created by unsloth_zoo/compiler.py during FastModel.from_pretrained() and contains model-type-specific compiled Python files. It should be cleared between model loads to avoid stale artefacts. """ import shutil import structlog from loggers import get_logger from pathlib import Path logger = get_logger(__name__) # Possible locations where unsloth_compiled_cache may appear _BACKEND_DIR = Path(__file__).resolve().parent.parent # studio/backend _PROJECT_ROOT = _BACKEND_DIR.parent.parent # repo root _CACHE_DIRS = [ _BACKEND_DIR / "unsloth_compiled_cache", _PROJECT_ROOT / "unsloth_compiled_cache", _PROJECT_ROOT / "studio" / "tmp" / "unsloth_compiled_cache", ] def clear_unsloth_compiled_cache() -> None: """Remove every known unsloth_compiled_cache directory (idempotent).""" for cache_dir in _CACHE_DIRS: if cache_dir.exists(): logger.info(f"Removing unsloth compiled cache: {cache_dir}") shutil.rmtree(cache_dir, ignore_errors = True) ================================================ FILE: studio/backend/utils/datasets/__init__.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Dataset utilities package. This package provides utilities for dataset format detection, conversion, and processing for LLM and VLM fine-tuning workflows. Modules: - format_detection: Detect dataset formats (Alpaca, ShareGPT, ChatML) - format_conversion: Convert between dataset formats - chat_templates: Apply chat templates to datasets - vlm_processing: Vision-Language Model processing utilities - data_collators: Custom data collators for training - model_mappings: Model-to-template mapping constants """ # Format detection from .format_detection import ( detect_dataset_format, detect_custom_format_heuristic, detect_multimodal_dataset, detect_vlm_dataset_structure, ) # Format conversion from .format_conversion import ( standardize_chat_format, convert_chatml_to_alpaca, convert_alpaca_to_chatml, convert_to_vlm_format, convert_llava_to_vlm_format, convert_sharegpt_with_images_to_vlm_format, ) # Chat templates from .chat_templates import ( apply_chat_template_to_dataset, get_dataset_info_summary, get_tokenizer_chat_template, DEFAULT_ALPACA_TEMPLATE, ) # VLM processing from .vlm_processing import ( generate_smart_vlm_instruction, ) # Data collators from .data_collators import ( DataCollatorSpeechSeq2SeqWithPadding, DeepSeekOCRDataCollator, VLMDataCollator, ) # Model mappings (constants) from .model_mappings import ( TEMPLATE_TO_MODEL_MAPPER, MODEL_TO_TEMPLATE_MAPPER, TEMPLATE_TO_RESPONSES_MAPPER, ) # Legacy imports from the original dataset_utils.py for backward compatibility # These functions have not yet been refactored into separate modules from .dataset_utils import ( check_dataset_format, format_and_template_dataset, format_dataset, ) # Public API __all__ = [ # Detection "detect_dataset_format", "detect_custom_format_heuristic", "detect_multimodal_dataset", "detect_vlm_dataset_structure", # Conversion "standardize_chat_format", "convert_chatml_to_alpaca", "convert_alpaca_to_chatml", "convert_to_vlm_format", "convert_llava_to_vlm_format", "convert_sharegpt_with_images_to_vlm_format", # Templates "apply_chat_template_to_dataset", "get_dataset_info_summary", "get_tokenizer_chat_template", "DEFAULT_ALPACA_TEMPLATE", # VLM "generate_smart_vlm_instruction", # Collators "DataCollatorSpeechSeq2SeqWithPadding", "DeepSeekOCRDataCollator", "VLMDataCollator", # Mappings "TEMPLATE_TO_MODEL_MAPPER", "MODEL_TO_TEMPLATE_MAPPER", "TEMPLATE_TO_RESPONSES_MAPPER", # Main entry points "check_dataset_format", "format_and_template_dataset", "format_dataset", ] ================================================ FILE: studio/backend/utils/datasets/chat_templates.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Chat template application utilities for dataset processing. This module contains functions for applying chat templates to datasets and generating dataset info summaries. """ from torch.utils.data import IterableDataset from .format_detection import detect_dataset_format, detect_multimodal_dataset, detect_custom_format_heuristic from .model_mappings import MODEL_TO_TEMPLATE_MAPPER from loggers import get_logger logger = get_logger(__name__) DEFAULT_ALPACA_TEMPLATE = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. ### Instruction: {} ### Input: {} ### Response: {}""" def get_tokenizer_chat_template(tokenizer, model_name): """ Gets appropriate chat template for tokenizer based on model. Uses Unsloth's get_chat_template if model is in the mapper. Args: tokenizer: HuggingFace tokenizer model_name: Model class name (e.g., "Gemma3ForCausalLM") Returns: tokenizer: Tokenizer with appropriate chat template applied """ try: from unsloth.chat_templates import get_chat_template except ImportError: # Unsloth not available, return tokenizer as-is return tokenizer # Normalize model_name to lowercase for matching model_name_lower = model_name.lower() # Check if model matches any template in mapper matched_template = None # Direct match in MODEL_TO_TEMPLATE_MAPPER if model_name_lower in MODEL_TO_TEMPLATE_MAPPER: matched_template = MODEL_TO_TEMPLATE_MAPPER[model_name_lower] logger.info(f"📝 Applying Unsloth chat template: {matched_template}") try: tokenizer = get_chat_template( tokenizer, chat_template = matched_template, ) except Exception as e: logger.info(f"⚠️ Failed to apply Unsloth template '{matched_template}': {e}") logger.info(f" Falling back to tokenizer's default chat template") else: # Check if tokenizer actually has a chat_template set has_chat_template = ( hasattr(tokenizer, 'chat_template') and tokenizer.chat_template is not None ) if has_chat_template: logger.info(f"📝 Using tokenizer's own chat template (no Unsloth template match)") else: # Base model with no chat template — apply default ChatML logger.info(f"📝 No chat template found — applying default ChatML template (base model)") try: tokenizer = get_chat_template( tokenizer, chat_template = "chatml", ) except Exception as e: logger.info(f"⚠️ Failed to apply default ChatML template: {e}") logger.info(f" Falling back to tokenizer as-is") return tokenizer def get_dataset_info_summary(dataset_info): """ Returns a human-readable summary for UI display. """ detected_format = dataset_info["detected_format"] final_format = dataset_info["final_format"] format_descriptions = { "alpaca": "Alpaca format (instruction/input/output)", "sharegpt": "ShareGPT format (needs standardization)", "chatml_messages": "ChatML format (messages column) - OpenAI compatible", "chatml_conversations": "ChatML format (conversations column) - HuggingFace standard", "unknown": "Unknown format" } return { "detected_format": detected_format, "final_format": final_format, "detected_description": format_descriptions.get(detected_format, "Unknown"), "final_description": format_descriptions.get(final_format, "Unknown"), "chat_column": dataset_info["chat_column"], "is_standardized": dataset_info["is_standardized"], "warnings": dataset_info.get("warnings", []), "ready_for_training": dataset_info["is_standardized"] and final_format != "unknown" } def apply_chat_template_to_dataset( dataset_info, tokenizer, model_name = None, custom_prompt_template = None, add_eos_token = False, remove_bos_prefix = False, custom_format_mapping = None, auto_detect_mapping = True, batch_size = 1000, num_proc = None, progress_callback = None, ): """ Applies chat template to dataset based on its format. Args: dataset_info: Output from format_dataset() with metadata tokenizer: Tokenizer with chat template custom_prompt_template: Optional string template for custom formatting add_eos_token: If True, appends tokenizer.eos_token to each text remove_bos_prefix: If True, removes '' prefix (for Gemma, etc.) custom_format_mapping: Dict mapping custom columns to standard format batch_size: Batch size for processing num_proc: Number of processes Returns: dict with dataset, success status, warnings, and errors """ dataset = dataset_info["dataset"] final_format = dataset_info["final_format"] chat_column = dataset_info["chat_column"] is_standardized = dataset_info["is_standardized"] warnings = list(dataset_info.get("warnings", [])) errors = [] # Get EOS token if needed eos_token = "" if add_eos_token: if hasattr(tokenizer, 'eos_token') and tokenizer.eos_token: eos_token = tokenizer.eos_token else: warnings.append("add_eos_token=True but tokenizer has no eos_token") # CUSTOM FORMAT MAPPING (for non-standard datasets) if final_format == "unknown": # Try auto-detection if no custom mapping provided if custom_format_mapping is None and auto_detect_mapping: # Check if format_dataset already tried and failed if not dataset_info.get("auto_detection_attempted", False): custom_format_mapping = detect_custom_format_heuristic(dataset) if custom_format_mapping: warnings.append(f"Auto-detected column mapping: {custom_format_mapping}") else: errors.append("Could not auto-detect format mapping") return { "dataset": dataset, "success": False, "warnings": warnings, "errors": errors } else: # Already failed once in format_dataset, don't retry errors.append( "Format remains unknown after detection attempts. " "Please provide custom_format_mapping to specify column roles manually." ) return { "dataset": dataset, "success": False, "warnings": warnings, "errors": errors } if custom_format_mapping: warnings.append(f"Applying custom format mapping: {custom_format_mapping}") is_user_provided = dataset_info.get("custom_format_mapping") is not None def _apply_custom_mapping(examples): conversations = [] num_examples = len(examples[list(examples.keys())[0]]) # Only preserve unmapped columns if auto-detected preserved_columns = {} if not is_user_provided: all_columns = set(examples.keys()) mapped_columns = set(custom_format_mapping.keys()) non_mapped_columns = all_columns - mapped_columns for col in non_mapped_columns: preserved_columns[col] = examples[col] for i in range(num_examples): convo = [] role_order = ['system', 'user', 'assistant'] for target_role in role_order: for col_name, role in custom_format_mapping.items(): if role == target_role and col_name in examples: content = examples[col_name][i] if is_user_provided: # User explicitly mapped - include even if empty convo.append({"role": role, "content": str(content) if content else ""}) else: # Auto-detected - skip empty if content and str(content).strip(): convo.append({"role": role, "content": str(content)}) conversations.append(convo) result = {"conversations": conversations} if not is_user_provided: result.update(preserved_columns) return result try: dataset = dataset.map(_apply_custom_mapping, batched = True, batch_size = batch_size) # Update to use conversations format final_format = "chatml_conversations" chat_column = "conversations" is_standardized = True warnings.append("Successfully converted to ChatML format via custom mapping") except Exception as e: errors.append(f"Custom format mapping failed: {e}") return { "dataset": dataset, "success": False, "warnings": warnings, "errors": errors } # ALPACA FORMAT if final_format == "alpaca": # Set alpaca chat template on tokenizer for saving (if not already set) # This ensures the template is saved with the model for inference if not (hasattr(tokenizer, 'chat_template') and tokenizer.chat_template): try: from unsloth.chat_templates import get_chat_template tokenizer = get_chat_template(tokenizer, chat_template = "alpaca") logger.info(f"📝 Set alpaca chat template on tokenizer for model saving") except Exception as e: logger.info(f"⚠️ Could not set alpaca template on tokenizer: {e}") # Use custom template if provided def _format_alpaca_custom(examples): texts = [] for i in range(len(examples["instruction"])): fields = { "instruction": examples["instruction"][i], "input": examples.get("input", [""] * len(examples["instruction"]))[i], "output": examples["output"][i] } try: text = DEFAULT_ALPACA_TEMPLATE.format(fields["instruction"], fields["input"], fields["output"]) text += eos_token texts.append(text) except KeyError as e: errors.append(f"Custom template missing field: {e}") texts.append("") return {"text": texts} formatted_fn = _format_alpaca_custom try: dataset_map_kwargs = { 'batched': True, 'batch_size': batch_size, } if not isinstance(dataset, IterableDataset): from utils.hardware import safe_num_proc if num_proc is None or type(num_proc) is not int: num_proc = safe_num_proc() else: num_proc = safe_num_proc(num_proc) dataset_map_kwargs['num_proc'] = num_proc dataset_map_kwargs['desc'] = "Applying template to Alpaca format" formatted_dataset = dataset.map(formatted_fn, **dataset_map_kwargs) return { "dataset": formatted_dataset, "success": True, "warnings": warnings, "errors": errors } except Exception as e: errors.append(f"Failed to format Alpaca dataset: {e}") return { "dataset": dataset, "success": False, "warnings": warnings, "errors": errors } # CHATML FORMATS elif final_format in ["chatml_messages", "chatml_conversations"]: if not is_standardized: warnings.append("Dataset may not be fully standardized") # Apply Unsloth chat template if model matches if model_name: tokenizer = get_tokenizer_chat_template(tokenizer, model_name) def _format_chatml(examples): convos = examples[chat_column] texts = [] for convo in convos: try: text = tokenizer.apply_chat_template( convo, tokenize = False, add_generation_prompt = False ) if remove_bos_prefix: text = text.removeprefix('') text += eos_token texts.append(text) except Exception as e: if len(texts) == 0: warnings.append(f"Chat template failed: {e}") texts.append("") return {"text": texts} try: dataset_map_kwargs = { 'batched': True, 'batch_size': batch_size, } if not isinstance(dataset, IterableDataset): from utils.hardware import safe_num_proc if num_proc is None or type(num_proc) is not int: num_proc = safe_num_proc() else: num_proc = safe_num_proc(num_proc) dataset_map_kwargs['num_proc'] = num_proc dataset_map_kwargs['desc'] = f"Applying chat template to {final_format}" # Monitor tqdm progress from dataset.map() and relay to callback _tqdm_monitor_stop = None if progress_callback and not isinstance(dataset, IterableDataset): import threading from tqdm.auto import tqdm as _tqdm_cls _tqdm_monitor_stop = threading.Event() _total = len(dataset) if hasattr(dataset, "__len__") else 0 _desc = f"Applying chat template to {final_format}" def _poll_tqdm(): while not _tqdm_monitor_stop.is_set(): for bar in list(getattr(_tqdm_cls, "_instances", set())): try: n = bar.n or 0 total = bar.total or _total if total > 0 and n > 0: pct = min(int(n * 100 / total), 100) progress_callback( status_message = f"{_desc}... {pct}% ({n:,}/{total:,})" ) except (AttributeError, ReferenceError): pass _tqdm_monitor_stop.wait(3) threading.Thread(target = _poll_tqdm, daemon = True).start() formatted_dataset = dataset.map(_format_chatml, **dataset_map_kwargs) if _tqdm_monitor_stop is not None: _tqdm_monitor_stop.set() return { "dataset": formatted_dataset, "success": True, "warnings": warnings, "errors": errors } except Exception as e: errors.append(f"Failed to format ChatML dataset: {e}") return { "dataset": dataset, "success": False, "warnings": warnings, "errors": errors } # UNKNOWN FORMAT else: errors.append( f"Cannot apply chat template to format: {final_format}. " f"This should not happen after custom mapping." ) return { "dataset": dataset, "success": False, "warnings": warnings, "errors": errors } ================================================ FILE: studio/backend/utils/datasets/data_collators.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Data collators for dataset processing. This module contains custom data collators for training, particularly for VLM/OCR processing. """ import torch from dataclasses import dataclass from typing import Any, List, Optional, Union from loggers import get_logger logger = get_logger(__name__) @dataclass class DataCollatorSpeechSeq2SeqWithPadding: """ Data collator for Whisper speech-to-text training. Pads input features (audio) and label sequences (text) separately, masks padding in labels with -100, and strips leading BOS token. Mirrors the collator from the Whisper.ipynb notebook. """ processor: Any def __call__(self, features: List[dict]) -> dict: input_features = [ {"input_features": feature["input_features"]} for feature in features ] batch = self.processor.feature_extractor.pad( input_features, return_tensors = "pt" ) label_features = [{"input_ids": feature["labels"]} for feature in features] labels_batch = self.processor.tokenizer.pad(label_features, return_tensors = "pt") labels = labels_batch["input_ids"].masked_fill( labels_batch.attention_mask.ne(1), -100 ) if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item(): labels = labels[:, 1:] batch["labels"] = labels return batch @dataclass class DeepSeekOCRDataCollator: """ Data collator for DeepSeek OCR VLM training. Handles: - Image processing via processor - Text tokenization - Proper label masking for instruction fine-tuning """ processor: Any # Qwen2VLProcessor or similar max_length: int = 2048 ignore_index: int = -100 def __call__(self, batch: List[dict]) -> dict: """ Collate a batch of samples. Args: batch: List of dicts, each with 'messages' containing [{'role': 'user', 'content': [...]}, {'role': 'assistant', 'content': [...]}] Returns: dict with input_ids, attention_mask, labels, pixel_values, etc. """ from PIL import Image # Extract messages and images all_messages = [] all_images = [] for sample in batch: messages = sample["messages"] all_messages.append(messages) # Extract PIL images from content for msg in messages: content = msg.get("content", []) if isinstance(content, list): for item in content: if isinstance(item, dict) and item.get("type") == "image": img = item.get("image") if img is not None and hasattr(img, "size"): # PIL Image all_images.append(img) # Process with the VL processor try: # Qwen2VL style processing texts = [ self.processor.apply_chat_template( msgs, tokenize = False, add_generation_prompt = False ) for msgs in all_messages ] # Process with images inputs = self.processor( text = texts, images = all_images if all_images else None, return_tensors = "pt", padding = True, truncation = True, max_length = self.max_length, ) # Create labels (mask input, keep output) labels = inputs["input_ids"].clone() # Simple masking: mask padding tokens labels[labels == self.processor.tokenizer.pad_token_id] = self.ignore_index inputs["labels"] = labels return inputs except Exception as e: logger.info(f"⚠️ DeepSeekOCRDataCollator error: {e}") raise @dataclass class VLMDataCollator: """ Generic VLM data collator that works with various processors. Supports: - Qwen2VL - LLaVA - Other VL models with compatible processors """ processor: Any max_length: int = 2048 ignore_index: int = -100 mask_input_tokens: bool = True # Whether to mask user tokens in labels def __call__(self, batch: List[dict]) -> dict: """ Collate a batch of VLM samples. """ all_messages = [] all_images = [] for sample in batch: messages = sample.get("messages", []) all_messages.append(messages) # Extract images for msg in messages: content = msg.get("content", []) if isinstance(content, list): for item in content: if isinstance(item, dict): img = item.get("image") if img is not None: all_images.append(img) # Apply chat template texts = [ self.processor.apply_chat_template( msgs, tokenize = False, add_generation_prompt = False ) for msgs in all_messages ] # Process inputs inputs = self.processor( text = texts, images = all_images if all_images else None, return_tensors = "pt", padding = True, truncation = True, max_length = self.max_length, ) # Create labels labels = inputs["input_ids"].clone() # Mask padding if hasattr(self.processor, "tokenizer"): pad_token_id = self.processor.tokenizer.pad_token_id else: pad_token_id = self.processor.pad_token_id if pad_token_id is not None: labels[labels == pad_token_id] = self.ignore_index inputs["labels"] = labels return inputs ================================================ FILE: studio/backend/utils/datasets/dataset_utils.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Dataset utilities for format detection, conversion, and template application. This module provides the main entry points for dataset processing: - check_dataset_format: Lightweight check if manual mapping is needed (for frontend) - format_dataset: Detects and normalizes dataset formats - format_and_template_dataset: End-to-end processing with chat template application All internal utilities have been moved to separate modules: - format_detection: detect_dataset_format, detect_multimodal_dataset, etc. - format_conversion: standardize_chat_format, convert_chatml_to_alpaca, etc. - chat_templates: apply_chat_template_to_dataset, get_tokenizer_chat_template, etc. - vlm_processing: generate_smart_vlm_instruction - data_collators: DeepSeekOCRDataCollator, VLMDataCollator - model_mappings: TEMPLATE_TO_MODEL_MAPPER """ import json # Import from modular files from .format_detection import ( detect_dataset_format, detect_multimodal_dataset, detect_vlm_dataset_structure, detect_custom_format_heuristic, ) from .format_conversion import ( standardize_chat_format, convert_chatml_to_alpaca, convert_alpaca_to_chatml, convert_to_vlm_format, convert_llava_to_vlm_format, convert_sharegpt_with_images_to_vlm_format, ) from .chat_templates import ( apply_chat_template_to_dataset, get_dataset_info_summary, get_tokenizer_chat_template, DEFAULT_ALPACA_TEMPLATE, ) from .vlm_processing import generate_smart_vlm_instruction from .data_collators import DeepSeekOCRDataCollator, VLMDataCollator from .model_mappings import TEMPLATE_TO_MODEL_MAPPER from loggers import get_logger logger = get_logger(__name__) def check_dataset_format(dataset, is_vlm: bool = False) -> dict: """ Lightweight format check without processing - for frontend validation. Use this to quickly determine if user needs to manually map columns before calling the full format_and_template_dataset(). Args: dataset: HuggingFace dataset is_vlm: Whether this is a Vision-Language Model dataset Returns: dict: { "requires_manual_mapping": bool - True if user must map columns, "detected_format": str - The detected format, "columns": list - Available column names for mapping UI, "suggested_mapping": dict or None - Auto-detected mapping if available, "detected_image_column": str or None - For VLM only, "detected_text_column": str or None - For VLM only, } """ columns = ( list(dataset.column_names) if hasattr(dataset, "column_names") else list(next(iter(dataset)).keys()) ) # Auto-detect multimodal data regardless of is_vlm flag multimodal_info = detect_multimodal_dataset(dataset) is_audio = multimodal_info.get("is_audio", False) # Common audio fields for all return paths audio_fields = { "is_audio": is_audio, "detected_audio_column": multimodal_info.get("detected_audio_column"), "detected_speaker_column": multimodal_info.get("detected_speaker_column"), } if is_vlm: vlm_structure = detect_vlm_dataset_structure(dataset) requires_mapping = vlm_structure["format"] == "unknown" warning = None if requires_mapping: img_col = vlm_structure.get("image_column") txt_col = vlm_structure.get("text_column") missing = [] if not img_col: missing.append("image") if not txt_col: missing.append("text") if missing: warning = ( f"Could not auto-detect {' or '.join(missing)} column. " "Please assign image and text columns manually." ) return { "requires_manual_mapping": requires_mapping, "detected_format": vlm_structure["format"], "columns": columns, "suggested_mapping": None, "detected_image_column": vlm_structure.get("image_column"), "detected_text_column": vlm_structure.get("text_column"), "is_image": multimodal_info["is_image"], "multimodal_columns": multimodal_info.get("multimodal_columns"), "warning": warning, **audio_fields, } if is_audio: # Audio dataset — require manual mapping only when columns can't be auto-detected detected_audio = multimodal_info.get("detected_audio_column") detected_text = multimodal_info.get("detected_text_column") needs_mapping = not detected_audio or not detected_text return { "requires_manual_mapping": needs_mapping, "detected_format": "audio", "columns": columns, "suggested_mapping": None, "detected_image_column": None, "detected_text_column": multimodal_info.get("detected_text_column"), "is_image": False, "multimodal_columns": multimodal_info.get("audio_columns"), **audio_fields, } # Text / LLM flow detected = detect_dataset_format(dataset) # If format is unknown, try heuristic detection if detected["format"] == "unknown": heuristic_mapping = detect_custom_format_heuristic(dataset) if heuristic_mapping: return { "requires_manual_mapping": False, "detected_format": "custom_heuristic", "columns": columns, "suggested_mapping": heuristic_mapping, "detected_image_column": None, "detected_text_column": None, "is_image": multimodal_info["is_image"], "multimodal_columns": multimodal_info.get("multimodal_columns"), **audio_fields, } else: # Heuristic failed — user must map manually (or use AI Assist) return { "requires_manual_mapping": True, "detected_format": "unknown", "columns": columns, "suggested_mapping": None, "detected_image_column": None, "detected_text_column": None, "is_image": multimodal_info["is_image"], "multimodal_columns": multimodal_info.get("multimodal_columns"), "warning": ( f"Could not auto-detect column roles for columns: {columns}. " "Please assign roles manually, or use AI Assist." ), **audio_fields, } # Known format detected return { "requires_manual_mapping": False, "detected_format": detected["format"], "columns": columns, "suggested_mapping": None, "detected_image_column": None, "detected_text_column": None, "is_image": multimodal_info["is_image"], "multimodal_columns": multimodal_info.get("multimodal_columns"), **audio_fields, } # Normalise any format-specific role to canonical chatml (user/assistant/system) _TO_CHATML = { "user": "user", "human": "user", "instruction": "user", "assistant": "assistant", "gpt": "assistant", "output": "assistant", "system": "system", "input": "system", } _CHATML_ROLE_ORDER = ("system", "user", "assistant") _CHATML_TO_ALPACA = {"user": "instruction", "system": "input", "assistant": "output"} def _apply_user_mapping(dataset, mapping: dict, batch_size: int = 1000): """ Apply user-provided column mapping to convert dataset to conversations format. Accepts chatml (user/assistant/system), sharegpt (human/gpt/system), and alpaca (instruction/input/output) role names — all normalised to chatml output. If the mapping contains ``__``-prefixed metadata keys (from the conversion advisor), routes to template-based conversion instead of simple role mapping. Returns: Dataset with single 'conversations' column """ # Split metadata from column roles meta = {k: v for k, v in mapping.items() if k.startswith("__")} column_roles = {k: v for k, v in mapping.items() if not k.startswith("__")} if meta: return _apply_template_mapping(dataset, column_roles, meta, batch_size) # ── Simple mode (original logic) ── # Pre-compute: group columns by canonical chatml role role_groups: dict[str, list[str]] = {r: [] for r in _CHATML_ROLE_ORDER} for col_name, role in column_roles.items(): canonical = _TO_CHATML.get(role) if canonical: role_groups[canonical].append(col_name) def _convert(examples): num = len(next(iter(examples.values()))) conversations = [] for i in range(num): convo = [] for chatml_role in _CHATML_ROLE_ORDER: for col in role_groups[chatml_role]: if col in examples: content = examples[col][i] convo.append( { "role": chatml_role, "content": str(content) if content else "", } ) conversations.append(convo) return {"conversations": conversations} return dataset.map( _convert, batched = True, batch_size = batch_size, remove_columns = dataset.column_names, ) def _extract_column_value(val, col: str, label_mapping: dict) -> str: """Extract a string value from a column, handling complex types and label mapping.""" # Handle complex types (dicts, lists) — extract useful text instead of raw repr if isinstance(val, dict): # Common pattern: {"text": [...]} in QA datasets if "text" in val: inner = val["text"] str_val = inner[0] if isinstance(inner, list) and inner else str(inner) else: str_val = json.dumps(val, ensure_ascii = False) elif isinstance(val, list): str_val = val[0] if len(val) == 1 else ", ".join(str(v) for v in val) else: str_val = str(val) if val is not None else "" # Apply label mapping if this column has one if col in label_mapping and isinstance(label_mapping[col], dict): str_val = label_mapping[col].get(str_val, str_val) return str_val def _apply_template_mapping( dataset, column_roles: dict, meta: dict, batch_size: int = 1000 ): """ Apply advisor-driven mapping for non-conversational datasets. Groups columns by their assigned role (user/assistant), concatenates values within each role into a single message, and injects an optional system prompt. Label mapping is applied to convert integer labels to human-readable strings. Returns: Dataset with single 'conversations' column """ system_prompt = meta.get("__system_prompt", "") label_mapping = meta.get("__label_mapping", {}) # {col: {int_str: label_str}} # Group columns by canonical chatml role role_groups: dict[str, list[str]] = {"user": [], "assistant": []} for col, role in column_roles.items(): canonical = _TO_CHATML.get(role, role) if canonical in role_groups: role_groups[canonical].append(col) import logging as _log _log.getLogger(__name__).info( f"Applying role mapping: sys={bool(system_prompt)}, " f"user_cols={role_groups['user']}, asst_cols={role_groups['assistant']}, " f"label_map={list(label_mapping.keys())}" ) def _convert(examples): num = len(next(iter(examples.values()))) conversations = [] for i in range(num): convo = [] # System prompt (generated, static across all rows) if system_prompt: convo.append({"role": "system", "content": system_prompt}) # User message: concatenate all user-role column values user_parts = [] for col in role_groups["user"]: if col in examples: user_parts.append( _extract_column_value(examples[col][i], col, label_mapping) ) if user_parts: convo.append({"role": "user", "content": "\n".join(user_parts)}) # Assistant message: concatenate all assistant-role column values asst_parts = [] for col in role_groups["assistant"]: if col in examples: asst_parts.append( _extract_column_value(examples[col][i], col, label_mapping) ) if asst_parts: convo.append({"role": "assistant", "content": "\n".join(asst_parts)}) conversations.append(convo) return {"conversations": conversations} return dataset.map( _convert, batched = True, batch_size = batch_size, remove_columns = dataset.column_names, ) def _apply_user_mapping_alpaca(dataset, mapping: dict, batch_size: int = 1000): """ Apply user-provided column mapping to convert dataset to Alpaca format. Accepts any format's role names — normalises via _TO_CHATML, then maps user → instruction, system → input, assistant → output. Returns: Dataset with instruction/input/output columns """ col_for: dict[str, str | None] = { "instruction": None, "input": None, "output": None, } for col_name, role in mapping.items(): canonical = _TO_CHATML.get(role) alpaca_field = _CHATML_TO_ALPACA.get(canonical) if canonical else None if alpaca_field: col_for[alpaca_field] = col_name def _convert(examples): num = len(next(iter(examples.values()))) instructions, inputs, outputs = [], [], [] for i in range(num): for field, dest in ( ("instruction", instructions), ("input", inputs), ("output", outputs), ): col = col_for[field] val = ( str(examples[col][i]) if col and col in examples and examples[col][i] else "" ) dest.append(val) return {"instruction": instructions, "input": inputs, "output": outputs} return dataset.map( _convert, batched = True, batch_size = batch_size, remove_columns = dataset.column_names, ) def format_dataset( dataset, format_type = "auto", tokenizer = None, aliases_for_system = [ "system", ], aliases_for_user = [ "user", "human", "input", ], aliases_for_assistant = [ "gpt", "assistant", "output", ], batch_size = 1000, num_proc = None, auto_detect_custom = True, custom_format_mapping = None, ): """ Formats dataset and returns metadata. Returns: dict: { "dataset": processed dataset, "detected_format": original format detected, "final_format": final format after processing, "chat_column": column name with chat data, "is_standardized": whether role names are standardized, "requires_manual_mapping": True if format detection failed and user must map columns, "warnings": list of warning messages } """ # Detect multimodal first (needed for all flows) multimodal_info = detect_multimodal_dataset(dataset) # If user provided explicit mapping, skip detection and apply in the requested format if custom_format_mapping: try: if format_type == "alpaca": mapped_dataset = _apply_user_mapping_alpaca( dataset, custom_format_mapping, batch_size ) final_format = "alpaca" chat_column = None else: # auto / chatml / sharegpt / conversational — all produce chatml conversations # (sharegpt is always standardized to role/content internally) mapped_dataset = _apply_user_mapping( dataset, custom_format_mapping, batch_size ) final_format = "chatml_conversations" chat_column = "conversations" return { "dataset": mapped_dataset, "detected_format": "user_mapped", "final_format": final_format, "chat_column": chat_column, "is_standardized": True, "requires_manual_mapping": False, "is_image": multimodal_info["is_image"], "multimodal_info": multimodal_info, "warnings": [ f"Applied user-provided column mapping ({format_type}): {custom_format_mapping}" ], } except Exception as e: return { "dataset": dataset, "detected_format": "user_mapped", "final_format": "unknown", "chat_column": None, "is_standardized": False, "requires_manual_mapping": True, "is_image": multimodal_info["is_image"], "multimodal_info": multimodal_info, "warnings": [f"Failed to apply user mapping: {e}"], } # Detect current format detected = detect_dataset_format(dataset) warnings = [] # Add multimodal warning if detected if multimodal_info["is_image"]: warnings.append( f"Multimodal dataset detected. Found columns: {multimodal_info['multimodal_columns']}" ) # AUTO MODE: Keep format but standardize if needed if format_type == "auto": # Alpaca - keep as is if detected["format"] == "alpaca": return { "dataset": dataset, "detected_format": "alpaca", "final_format": "alpaca", "chat_column": None, "is_standardized": True, "requires_manual_mapping": False, "is_image": multimodal_info["is_image"], "multimodal_info": multimodal_info, "warnings": [], } # ShareGPT - needs standardization elif detected["format"] == "sharegpt": try: standardized = standardize_chat_format( dataset, tokenizer, aliases_for_system, aliases_for_user, aliases_for_assistant, batch_size, num_proc, ) return { "dataset": standardized, "detected_format": "sharegpt", "final_format": f"chatml_{detected['chat_column']}", "chat_column": detected["chat_column"], "is_standardized": True, "requires_manual_mapping": False, "is_image": multimodal_info["is_image"], "multimodal_info": multimodal_info, "warnings": [], } except Exception as e: warnings.append(f"Failed to standardize ShareGPT format: {e}") return { "dataset": dataset, "detected_format": "sharegpt", "final_format": "sharegpt", "chat_column": detected["chat_column"], "is_standardized": False, "requires_manual_mapping": True, "is_image": multimodal_info["is_image"], "multimodal_info": multimodal_info, "warnings": warnings, } elif detected["format"] == "chatml" and detected["chat_column"] in [ "conversations", "messages", "texts", ]: return { "dataset": dataset, "detected_format": f"chatml_{detected['chat_column']}", "final_format": f"chatml_{detected['chat_column']}", "chat_column": detected["chat_column"], "is_standardized": True, "requires_manual_mapping": False, "is_image": multimodal_info["is_image"], "multimodal_info": multimodal_info, "warnings": warnings, } # Unknown - try standardization, if fails pass as is else: warnings.append( f"Unknown format detected. Keys found: {detected['sample_keys']}" ) # NEW: Try heuristic detection if auto_detect_custom: custom_mapping = detect_custom_format_heuristic(dataset) if custom_mapping: warnings.append(f"Auto-detected column mapping: {custom_mapping}") def _apply_auto_mapping(examples): conversations = [] num_examples = len(examples[list(examples.keys())[0]]) # Preserve non-mapped columns all_columns = set(examples.keys()) mapped_columns = set(custom_mapping.keys()) preserved_columns = { col: examples[col] for col in all_columns - mapped_columns } for i in range(num_examples): convo = [] for target_role in ["system", "user", "assistant"]: for col_name, role in custom_mapping.items(): if role == target_role and col_name in examples: content = examples[col_name][i] if content and str(content).strip(): convo.append( {"role": role, "content": str(content)} ) conversations.append(convo) return {"conversations": conversations, **preserved_columns} try: dataset = dataset.map( _apply_auto_mapping, batched = True, batch_size = batch_size ) return { "dataset": dataset, "detected_format": "unknown", "final_format": "chatml_conversations", "chat_column": "conversations", "is_standardized": True, "requires_manual_mapping": False, "is_image": multimodal_info["is_image"], "multimodal_info": multimodal_info, "warnings": warnings, } except Exception as e: warnings.append(f"Auto-detection failed: {e}") # Try standardization as a last resort if detected["chat_column"]: try: standardized = standardize_chat_format( dataset, tokenizer, aliases_for_system, aliases_for_user, aliases_for_assistant, batch_size, num_proc, ) warnings.append("Successfully standardized unknown format") return { "dataset": standardized, "detected_format": "unknown", "final_format": f"chatml_{detected['chat_column']}", "chat_column": detected["chat_column"], "is_standardized": True, "requires_manual_mapping": False, "is_image": multimodal_info["is_image"], "multimodal_info": multimodal_info, "warnings": warnings, } except Exception as e: warnings.append( f"Could not standardize: {e}. Passing dataset as-is." ) # Return as-is with warnings return { "dataset": dataset, "detected_format": "unknown", "final_format": "unknown", "chat_column": detected["chat_column"], "is_standardized": False, "requires_manual_mapping": True, "is_image": multimodal_info["is_image"], "multimodal_info": multimodal_info, "warnings": warnings, } # ALPACA MODE: Convert to Alpaca elif format_type == "alpaca": if detected["format"] == "alpaca": return { "dataset": dataset, "detected_format": "alpaca", "final_format": "alpaca", "chat_column": None, "is_standardized": True, "requires_manual_mapping": False, "is_image": multimodal_info["is_image"], "multimodal_info": multimodal_info, "warnings": [], } elif detected["format"] in ["sharegpt", "chatml"]: # First standardize if ShareGPT if detected["format"] == "sharegpt": dataset = standardize_chat_format( dataset, tokenizer, aliases_for_system, aliases_for_user, aliases_for_assistant, batch_size, num_proc, ) # Then convert to Alpaca converted = convert_chatml_to_alpaca(dataset, batch_size, num_proc) return { "dataset": converted, "detected_format": detected["format"], "final_format": "alpaca", "chat_column": None, "is_standardized": True, "requires_manual_mapping": False, "is_image": multimodal_info["is_image"], "multimodal_info": multimodal_info, "warnings": [], } else: warnings.append(f"Cannot convert unknown format to Alpaca") return { "dataset": dataset, "detected_format": "unknown", "final_format": "unknown", "chat_column": detected["chat_column"], "is_standardized": False, "requires_manual_mapping": True, "is_image": multimodal_info["is_image"], "multimodal_info": multimodal_info, "warnings": warnings, } # CHATML MODE: Convert to ChatML elif format_type in ["chatml", "conversational", "sharegpt"]: if detected["format"] == "alpaca": converted = convert_alpaca_to_chatml(dataset, batch_size, num_proc) return { "dataset": converted, "detected_format": "alpaca", "final_format": "chatml_conversations", "chat_column": "conversations", "is_standardized": True, "requires_manual_mapping": False, "is_image": multimodal_info["is_image"], "multimodal_info": multimodal_info, "warnings": [], } elif detected["format"] == "sharegpt": standardized = standardize_chat_format( dataset, tokenizer, aliases_for_system, aliases_for_user, aliases_for_assistant, batch_size, num_proc, ) return { "dataset": standardized, "detected_format": "sharegpt", "final_format": f"chatml_{detected['chat_column']}", "chat_column": detected["chat_column"], "is_standardized": True, "requires_manual_mapping": False, "is_image": multimodal_info["is_image"], "multimodal_info": multimodal_info, "warnings": [], } elif detected["format"] == "chatml": return { "dataset": dataset, "detected_format": f"chatml_{detected['chat_column']}", "final_format": f"chatml_{detected['chat_column']}", "chat_column": detected["chat_column"], "is_standardized": True, "requires_manual_mapping": False, "is_image": multimodal_info["is_image"], "multimodal_info": multimodal_info, "warnings": [], } else: warnings.append(f"Unknown format, attempting standardization") if detected["chat_column"]: try: standardized = standardize_chat_format( dataset, tokenizer, aliases_for_system, aliases_for_user, aliases_for_assistant, batch_size, num_proc, ) return { "dataset": standardized, "detected_format": "unknown", "final_format": f"chatml_{detected['chat_column']}", "chat_column": detected["chat_column"], "is_standardized": True, "requires_manual_mapping": False, "is_image": multimodal_info["is_image"], "multimodal_info": multimodal_info, "warnings": warnings, } except Exception as e: warnings.append(f"Standardization failed: {e}") return { "dataset": dataset, "detected_format": "unknown", "final_format": "unknown", "chat_column": detected["chat_column"], "is_standardized": False, "requires_manual_mapping": True, "is_image": multimodal_info["is_image"], "multimodal_info": multimodal_info, "warnings": warnings, } else: raise ValueError(f"Unknown format_type: {format_type}") def format_and_template_dataset( dataset, model_name, tokenizer, is_vlm = False, format_type = "auto", # VLM-specific parameters vlm_instruction = None, # Now optional - will auto-generate vlm_text_column = None, vlm_image_column = None, dataset_name = None, custom_prompt_template = None, add_eos_token = False, remove_bos_prefix = False, custom_format_mapping = None, auto_detect_custom = True, auto_detect_mapping = True, aliases_for_system = [ "system", ], aliases_for_user = [ "user", "human", "input", ], aliases_for_assistant = [ "gpt", "assistant", "output", ], batch_size = 1000, num_proc = None, progress_callback = None, ): """ Convenience function that combines format_dataset and apply_chat_template_to_dataset. Perfect for UI workflows - one function does everything! Returns: dict: { "dataset": Final dataset with 'text' column, "detected_format": Original format, "final_format": Format after processing, "success": Whether template application succeeded, "requires_manual_mapping": True if format detection failed and user must map columns, "warnings": List of warnings, "errors": List of errors, "summary": Human-readable summary } """ # VLM FLOW if is_vlm: warnings = [] errors = [] multimodal_info = detect_multimodal_dataset(dataset) # NEW: If user provided explicit mapping for VLM, use it directly if custom_format_mapping: # Expect mapping like: {"image_col": "image", "caption_col": "text"} user_vlm_image_column = None user_vlm_text_column = None for col, role in custom_format_mapping.items(): if role == "image": user_vlm_image_column = col elif role in ["text", "user", "caption", "assistant"]: user_vlm_text_column = col if user_vlm_image_column and user_vlm_text_column: try: dataset = convert_to_vlm_format( dataset, instruction = vlm_instruction, text_column = user_vlm_text_column, image_column = user_vlm_image_column, dataset_name = dataset_name, progress_callback = progress_callback, ) warnings.append( f"Applied user VLM mapping: image='{user_vlm_image_column}', text='{user_vlm_text_column}'" ) return { "dataset": dataset, "detected_format": "user_mapped", "final_format": "vlm_messages", "chat_column": "messages", "is_vlm": True, "is_image": True, "multimodal_info": multimodal_info, "success": True, "requires_manual_mapping": False, "warnings": warnings, "errors": [], } except Exception as e: # User mapping failed — fall back to auto-detection instead # of giving up (handles stale cached mappings gracefully) warnings.append( f"User VLM mapping (image='{user_vlm_image_column}', " f"text='{user_vlm_text_column}') failed: {e} — " f"falling back to auto-detection" ) logger.info( f"⚠️ User VLM mapping failed, falling back to auto-detection..." ) custom_format_mapping = None # clear so auto-detection runs below else: errors.append( f"Invalid VLM mapping: need 'image' and 'text' roles. Got: {custom_format_mapping}" ) return { "dataset": dataset, "detected_format": "user_mapped", "final_format": "vlm_unknown", "is_vlm": True, "success": False, "requires_manual_mapping": True, "warnings": warnings, "errors": errors, } # Auto-detect VLM structure vlm_structure = detect_vlm_dataset_structure(dataset) # Handle Llava format if vlm_structure["format"] == "vlm_messages_llava": try: dataset = convert_llava_to_vlm_format(dataset) warnings.append( "Converted from Llava format (image indices) to standard VLM format" ) except Exception as e: errors.append(f"Failed to convert Llava format: {e}") import traceback traceback.print_exc() return { "dataset": dataset, "detected_format": "vlm_messages_llava", "final_format": "vlm_conversion_failed", "is_vlm": True, "success": False, "requires_manual_mapping": True, "warnings": warnings, "errors": errors, } # Handle ShareGPT/ChatML + image column (e.g. ShareGPT4V, LLaVA-style) elif vlm_structure["format"] == "sharegpt_with_images": try: dataset = convert_sharegpt_with_images_to_vlm_format( dataset, image_column = vlm_structure["image_column"], messages_column = vlm_structure["messages_column"], dataset_name = dataset_name, progress_callback = progress_callback, ) warnings.append( "Converted from ShareGPT+image format to standard VLM format" ) except Exception as e: errors.append(f"Failed to convert ShareGPT+image format: {e}") import traceback traceback.print_exc() return { "dataset": dataset, "detected_format": "sharegpt_with_images", "final_format": "vlm_conversion_failed", "is_vlm": True, "success": False, "requires_manual_mapping": True, "warnings": warnings, "errors": errors, } # Handle simple format elif vlm_structure["needs_conversion"]: if vlm_text_column is None: vlm_text_column = vlm_structure["text_column"] if vlm_image_column is None: vlm_image_column = vlm_structure["image_column"] if vlm_text_column is None or vlm_image_column is None: columns = list(next(iter(dataset)).keys()) if dataset else [] issues = [ f"Could not auto-detect image and text columns from: {columns}", f"VLM structure detected: {vlm_structure.get('format', 'unknown')}", ] friendly = None try: from .llm_assist import llm_generate_dataset_warning friendly = llm_generate_dataset_warning( issues, dataset_name = dataset_name, modality = "vision", column_names = columns, ) except Exception: pass errors.append( friendly or f"Could not auto-detect image/text columns. Found: {vlm_structure}. " ) return { "dataset": dataset, "detected_format": "vlm_unknown", "final_format": "vlm_unknown", "is_vlm": True, "success": False, "requires_manual_mapping": True, "warnings": warnings, "errors": errors, } try: dataset = convert_to_vlm_format( dataset, instruction = vlm_instruction, text_column = vlm_text_column, image_column = vlm_image_column, dataset_name = dataset_name, progress_callback = progress_callback, ) if vlm_instruction: warnings.append( f"Using user-provided instruction: '{vlm_instruction}'" ) else: warnings.append( "Auto-generated instruction based on dataset analysis" ) except Exception as e: errors.append(f"Failed to convert to VLM format: {e}") import traceback traceback.print_exc() return { "dataset": dataset, "detected_format": vlm_structure["format"], "final_format": "vlm_conversion_failed", "is_vlm": True, "success": False, "requires_manual_mapping": True, "warnings": warnings, "errors": errors, } # Already in standard VLM format elif vlm_structure["format"] == "vlm_messages": dataset = [sample for sample in dataset] warnings.append("Dataset already in standard VLM messages format") # Return as list return { "dataset": dataset, "detected_format": vlm_structure["format"], "final_format": "vlm_messages", "chat_column": "messages", "is_vlm": True, "is_image": multimodal_info["is_image"], "multimodal_info": multimodal_info, "vlm_structure": vlm_structure, "success": True, "requires_manual_mapping": False, "warnings": warnings, "errors": errors, } # LLM FLOW (Existing code) else: # Step 1: Format the dataset n_rows = len(dataset) if hasattr(dataset, "__len__") else None if progress_callback and n_rows: progress_callback(status_message = f"Formatting dataset ({n_rows:,} rows)...") dataset_info = format_dataset( dataset, format_type = format_type, tokenizer = tokenizer, auto_detect_custom = auto_detect_custom, custom_format_mapping = custom_format_mapping, aliases_for_system = aliases_for_system, aliases_for_user = aliases_for_user, aliases_for_assistant = aliases_for_assistant, batch_size = batch_size, num_proc = num_proc, ) # Step 2: Apply chat template detected = dataset_info.get("detected_format", "unknown") if progress_callback and n_rows: progress_callback( status_message = f"Applying chat template to {detected} ({n_rows:,} rows)..." ) # Gemma emits a leading that must be stripped for text-only chatml/sharegpt. is_alpaca = format_type == "alpaca" or ( format_type == "auto" and dataset_info["detected_format"] == "alpaca" ) is_gemma = "gemma" in model_name.lower() if is_gemma and not dataset_info["is_image"] and not is_alpaca: remove_bos_prefix = True template_result = apply_chat_template_to_dataset( dataset_info = dataset_info, tokenizer = tokenizer, model_name = model_name, custom_prompt_template = custom_prompt_template, add_eos_token = add_eos_token, remove_bos_prefix = remove_bos_prefix, custom_format_mapping = custom_format_mapping, auto_detect_mapping = auto_detect_mapping, batch_size = batch_size, num_proc = num_proc, progress_callback = progress_callback, ) # Step 3: Generate summary summary = get_dataset_info_summary(dataset_info) # Combine results all_warnings = dataset_info.get("warnings", []) + template_result.get( "warnings", [] ) all_errors = template_result.get("errors", []) # If format_dataset returned "unknown" but apply_chat_template rescued # it via heuristic detection, update final_format to reflect reality. final_format = dataset_info["final_format"] requires_manual = dataset_info.get("requires_manual_mapping", False) if final_format == "unknown" and template_result["success"]: out_ds = template_result["dataset"] if hasattr(out_ds, "column_names") and "text" in out_ds.column_names: final_format = "chatml_conversations" requires_manual = False return { "dataset": template_result["dataset"], "detected_format": dataset_info["detected_format"], "final_format": final_format, "chat_column": dataset_info.get("chat_column"), "is_vlm": False, # This is LLM flow "success": template_result["success"], "requires_manual_mapping": requires_manual, "warnings": all_warnings, "errors": all_errors, "summary": summary, } ================================================ FILE: studio/backend/utils/datasets/format_conversion.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Format conversion utilities for dataset processing. This module contains functions for converting between dataset formats (Alpaca, ShareGPT, ChatML) and standardizing chat formats. """ import os from datasets import IterableDataset from loggers import get_logger logger = get_logger(__name__) def standardize_chat_format( dataset, tokenizer = None, aliases_for_system = [ "system", ], aliases_for_user = [ "user", "human", "input", ], aliases_for_assistant = [ "gpt", "assistant", "output", ], batch_size = 1000, num_proc = None, ): """ Our own standardization function that handles BOTH messages and conversations. Converts non-standard role names and keys to standard format. """ import collections import itertools from datasets import IterableDataset # Check if vision tokenizer is used is_vlm = False if tokenizer is not None: if hasattr(tokenizer, "image_processor") or hasattr(tokenizer, "tokenizer"): is_vlm = True column_names = set(next(iter(dataset)).keys()) # Check for both 'conversations' and 'messages' chat_column = None if "conversations" in column_names: chat_column = "conversations" elif "messages" in column_names: chat_column = "messages" elif "texts" in column_names: chat_column = "texts" else: return dataset # No chat column found # Inspect structure examples = itertools.islice(dataset, 10) uniques = collections.defaultdict(list) for example in examples: for message in example[chat_column]: for key, value in message.items(): if type(value) is not str: continue # Skip non-string values uniques[key].append(value) if len(uniques.keys()) != 2: return dataset # Unexpected structure keys = list(uniques.keys()) length_first = len(set(uniques[keys[0]])) length_second = len(set(uniques[keys[1]])) # Determine which is role and which is content if length_first < length_second: role_key = keys[0] content_key = keys[1] else: role_key = keys[1] content_key = keys[0] # Mapping for aliases aliases_mapping = {} for x in aliases_for_system: aliases_mapping[x] = "system" for x in aliases_for_user: aliases_mapping[x] = "user" for x in aliases_for_assistant: aliases_mapping[x] = "assistant" def _standardize_dataset(examples): convos = examples[chat_column] all_convos = [] for convo in convos: new_convo = [] for message in convo: # Get original role and content original_role = message.get(role_key, "") original_content = message.get(content_key, "") # Map to standard role name standard_role = aliases_mapping.get(original_role, original_role) # Handle VLM format if is_vlm: original_content = [{"type": "text", "text": original_content}] # Create dict with EXPLICIT ORDER new_message = {"role": standard_role, "content": original_content} new_convo.append(new_message) all_convos.append(new_convo) return {chat_column: all_convos} dataset_map_kwargs = { "batched": True, "batch_size": batch_size, } if not isinstance(dataset, IterableDataset): from utils.hardware import safe_num_proc if num_proc is None or type(num_proc) is not int: num_proc = safe_num_proc() else: num_proc = safe_num_proc(num_proc) dataset_map_kwargs["num_proc"] = num_proc dataset_map_kwargs["desc"] = "Standardizing chat format" return dataset.map(_standardize_dataset, **dataset_map_kwargs) def convert_chatml_to_alpaca(dataset, batch_size = 1000, num_proc = None): """ Converts ChatML format (messages OR conversations) to Alpaca format. Handles both standardized and ShareGPT formats. Supports: - "messages" or "conversations" column - "role"/"content" (standard) or "from"/"value" (ShareGPT) """ from torch.utils.data import IterableDataset def _convert(examples): # Auto-detect which column name is used chatml_data = ( examples.get("messages") or examples.get("conversations") or examples.get("texts") ) if chatml_data is None: raise ValueError( "No 'messages' or 'conversations' or 'texts' column found." ) instructions = [] outputs = [] inputs = [] for convo in chatml_data: instruction = "" output = "" for msg in convo: # Handle both standard and ShareGPT formats role = msg.get("role") or msg.get("from") content = msg.get("content") or msg.get("value") # Get first user message as instruction if role in ["user", "human", "input"] and not instruction: instruction = content # Get first assistant message as output elif role in ["assistant", "gpt", "output"] and not output: output = content break # Stop after first assistant response instructions.append(instruction) inputs.append("") # Alpaca typically has empty input outputs.append(output) return {"instruction": instructions, "input": inputs, "output": outputs} dataset_map_kwargs = { "batched": True, "batch_size": batch_size, } if not isinstance(dataset, IterableDataset): from utils.hardware import safe_num_proc if num_proc is None or type(num_proc) is not int: num_proc = safe_num_proc() else: num_proc = safe_num_proc(num_proc) dataset_map_kwargs["num_proc"] = num_proc dataset_map_kwargs["desc"] = "Converting ChatML to Alpaca format" return dataset.map(_convert, **dataset_map_kwargs) def convert_alpaca_to_chatml(dataset, batch_size = 1000, num_proc = None): """ Converts Alpaca format to ChatML format. Output format: Uses 'conversations' column with standard 'role'/'content' structure. """ from torch.utils.data import IterableDataset def _convert(examples): conversations = [] for i in range(len(examples["instruction"])): instruction = examples["instruction"][i] input_text = examples.get("input", [""] * len(examples["instruction"]))[i] output = examples["output"][i] # Combine instruction and input (if exists) for user message if input_text and input_text.strip(): user_content = f"{instruction}\n\n{input_text}".strip() else: user_content = instruction # Build conversation in standard ChatML format convo = [ {"role": "user", "content": user_content}, {"role": "assistant", "content": output}, ] conversations.append(convo) return {"conversations": conversations} dataset_map_kwargs = { "batched": True, "batch_size": batch_size, } if not isinstance(dataset, IterableDataset): from utils.hardware import safe_num_proc if num_proc is None or type(num_proc) is not int: num_proc = safe_num_proc() else: num_proc = safe_num_proc(num_proc) dataset_map_kwargs["num_proc"] = num_proc dataset_map_kwargs["desc"] = "Converting Alpaca to ChatML format" return dataset.map(_convert, **dataset_map_kwargs) def _format_eta(seconds): """Format seconds into a human-readable ETA string.""" if seconds < 60: return f"{seconds:.0f}s" elif seconds < 3600: m, s = divmod(int(seconds), 60) return f"{m}m {s}s" else: h, remainder = divmod(int(seconds), 3600) m, _ = divmod(remainder, 60) return f"{h}h {m}m" def convert_to_vlm_format( dataset, instruction = None, text_column = "text", image_column = "image", dataset_name = None, progress_callback = None, ): """ Converts simple {image, text} format to VLM messages format. Returns a LIST, not a HuggingFace Dataset (to preserve PIL Images). For URL-based image datasets, runs a 200-sample parallel probe first to estimate download speed and failure rate, then reports time estimate or warning through progress_callback before proceeding with the full conversion. Args: progress_callback: Optional callable(status_message=str) to report progress to the training overlay. Returns: list: List of dicts with 'messages' field """ from PIL import Image from .vlm_processing import generate_smart_vlm_instruction def _notify(msg): """Send status update to the training overlay if callback is available.""" if progress_callback: progress_callback(status_message = msg) # Generate smart instruction if not provided if instruction is None: instruction_info = generate_smart_vlm_instruction( dataset, text_column = text_column, image_column = image_column, dataset_name = dataset_name, ) instruction = instruction_info["instruction"] instruction_column = instruction_info.get("instruction_column") uses_dynamic = instruction_info["uses_dynamic_instruction"] logger.info( f"📝 Auto-detected instruction type: {instruction_info['instruction_type']}" ) logger.info(f"📝 Confidence: {instruction_info['confidence']:.2f}") if not uses_dynamic: logger.info(f"📝 Using instruction: '{instruction}'") else: logger.info( f"📝 Using dynamic instructions from column: '{instruction_column}'" ) else: instruction_column = None uses_dynamic = False def _convert_single_sample(sample): """Convert a single sample to VLM format.""" # Get image (might be PIL Image, local path, URL, or bare filename) image_data = sample[image_column] if isinstance(image_data, str): if image_data.startswith(("http://", "https://")): import fsspec from io import BytesIO with fsspec.open(image_data, "rb", expand = True) as f: image_data = Image.open(BytesIO(f.read())).convert("RGB") elif _image_lookup is not None and image_data in _image_lookup: # Bare filename → resolve via HF repo lookup from huggingface_hub import hf_hub_download local_path = hf_hub_download( dataset_name, _image_lookup[image_data], repo_type = "dataset", ) image_data = Image.open(local_path).convert("RGB") else: image_data = Image.open(image_data).convert("RGB") # Get text (if list of strings, pick a random one — e.g. multiple captions) text_data = sample[text_column] if isinstance(text_data, list) and len(text_data) > 0: import random text_data = random.choice(text_data) # Get instruction (static or dynamic) if uses_dynamic and instruction_column: current_instruction = sample[instruction_column] else: current_instruction = instruction # Build VLM messages - simple structure messages = [ { "role": "user", "content": [ {"type": "text", "text": current_instruction}, {"type": "image", "image": image_data}, # PIL object ], }, {"role": "assistant", "content": [{"type": "text", "text": text_data}]}, ] # Return dict with messages return {"messages": messages} total = len(dataset) first_image = next(iter(dataset))[image_column] has_urls = isinstance(first_image, str) and first_image.startswith( ("http://", "https://") ) # ── Bare-filename detection: images stored as filenames (e.g. "img_001.png") # that don't exist locally. Build a basename→repo_path lookup so we can # resolve them via hf_hub_download during conversion. _image_lookup = None _IMAGE_EXTS = (".png", ".jpg", ".jpeg", ".webp", ".gif", ".bmp", ".tiff") if ( not has_urls and isinstance(first_image, str) and not os.path.exists(first_image) and dataset_name ): try: from huggingface_hub import HfApi _notify("Resolving image filenames from HF repo...") logger.info( f"🔍 Image column contains bare filenames (e.g. '{first_image}') — building repo lookup..." ) repo_files = HfApi().list_repo_files(dataset_name, repo_type = "dataset") _image_lookup = { os.path.basename(f): f for f in repo_files if any(f.lower().endswith(ext) for ext in _IMAGE_EXTS) } if first_image in _image_lookup: logger.info( f"✅ Matched {len(_image_lookup)} image files in repo (e.g. '{first_image}' → '{_image_lookup[first_image]}')" ) else: logger.info( f"⚠️ Built lookup with {len(_image_lookup)} images but '{first_image}' not found — falling back to local open" ) _image_lookup = None except Exception as e: logger.info(f"⚠️ Failed to build HF repo image lookup: {e}") _image_lookup = None # ── URL probe: 200 samples with parallel workers to estimate speed + failure rate ── PROBE_SIZE = 200 MAX_FAIL_RATE = 0.3 if has_urls and total > PROBE_SIZE: import time from concurrent.futures import ThreadPoolExecutor, as_completed from utils.hardware import safe_num_proc num_workers = safe_num_proc() _notify(f"Probing {PROBE_SIZE} image URLs with {num_workers} workers...") logger.info( f"🔍 Probing {PROBE_SIZE}/{total} image URLs with {num_workers} workers..." ) probe_samples = [dataset[i] for i in range(PROBE_SIZE)] probe_ok = 0 probe_fail = 0 probe_start = time.time() with ThreadPoolExecutor(max_workers = num_workers) as executor: futures = { executor.submit(_convert_single_sample, s): s for s in probe_samples } for future in as_completed(futures): try: future.result() probe_ok += 1 except Exception: probe_fail += 1 probe_elapsed = time.time() - probe_start probe_total = probe_ok + probe_fail fail_rate = probe_fail / probe_total if probe_total > 0 else 0 throughput = probe_total / probe_elapsed if probe_elapsed > 0 else 0 if fail_rate >= MAX_FAIL_RATE: issues = [ f"{fail_rate:.0%} of the first {PROBE_SIZE} image URLs failed to download ({probe_fail}/{probe_total})", "Images are external URLs, not embedded in the dataset", ] # Try LLM-friendly warning friendly = None try: from .llm_assist import llm_generate_dataset_warning friendly = llm_generate_dataset_warning( issues, dataset_name = dataset_name, modality = "vision", column_names = [image_column, text_column], ) except Exception: pass msg = friendly or ( f"⚠️ {fail_rate:.0%} of the first {PROBE_SIZE} images failed to download " f"({probe_fail}/{probe_total}). " "This dataset has too many broken or unreachable image URLs. " "Consider using a dataset with embedded images instead." ) logger.info(msg) _notify(msg) raise ValueError(msg) # Estimate total time for remaining samples remaining = total - PROBE_SIZE estimated_seconds = remaining / throughput if throughput > 0 else 0 eta_str = _format_eta(estimated_seconds) info_msg = ( f"Downloading {total:,} images ({num_workers} workers, ~{throughput:.1f} img/s). " f"Estimated time: ~{eta_str}" ) if probe_fail > 0: info_msg += f" | {fail_rate:.0%} broken URLs will be skipped" logger.info( f"✅ Probe passed: {probe_ok}/{probe_total} ok, {probe_fail} failed ({fail_rate:.0%}), {throughput:.1f} img/s" ) logger.info(f"⏱️ Estimated time for {total:,} samples: ~{eta_str}") _notify(info_msg) # ── Full conversion with progress ── from tqdm import tqdm logger.info(f"🔄 Converting {total} samples to VLM format...") converted_list = [] failed_count = 0 if has_urls: # Parallel conversion for URL-based datasets import time from concurrent.futures import ThreadPoolExecutor, as_completed from utils.hardware import safe_num_proc num_workers = safe_num_proc() batch_size = 500 start_time = time.time() for batch_start in range(0, total, batch_size): batch_end = min(batch_start + batch_size, total) batch_samples = [dataset[i] for i in range(batch_start, batch_end)] with ThreadPoolExecutor(max_workers = num_workers) as executor: futures = { executor.submit(_convert_single_sample, s): i for i, s in enumerate(batch_samples) } batch_results = [None] * len(batch_samples) for future in as_completed(futures): idx = futures[future] try: batch_results[idx] = future.result() except Exception as e: failed_count += 1 if failed_count == 1: print( f"⚠️ First VLM conversion failure: {type(e).__name__}: {e}" ) if failed_count == 1: logger.info( f"⚠️ First VLM conversion failure: {type(e).__name__}: {e}" ) converted_list.extend(r for r in batch_results if r is not None) # Progress update every batch elapsed = time.time() - start_time done = batch_end rate = done / elapsed if elapsed > 0 else 0 remaining_time = (total - done) / rate if rate > 0 else 0 eta_str = _format_eta(remaining_time) progress_msg = f"Downloading images: {done:,}/{total:,} ({done*100//total}%) | ~{eta_str} remaining | {failed_count} skipped" logger.info( f" [{done}/{total}] {rate:.1f} img/s, {failed_count} failed, ETA {eta_str}" ) _notify(progress_msg) else: # Sequential conversion for local/embedded images (fast, no I/O bottleneck) pbar = tqdm(dataset, total = total, desc = "Converting VLM samples", unit = "sample") for sample in pbar: try: converted_list.append(_convert_single_sample(sample)) except Exception as e: failed_count += 1 if failed_count == 1: # Log the first failure to aid debugging print(f"⚠️ First VLM conversion failure: {type(e).__name__}: {e}") if failed_count == 1: # Log the first failure to aid debugging logger.info( f"⚠️ First VLM conversion failure: {type(e).__name__}: {e}" ) pbar.set_postfix(ok = len(converted_list), failed = failed_count, refresh = False) pbar.close() if failed_count > 0: fail_rate = failed_count / total logger.info( f"⚠️ Skipped {failed_count}/{total} ({fail_rate:.0%}) samples with broken/unreachable images" ) # For datasets that skipped the probe (small URL datasets), check fail rate now if has_urls and fail_rate >= MAX_FAIL_RATE: issues = [ f"{fail_rate:.0%} of images failed to download ({failed_count}/{total})", "Images are external URLs, not embedded in the dataset", ] friendly = None try: from .llm_assist import llm_generate_dataset_warning friendly = llm_generate_dataset_warning( issues, dataset_name = dataset_name, modality = "vision", column_names = [image_column, text_column], ) except Exception: pass msg = friendly or ( f"⚠️ {fail_rate:.0%} of images failed to download ({failed_count}/{total}). " "This dataset has too many broken or unreachable image URLs. " "Consider using a dataset with embedded images instead." ) _notify(msg) raise ValueError(msg) if len(converted_list) == 0: issues = [ f"All {total} samples failed during VLM conversion — no usable images found", f"Image column '{image_column}' may contain URLs that are no longer accessible, " "or local file paths that don't exist", ] friendly = None try: from .llm_assist import llm_generate_dataset_warning friendly = llm_generate_dataset_warning( issues, dataset_name = dataset_name, modality = "vision", column_names = [image_column, text_column], ) except Exception: pass raise ValueError( friendly or ( f"All {total} samples failed during VLM conversion — no usable images found. " "This dataset may contain only image URLs that are no longer accessible." ) ) logger.info(f"✅ Converted {len(converted_list)}/{total} samples") _notify(f"Converted {len(converted_list):,}/{total:,} images successfully") # Return list, NOT Dataset return converted_list def convert_sharegpt_with_images_to_vlm_format( dataset, image_column = "image", messages_column = "conversations", dataset_name = None, progress_callback = None, ): """ Converts ShareGPT/ChatML datasets that have a separate image column and ```` placeholders inside the conversation text. Example input:: { "image": "sam/images/sa_545504.jpg", "conversations": [ {"from": "human", "value": "\\nWhat is this photo about?"}, {"from": "gpt", "value": "The image captures..."} ] } Returns a list of dicts in standard VLM messages format (PIL Images inline). """ from PIL import Image from tqdm import tqdm _IMAGE_EXTS = (".png", ".jpg", ".jpeg", ".webp", ".gif", ".bmp", ".tiff") _ROLE_MAP = { "human": "user", "user": "user", "gpt": "assistant", "assistant": "assistant", "system": "system", } def _notify(msg): if progress_callback: progress_callback(status_message = msg) # ── Resolve image loading strategy (same 3-tier as convert_to_vlm_format) ── total = len(dataset) first_image = next(iter(dataset))[image_column] _image_lookup = None if ( isinstance(first_image, str) and not first_image.startswith(("http://", "https://")) and not os.path.exists(first_image) and dataset_name ): try: from huggingface_hub import HfApi _notify("Resolving image filenames from HF repo...") logger.info( f"🔍 Image column contains bare filenames (e.g. '{first_image}') — building repo lookup..." ) repo_files = HfApi().list_repo_files(dataset_name, repo_type = "dataset") _image_lookup = { os.path.basename(f): f for f in repo_files if any(f.lower().endswith(ext) for ext in _IMAGE_EXTS) } # Also add the full relative paths as keys (for paths like "sam/images/sa_545504.jpg") for f in repo_files: if any(f.lower().endswith(ext) for ext in _IMAGE_EXTS): _image_lookup[f] = f if first_image in _image_lookup: logger.info( f"✅ Matched {len(_image_lookup)} image files in repo (e.g. '{first_image}' → '{_image_lookup[first_image]}')" ) else: logger.info( f"⚠️ Built lookup with {len(_image_lookup)} images but '{first_image}' not found — falling back to local open" ) _image_lookup = None except Exception as e: logger.info(f"⚠️ Failed to build HF repo image lookup: {e}") _image_lookup = None def _resolve_image(image_data): """Resolve image data to a PIL Image object.""" if hasattr(image_data, "size") and hasattr(image_data, "mode"): return image_data # Already PIL if isinstance(image_data, str): if image_data.startswith(("http://", "https://")): import fsspec from io import BytesIO with fsspec.open(image_data, "rb", expand = True) as f: return Image.open(BytesIO(f.read())).convert("RGB") elif _image_lookup is not None and image_data in _image_lookup: from huggingface_hub import hf_hub_download local_path = hf_hub_download( dataset_name, _image_lookup[image_data], repo_type = "dataset", ) return Image.open(local_path).convert("RGB") else: return Image.open(image_data).convert("RGB") if isinstance(image_data, dict) and ( "bytes" in image_data or "path" in image_data ): if image_data.get("bytes"): from io import BytesIO return Image.open(BytesIO(image_data["bytes"])).convert("RGB") if image_data.get("path"): return Image.open(image_data["path"]).convert("RGB") raise ValueError(f"Cannot resolve image: {type(image_data)}") def _convert_single_sample(sample): """Convert a single ShareGPT+image sample to standard VLM format.""" pil_image = _resolve_image(sample[image_column]) conversation = sample[messages_column] new_messages = [] for msg in conversation: role_raw = msg.get("from") or msg.get("role", "user") role = _ROLE_MAP.get(role_raw.lower(), role_raw.lower()) text = msg.get("value") or msg.get("content") or "" # Split on to interleave text and image content blocks if "" in text: parts = text.split("") content = [] for i, part in enumerate(parts): part = part.strip() if part: content.append({"type": "text", "text": part}) if i < len(parts) - 1: content.append({"type": "image", "image": pil_image}) # If was the entire text, content might just be the image if not content: content.append({"type": "image", "image": pil_image}) else: content = [{"type": "text", "text": text}] new_messages.append({"role": role, "content": content}) return {"messages": new_messages} # ── Full conversion with progress ── logger.info(f"🔄 Converting {total} samples from ShareGPT+image format...") converted_list = [] failed_count = 0 pbar = tqdm(dataset, total = total, desc = "Converting ShareGPT+image", unit = "sample") for sample in pbar: try: converted_list.append(_convert_single_sample(sample)) except Exception as e: failed_count += 1 if failed_count == 1: logger.info(f"⚠️ First conversion failure: {type(e).__name__}: {e}") pbar.set_postfix(ok = len(converted_list), failed = failed_count, refresh = False) pbar.close() if failed_count > 0: logger.info( f"⚠️ Skipped {failed_count}/{total} ({failed_count*100//total}%) samples" ) if len(converted_list) == 0: raise ValueError( f"All {total} samples failed during ShareGPT+image conversion — " "no usable samples found." ) logger.info(f"✅ Converted {len(converted_list)}/{total} samples") _notify(f"Converted {len(converted_list):,}/{total:,} samples successfully") return converted_list def convert_llava_to_vlm_format(dataset): """ Converts Llava format to standard VLM format. Llava format: - messages: [{'content': [{'type': 'image', 'index': 0}, {'type': 'text', 'text': '...'}]}] - images: [PIL_Image1, PIL_Image2, ...] Standard VLM format: - messages: [{'content': [{'type': 'image', 'image': PIL_Image}, {'type': 'text', 'text': '...'}]}] """ from PIL import Image logger.info( f"🔄 Converting {len(dataset)} samples from Llava format to standard VLM format..." ) def _convert_single_sample(sample): """Convert a single llava sample to standard VLM format.""" messages = sample["messages"] images = sample.get("images", []) # Process each message new_messages = [] for msg in messages: new_content = [] for item in msg["content"]: if item["type"] == "image": # Replace index with actual PIL image if "index" in item and item["index"] is not None: img_idx = item["index"] if img_idx < len(images): pil_image = images[img_idx] # Ensure it's PIL if isinstance(pil_image, str): pil_image = Image.open(pil_image).convert("RGB") new_content.append( { "type": "image", "image": pil_image, # Actual PIL object } ) else: # No index, try to use first image if len(images) > 0: pil_image = images[0] if isinstance(pil_image, str): pil_image = Image.open(pil_image).convert("RGB") new_content.append({"type": "image", "image": pil_image}) elif item["type"] == "text": # Keep text as-is (only type + text) new_content.append({"type": "text", "text": item.get("text", "")}) new_messages.append({"role": msg["role"], "content": new_content}) return {"messages": new_messages} # Convert using list comprehension converted_list = [_convert_single_sample(sample) for sample in dataset] logger.info(f"✅ Converted {len(converted_list)} samples") return converted_list ================================================ FILE: studio/backend/utils/datasets/format_detection.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Format detection utilities for dataset processing. This module contains functions for detecting dataset formats (Alpaca, ShareGPT, ChatML), detecting multimodal/VLM dataset structures, and heuristic-based column mapping. """ import re def _keyword_in_column(keyword: str, col_name: str) -> bool: """Word-boundary keyword match to avoid false positives like 'pic' in 'topic'.""" return ( re.search(r"\b" + re.escape(keyword) + r"\b", col_name, re.IGNORECASE) is not None ) def detect_dataset_format(dataset): """ Detects dataset format by inspecting structure. Returns: dict: { "format": "alpaca" | "sharegpt" | "chatml" | "unknown", "chat_column": "messages" | "conversations" | None, "needs_standardization": bool, "sample_keys": list of keys found in messages (for debugging) } """ column_names = set(next(iter(dataset)).keys()) # Check for Alpaca alpaca_columns = {"instruction", "output"} if alpaca_columns.issubset(column_names): return { "format": "alpaca", "chat_column": None, "needs_standardization": False, "sample_keys": [], } # Check for chat-based formats (messages or conversations) chat_column = None if "messages" in column_names: chat_column = "messages" elif "conversations" in column_names: chat_column = "conversations" elif "texts" in column_names: chat_column = "texts" if chat_column: # Inspect the structure to determine if ShareGPT or ChatML try: sample = next(iter(dataset)) chat_data = sample[chat_column] if chat_data and len(chat_data) > 0: first_msg = chat_data[0] msg_keys = set(first_msg.keys()) # ShareGPT uses "from" and "value" if "from" in msg_keys or "value" in msg_keys: return { "format": "sharegpt", "chat_column": chat_column, "needs_standardization": True, "sample_keys": list(msg_keys), } # ChatML uses "role" and "content" elif "role" in msg_keys and "content" in msg_keys: return { "format": "chatml", "chat_column": chat_column, "needs_standardization": False, "sample_keys": list(msg_keys), } # Unknown structure but has chat column else: return { "format": "unknown", "chat_column": chat_column, "needs_standardization": None, "sample_keys": list(msg_keys), } except Exception as e: return { "format": "unknown", "chat_column": chat_column, "needs_standardization": None, "sample_keys": [], "error": str(e), } # No recognized format return { "format": "unknown", "chat_column": None, "needs_standardization": None, "sample_keys": [], } def detect_custom_format_heuristic(dataset): """ Smart detection with priority scoring. Strategy for ambiguous keywords like 'task': 1. Detect assistant first (unambiguous) 2. Detect user using high-priority keywords first 3. Check REMAINING columns for system keywords (including 'task') 4. Only if no system match, use 'task' as fallback user """ sample = next(iter(dataset)) all_columns = list(sample.keys()) mapping = {} # Keywords assistant_words = [ "output", "answer", "response", "assistant", "completion", "expected", "recommendation", "reply", "result", "target", "solution", "explanation", "solve", ] # Split into high/low priority user_words_high_priority = [ "input", "question", "query", "prompt", "instruction", "request", "snippet", "user", "text", "problem", "exercise", ] user_words_low_priority = ["task"] # Ambiguous - can be user OR system user_words = user_words_high_priority + user_words_low_priority system_words = [ "system", "context", "description", "persona", "role", "template", "task", # Also in system ] # Metadata columns to ignore metadata_exact_match = { "id", "idx", "index", "key", "timestamp", "date", "metadata", "source", "kind", "type", "category", "score", "label", "tag", "inference_mode", } metadata_prefix_patterns = [ "problem_type", "problem_source", "generation_model", "pass_rate", ] priority_patterns = { "generated": 100, "gen_": 90, "model_": 80, "predicted": 70, "completion": 60, } def has_keyword(col_name, keywords): """Check if any keyword appears in column name.""" col_lower = col_name.lower() col_normalized = col_lower.replace("_", "").replace("-", "").replace(" ", "") for keyword in keywords: if keyword in col_lower or keyword in col_normalized: return True return False def is_metadata(col_name): """Check if column is likely metadata.""" col_lower = col_name.lower() if col_lower in metadata_exact_match: return True if col_lower in metadata_prefix_patterns: return True for pattern in metadata_prefix_patterns: if ( col_lower.startswith(pattern.split("_")[0] + "_") and col_lower != pattern ): if "_" in col_lower: prefix = col_lower.split("_")[0] if prefix in ["generation", "pass", "inference"]: return True if len(col_lower) <= 2 and not col_lower in ["qa", "q", "a"]: return True return False def get_priority_score(col_name): """Calculate priority score based on column name patterns.""" col_lower = col_name.lower() score = 0 for pattern, pattern_score in priority_patterns.items(): if pattern in col_lower: score += pattern_score return score def get_content_length(col_name): """Get average content length for this column.""" try: if col_name in sample and sample[col_name]: content = str(sample[col_name]) return len(content) return 0 except: return 0 def score_column(col_name, keywords, role_type, num_candidates): """Score a column for how likely it is to be a particular role.""" if not has_keyword(col_name, keywords): return 0 score = 0 score += 10 # Penalize ambiguous keywords when scoring for user if role_type == "user": col_lower = col_name.lower() # If column is ONLY "task" (or task_xxx), give it lower priority for user role if "task" in col_lower and not any( kw in col_lower for kw in user_words_high_priority ): score -= 15 # Significant penalty so other user columns win priority_bonus = get_priority_score(col_name) score += priority_bonus if role_type in ["assistant", "user"]: avg_length = get_content_length(col_name) if num_candidates > 1: if avg_length > 1000: score += 50 elif avg_length > 200: score += 30 elif avg_length > 50: score += 10 elif avg_length < 50: score -= 20 else: if avg_length > 1000: score += 50 elif avg_length > 200: score += 30 elif avg_length > 50: score += 10 return score # Filter out metadata columns content_columns = [col for col in all_columns if not is_metadata(col)] # Count candidates first assistant_potential = [ col for col in content_columns if has_keyword(col, assistant_words) ] user_potential = [col for col in content_columns if has_keyword(col, user_words)] # STEP 1: Find best ASSISTANT column assistant_candidates = [] for col in assistant_potential: score = score_column( col, assistant_words, "assistant", len(assistant_potential) ) if score > 0: assistant_candidates.append((col, score)) if assistant_candidates: assistant_candidates.sort(key = lambda x: x[1], reverse = True) assistant_col = assistant_candidates[0][0] mapping[assistant_col] = "assistant" else: assistant_col = None # STEP 2: Find best USER column (with penalty for ambiguous keywords) user_candidates = [] for col in user_potential: if col == assistant_col: continue score = score_column(col, user_words, "user", len(user_potential)) if score > 0: user_candidates.append((col, score)) if user_candidates: user_candidates.sort(key = lambda x: x[1], reverse = True) user_col = user_candidates[0][0] mapping[user_col] = "user" else: user_col = None # STEP 3: Check ALL remaining columns for SYSTEM matches (priority check) remaining_columns = [col for col in content_columns if col not in mapping] system_col = None for col in remaining_columns: if has_keyword(col, system_words): # Found a system match in remaining columns mapping[col] = "system" system_col = col break # STEP 4: Handle any additional remaining columns if system_col: remaining_columns = [col for col in remaining_columns if col != system_col] if len(remaining_columns) >= 1: remaining_col = remaining_columns[0] # If no strong keyword match, decide based on what's missing if not has_keyword(remaining_col, user_words + assistant_words): mapping[remaining_col] = "system" elif user_col is None: # No user column yet, assign this as user mapping[remaining_col] = "user" else: # Already have user + assistant, treat as system context mapping[remaining_col] = "system" # VALIDATION: Ensure we have at least user + assistant has_user = any(role == "user" for role in mapping.values()) has_assistant = any(role == "assistant" for role in mapping.values()) if not has_user and len(remaining_columns) > 0: for col in remaining_columns: if col not in mapping: mapping[col] = "user" has_user = True break if has_user and has_assistant: return mapping return None def detect_multimodal_dataset(dataset): """ Detects if dataset contains multimodal data (images and/or audio). Two-pass approach for each modality: 1. Column-name heuristic (fast): checks for keywords. 2. Value-type inspection (reliable): checks actual sample values. Returns: dict: { "is_image": bool, "multimodal_columns": list of column names containing image data, "modality_types": list of detected types (e.g., ["image", "audio"]), "is_audio": bool, "audio_columns": list of column names containing audio data, "detected_audio_column": str or None, "detected_text_column": str or None, } """ sample = next(iter(dataset)) column_names = list(sample.keys()) # Keywords that indicate image data image_keywords = [ "image", "img", "pixel", "jpg", "jpeg", "png", "webp", "bmp", "gif", "tiff", "svg", "photo", "pic", "picture", "visual", "file_name", "filename", ] # Keywords that indicate audio data audio_keywords = ["audio", "speech", "wav", "waveform", "sound"] multimodal_columns = [] audio_columns = [] modality_types = set() # ── Image detection ───────────────────────────────────── # Pass 1: column-name heuristic (word-boundary match to avoid # false positives like 'pic' in 'topic') for col_name in column_names: for keyword in image_keywords: if _keyword_in_column(keyword, col_name): multimodal_columns.append(col_name) modality_types.add(keyword) break # Pass 2: inspect actual values already_detected = set(multimodal_columns) for col_name in column_names: if col_name in already_detected: continue value = sample[col_name] if _is_image_value(value): multimodal_columns.append(col_name) modality_types.add("image") # ── Audio detection ───────────────────────────────────── # Pass 1: column-name heuristic (word-boundary match) for col_name in column_names: for keyword in audio_keywords: if _keyword_in_column(keyword, col_name): audio_columns.append(col_name) modality_types.add("audio") break # Pass 2: inspect actual values (catches non-obvious column names) already_audio = set(audio_columns) for col_name in column_names: if col_name in already_audio: continue value = sample[col_name] if _is_audio_value(value): audio_columns.append(col_name) modality_types.add("audio") # Filter out columns that are actually audio from the image list # (e.g. a column named "audio" with {"bytes", "path"} could match _is_image_value) if audio_columns: audio_set = set(audio_columns) multimodal_columns = [c for c in multimodal_columns if c not in audio_set] # Detect text column for audio datasets detected_text_col = None if audio_columns: text_keywords = ["text", "sentence", "transcript", "transcription", "label"] for col_name in column_names: if col_name.lower() in text_keywords: detected_text_col = col_name break is_audio = len(audio_columns) > 0 # Detect speaker_id column for TTS datasets (CSM, Orpheus, Spark) detected_speaker_col = None if audio_columns: speaker_keywords = ["source", "speaker", "speaker_id"] for col_name in column_names: if col_name.lower() in speaker_keywords: detected_speaker_col = col_name break return { "is_image": len(multimodal_columns) > 0, "multimodal_columns": multimodal_columns, "modality_types": list(modality_types), "is_audio": is_audio, "audio_columns": audio_columns, "detected_audio_column": audio_columns[0] if audio_columns else None, "detected_text_column": detected_text_col, "detected_speaker_column": detected_speaker_col, } def _is_image_value(value) -> bool: """Check if a single sample value looks like image data.""" if value is None: return False # PIL Image instance try: from PIL.Image import Image as PILImage if isinstance(value, PILImage): return True except ImportError: pass # HF datasets Image feature stores decoded images as PIL or dicts with # {"bytes": b"...", "path": "..."} when not yet decoded. # Exclude audio dicts (decoded audio has "array" + "sampling_rate"). if isinstance(value, dict): if "array" in value and "sampling_rate" in value: return False # This is audio, not image if "bytes" in value and "path" in value: # Check path extension to exclude audio files path = value.get("path") or "" if isinstance(path, str) and any( path.lower().endswith(ext) for ext in _AUDIO_EXTENSIONS ): return False return True # Raw bytes with a known image magic header if isinstance(value, (bytes, bytearray)): return _has_image_header(value) # String that looks like an image file path or URL _IMAGE_EXTS = (".png", ".jpg", ".jpeg", ".webp", ".gif", ".bmp", ".tiff", ".svg") if isinstance(value, str) and len(value) < 1000: lower = value.strip().lower() # Image URL (http://... ending in image extension) if lower.startswith(("http://", "https://")) and any( lower.split("?")[0].endswith(ext) for ext in _IMAGE_EXTS ): return True # Image file path (relative or absolute path ending in image extension) if any(lower.endswith(ext) for ext in _IMAGE_EXTS): return True return False _AUDIO_EXTENSIONS = ( ".wav", ".mp3", ".flac", ".ogg", ".opus", ".m4a", ".aac", ".wma", ".webm", ) def _is_audio_value(value) -> bool: """Check if a single sample value looks like audio data.""" if value is None: return False # HF datasets Audio feature: decoded → {"array": np.ndarray, "sampling_rate": int} if isinstance(value, dict): if "array" in value and "sampling_rate" in value: return True # Undecoded/streaming → {"bytes": b"...", "path": "some.wav"} if "bytes" in value or "path" in value: path = value.get("path") or "" if isinstance(path, str) and any( path.lower().endswith(ext) for ext in _AUDIO_EXTENSIONS ): return True return False def _has_image_header(data: bytes) -> bool: """Quick magic-byte check for common image formats.""" if len(data) < 4: return False # JPEG if data[:2] == b"\xff\xd8": return True # PNG if data[:4] == b"\x89PNG": return True # GIF if data[:3] == b"GIF": return True # WebP if data[:4] == b"RIFF" and len(data) >= 12 and data[8:12] == b"WEBP": return True # BMP if data[:2] == b"BM": return True return False def detect_vlm_dataset_structure(dataset): """ Detects if VLM dataset is: - Standard VLM messages format (image objects in content) - Llava format (image indices + separate images column) - Simple format needing conversion (image + text columns) """ try: sample = next(iter(dataset)) except StopIteration: return { "format": "unknown", "needs_conversion": None, "image_column": None, "text_column": None, "messages_column": None, } column_names = set(sample.keys()) # Check if has messages column if "messages" in column_names: messages = sample["messages"] if messages and len(messages) > 0: first_msg = messages[0] if "content" in first_msg: content = first_msg["content"] if isinstance(content, list) and len(content) > 0: if isinstance(content[0], dict) and "type" in content[0]: # Check for llava format has_index = any( "index" in item for item in content if isinstance(item, dict) ) has_images_column = "images" in column_names if has_index and has_images_column: return { "format": "vlm_messages_llava", "needs_conversion": True, "messages_column": "messages", "image_column": "images", "text_column": None, } # Standard VLM format has_image = any( "image" in item for item in content if isinstance(item, dict) ) if has_image: return { "format": "vlm_messages", "needs_conversion": False, "messages_column": "messages", "image_column": None, "text_column": None, } # Check for ShareGPT/ChatML conversations with placeholder + companion image column # (e.g. Lin-Chen/ShareGPT4V, LLaVA-style datasets) for chat_col in ("conversations", "messages"): if chat_col not in column_names: continue chat_data = sample[chat_col] if not isinstance(chat_data, list) or len(chat_data) == 0: continue first_msg = chat_data[0] if not isinstance(first_msg, dict): continue # Detect ShareGPT (from/value) or ChatML (role/content) keys msg_text = first_msg.get("value") or first_msg.get("content") if not isinstance(msg_text, str): continue # Check for placeholder anywhere in the conversation has_image_placeholder = any( "" in str(m.get("value", "") or m.get("content", "")) for m in chat_data if isinstance(m, dict) ) if not has_image_placeholder: continue # Find companion image column image_col = None for col in column_names: if col == chat_col: continue if _keyword_in_column("image", col) or _keyword_in_column("img", col): image_col = col break if image_col: return { "format": "sharegpt_with_images", "needs_conversion": True, "image_column": image_col, "text_column": None, "messages_column": chat_col, } # Find image and text columns using metadata filtering # Define metadata patterns to EXCLUDE metadata_patterns = { "suffixes": [ "_id", "_url", "_name", "_filename", "_uri", "_link", "_key", "_index", ], "prefixes": [ "id_", "url_", "name_", "filename_", "uri_", "link_", "key_", "index_", ], } # Image-related keywords image_keywords = [ "image", "img", "photo", "picture", "pic", "visual", "scan", "file_name", "filename", ] # Text-related keywords text_keywords = [ "text", "caption", "captions", "description", "answer", "output", "response", "label", ] def is_metadata_column(col_name): """Check if column name looks like metadata.""" col_lower = col_name.lower() # Check suffixes if any(col_lower.endswith(suffix) for suffix in metadata_patterns["suffixes"]): return True # Check prefixes if any( col_lower.startswith(prefix) for prefix in metadata_patterns["prefixes"] ): return True return False def _score_image_candidate(col, sample_value): """Score a candidate image column by how resolvable its value is.""" # PIL Image object (highest priority - already loaded) if hasattr(sample_value, "size") and hasattr(sample_value, "mode"): return 100 # Dict with image data (bytes/path from HF Image feature) if isinstance(sample_value, dict) and ( "bytes" in sample_value or "path" in sample_value ): return 75 if isinstance(sample_value, str): # URL strings if sample_value.startswith(("http://", "https://")): return 70 if not is_metadata_column(col) else 55 # Bare file path if is_metadata_column(col): return 30 return 50 return 0 def _probe_image_candidate(col, sample_value): """Quick probe to check if an image candidate is actually reachable. Returns True if likely valid, False if definitely broken.""" import os # PIL / dict — already loaded, always valid if not isinstance(sample_value, str): return True # Local file — check it exists if not sample_value.startswith(("http://", "https://")): return os.path.exists( sample_value ) # bare filenames return False here, that's OK # URL — quick HEAD request with short timeout try: import urllib.request req = urllib.request.Request(sample_value, method = "HEAD") resp = urllib.request.urlopen(req, timeout = 3) return resp.status < 400 except Exception: return False def find_image_column(): """Find image column by keyword match + value-based fallback. When multiple candidates exist, probes them to find one that works.""" candidates = [] # Pass 1: keyword-matched columns for col in column_names: if any(_keyword_in_column(keyword, col) for keyword in image_keywords): sample_value = sample[col] score = _score_image_candidate(col, sample_value) if score > 0: candidates.append((col, score)) # Pass 2: value-based fallback — find columns with image URLs/paths # even if the column name doesn't match image keywords already = {c[0] for c in candidates} for col in column_names: if col in already: continue sample_value = sample[col] if _is_image_value(sample_value): score = _score_image_candidate(col, sample_value) # Slightly penalise non-keyword columns so keyword matches win on ties candidates.append((col, max(score - 5, 1))) if not candidates: return None candidates.sort(key = lambda x: x[1], reverse = True) # Single candidate or top candidate is PIL/dict — no probing needed if len(candidates) == 1 or candidates[0][1] >= 75: return candidates[0][0] # Multiple string-based candidates — probe to find one that actually works for col, score in candidates: sample_value = sample[col] if _probe_image_candidate(col, sample_value): return col # Nothing probed successfully — return highest-scored anyway and let # conversion handle the error (it may still resolve via hf_hub_download) return candidates[0][0] def find_text_column(): """Find text column by filtering out metadata and checking keywords.""" candidates = [] for col in column_names: # Skip metadata columns if is_metadata_column(col): continue # Check if contains text keywords (word-boundary match) if any(_keyword_in_column(keyword, col) for keyword in text_keywords): # Verify it's actually text sample_value = sample[col] if isinstance(sample_value, str) and len(sample_value) > 0: # Longer text = higher priority (likely content, not just a label) priority = min(len(sample_value), 1000) # Cap at 1000 candidates.append((col, priority)) elif ( isinstance(sample_value, list) and len(sample_value) > 0 and isinstance(sample_value[0], str) ): # List of strings (e.g. captions list) — lower priority than plain strings priority = min(len(sample_value[0]), 1000) // 2 candidates.append((col, priority)) # Return highest priority candidate if candidates: candidates.sort(key = lambda x: x[1], reverse = True) return candidates[0][0] return None found_image = find_image_column() found_text = find_text_column() if found_image and found_text: return { "format": "simple_image_text", "needs_conversion": True, "image_column": found_image, "text_column": found_text, "messages_column": None, } return { "format": "unknown", "needs_conversion": None, "image_column": found_image, "text_column": found_text, "messages_column": None, } ================================================ FILE: studio/backend/utils/datasets/llm_assist.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ LLM-assisted dataset analysis using an ephemeral GGUF helper model. Complements heuristic-based detection in format_detection.py and vlm_processing.py. Only invoked when heuristics are uncertain. Architecture: - Instantiates LlamaCppBackend, loads model, runs completion(s), unloads. - Not kept warm — VRAM is freed immediately after use. - Gracefully degrades: returns None when unavailable (no binary, OOM, disabled). """ import json import logging import os import re import textwrap import time from itertools import islice from typing import Any, Optional from loggers import get_logger logger = get_logger(__name__) DEFAULT_HELPER_MODEL_REPO = "unsloth/Qwen3.5-4B-GGUF" DEFAULT_HELPER_MODEL_VARIANT = "UD-Q4_K_XL" README_MAX_CHARS = 1500 def _strip_think_tags(text: str) -> str: """Strip ... reasoning blocks emitted by some models. If the model places its actual answer OUTSIDE the think block, we discard the think block and keep the rest. If the entire response is INSIDE a think block (nothing useful outside), we extract and return the inner content instead of discarding everything. """ if "" not in text: return text # Try stripping think blocks — keep content outside them stripped = re.sub(r".*?\s*", "", text, flags = re.DOTALL).strip() if stripped: return stripped # Everything was inside tags — extract the inner content of the last block matches = re.findall(r"(.*?)", text, flags = re.DOTALL) if matches: return matches[-1].strip() return text def precache_helper_gguf(): """ Pre-download the helper GGUF to HF cache. Called on FastAPI startup in a background thread so subsequent ``_run_with_helper()`` calls skip the download and only pay for llama-server startup. No-op if already cached or disabled. """ if os.environ.get("UNSLOTH_HELPER_MODEL_DISABLE", "").strip() in ("1", "true"): return repo = os.environ.get("UNSLOTH_HELPER_MODEL_REPO", DEFAULT_HELPER_MODEL_REPO) variant = os.environ.get( "UNSLOTH_HELPER_MODEL_VARIANT", DEFAULT_HELPER_MODEL_VARIANT ) try: from huggingface_hub import HfApi, hf_hub_download from huggingface_hub.utils import disable_progress_bars, enable_progress_bars disable_progress_bars() logging.getLogger("huggingface_hub").setLevel(logging.WARNING) # Find the GGUF file matching the variant api = HfApi() files = api.list_repo_files(repo, repo_type = "model") gguf_files = [f for f in files if f.endswith(".gguf")] # Find all GGUF files matching the variant (may be split into shards) variant_lower = variant.lower().replace("-", "_") matching = sorted( f for f in gguf_files if variant_lower in f.lower().replace("-", "_") ) if matching: logger.info( f"Pre-caching helper GGUF: {repo}/{matching[0]}" + (f" (+{len(matching) - 1} shards)" if len(matching) > 1 else "") ) for target in matching: hf_hub_download(repo_id = repo, filename = target) logger.info(f"Helper GGUF cached: {len(matching)} file(s)") else: logger.warning(f"No GGUF matching variant '{variant}' in {repo}") except Exception as e: logger.warning(f"Failed to pre-cache helper GGUF: {e}") finally: try: enable_progress_bars() except Exception as e: pass def _run_with_helper(prompt: str, max_tokens: int = 256) -> Optional[str]: """ Load helper model, run one chat completion, unload. Returns the completion text, or None on any failure. """ if os.environ.get("UNSLOTH_HELPER_MODEL_DISABLE", "").strip() in ("1", "true"): return None repo = os.environ.get("UNSLOTH_HELPER_MODEL_REPO", DEFAULT_HELPER_MODEL_REPO) variant = os.environ.get( "UNSLOTH_HELPER_MODEL_VARIANT", DEFAULT_HELPER_MODEL_VARIANT ) backend = None try: from core.inference.llama_cpp import LlamaCppBackend backend = LlamaCppBackend() logger.info(f"Loading helper model: {repo} ({variant})") ok = backend.load_model( hf_repo = repo, hf_variant = variant, model_identifier = f"helper:{repo}:{variant}", is_vision = False, n_ctx = 2048, n_gpu_layers = -1, ) if not ok: logger.warning("Helper model failed to start") return None messages = [{"role": "user", "content": prompt}] logger.info( "Helper model request: enable_thinking=False (per-request override)" ) cumulative = "" for text in backend.generate_chat_completion( messages = messages, temperature = 0.1, top_p = 0.9, top_k = 20, max_tokens = max_tokens, repetition_penalty = 1.0, enable_thinking = False, # Always disable thinking for AI Assist ): cumulative = text # cumulative — last value is full text result = cumulative.strip() result = _strip_think_tags(result) logger.info(f"Helper model response ({len(result)} chars)") return result if result else None except Exception as e: logger.warning(f"Helper model failed: {e}") return None finally: if backend is not None: try: backend.unload_model() logger.info("Helper model unloaded") except Exception: pass # ─── Public API ─────────────────────────────────────────────────────── def llm_generate_vlm_instruction( column_names: list[str], samples: list[dict], dataset_name: Optional[str] = None, ) -> Optional[dict]: """ Ask a helper LLM to generate a task-specific VLM instruction. Called when heuristic instruction generation returns low confidence or falls back to generic. Args: column_names: Column names in the dataset. samples: 3-5 sample rows with text values (images replaced by ""). dataset_name: Optional HF dataset identifier for context. Returns: {"instruction": str, "confidence": 0.85} or None. """ # Format samples for the prompt formatted = "" for i, row in enumerate(samples[:5], 1): parts = [] for col in column_names: val = str(row.get(col, ""))[:300] parts.append(f" {col}: {val}") formatted += f"Sample {i}:\n" + "\n".join(parts) + "\n\n" prompt = ( "You are a dataset analyst. Given a vision-language dataset, generate ONE " "instruction sentence that describes what the model should do with each image.\n\n" f"Dataset: {dataset_name or 'unknown'}\n" f"Columns: {column_names}\n\n" f"{formatted}" "Write ONE instruction sentence. Examples:\n" '- "Solve the math problem shown in the image and explain your reasoning."\n' '- "Transcribe all text visible in this image."\n' '- "Answer the question about this image."\n\n' "Respond with ONLY the instruction sentence, nothing else." ) result = _run_with_helper(prompt, max_tokens = 100) if not result: return None # Clean up: strip quotes, ensure it's a single sentence instruction = result.strip().strip('"').strip("'").strip() # Reject obviously bad outputs (too short, too long, or multi-line) if len(instruction) < 10 or len(instruction) > 200 or "\n" in instruction: logger.warning(f"Helper model returned unusable instruction: {instruction!r}") return None logger.info(f"LLM-generated instruction: {instruction}") return { "instruction": instruction, "confidence": 0.85, } def llm_classify_columns( column_names: list[str], samples: list[dict], ) -> Optional[dict[str, str]]: """ Ask a helper LLM to classify dataset columns into roles. Called when heuristic column detection fails (returns None). Args: column_names: Column names in the dataset. samples: 3-5 sample rows with values truncated to 200 chars. Returns: Dict mapping column_name → role ("user"|"assistant"|"system"|"metadata"), or None on failure. """ formatted = "" for i, row in enumerate(samples[:5], 1): parts = [] for col in column_names: val = str(row.get(col, ""))[:200] parts.append(f" {col}: {val}") formatted += f"Sample {i}:\n" + "\n".join(parts) + "\n\n" prompt = ( "Classify each column in this dataset into one of these roles:\n" "- user: The input/question/prompt from the human\n" "- assistant: The expected output/answer/response from the AI\n" "- system: Context, persona, or task description\n" "- metadata: IDs, scores, labels, timestamps — not part of conversation\n\n" f"Columns: {column_names}\n\n" f"{formatted}" "Respond with ONLY a JSON object mapping column names to roles.\n" 'Example: {"question": "user", "answer": "assistant", "id": "metadata"}' ) result = _run_with_helper(prompt, max_tokens = 200) if not result: return None # Parse JSON from response (may have markdown fences) text = result.strip() if text.startswith("```"): # Strip markdown code fence lines = text.split("\n") text = "\n".join(lines[1:-1] if lines[-1].strip() == "```" else lines[1:]) text = text.strip() try: mapping = json.loads(text) except json.JSONDecodeError: # Try to find JSON object in the response import re match = re.search(r"\{[^}]+\}", text) if match: try: mapping = json.loads(match.group()) except json.JSONDecodeError: logger.warning(f"Could not parse helper model JSON: {text!r}") return None else: logger.warning(f"No JSON found in helper model response: {text!r}") return None if not isinstance(mapping, dict): return None # Validate: all values must be valid roles valid_roles = {"user", "assistant", "system", "metadata"} cleaned = {} for col, role in mapping.items(): if ( col in column_names and isinstance(role, str) and role.lower() in valid_roles ): cleaned[col] = role.lower() if not cleaned: return None # Must have at least user + assistant roles_present = set(cleaned.values()) if "user" not in roles_present or "assistant" not in roles_present: logger.warning(f"Helper model mapping missing user/assistant: {cleaned}") return None logger.info(f"LLM-classified columns: {cleaned}") return cleaned def llm_generate_dataset_warning( issues: list[str], dataset_name: Optional[str] = None, modality: str = "text", column_names: Optional[list[str]] = None, ) -> Optional[str]: """ Ask the helper LLM to turn technical dataset issues into a user-friendly warning. Works for all modalities (text, vision, audio). Args: issues: List of technical issue descriptions found during analysis. dataset_name: Optional HF dataset name. modality: "text", "vision", or "audio". column_names: Optional list of column names for context. Returns: A human-friendly warning string, or None on failure. """ if not issues: return None issues_text = "\n".join(f"- {issue}" for issue in issues) cols_text = f"\nColumns: {column_names}" if column_names else "" prompt = ( "You are a helpful assistant. A user is trying to fine-tune a model on a dataset.\n" "The following issues were found during dataset analysis:\n\n" f"{issues_text}\n\n" f"Dataset: {dataset_name or 'unknown'}\n" f"Modality: {modality}" f"{cols_text}\n\n" "Write a brief, friendly explanation of what's wrong and what the user can do about it.\n" "Keep it under 3 sentences. Be specific about the dataset." ) result = _run_with_helper(prompt, max_tokens = 200) if not result: return None warning = result.strip() # Reject obviously bad outputs if len(warning) < 10 or len(warning) > 500: return None logger.info(f"LLM-generated warning: {warning}") return warning # ─── Dataset Conversion Advisor ────────────────────────────────────── def _parse_json_response(text: str) -> Optional[dict]: """Parse JSON from LLM response, handling markdown fences and noise.""" if not text: return None cleaned = text.strip() # Strip markdown code fences if cleaned.startswith("```"): lines = cleaned.split("\n") end = -1 if lines[-1].strip().startswith("```") else len(lines) cleaned = "\n".join(lines[1:end]).strip() # Try direct parse try: obj = json.loads(cleaned) if isinstance(obj, dict): return obj except json.JSONDecodeError: pass # Greedy match for outermost {...} match = re.search(r"\{.*\}", cleaned, re.DOTALL) if match: try: obj = json.loads(match.group()) if isinstance(obj, dict): return obj except json.JSONDecodeError: pass return None def _generate_with_backend(backend, messages: list[dict], max_tokens: int = 512) -> str: """Run one chat completion on an already-loaded backend. Returns raw text.""" logger.info("Advisor request: enable_thinking=False (per-request override)") cumulative = "" for text in backend.generate_chat_completion( messages = messages, temperature = 0.1, top_p = 0.9, top_k = 20, max_tokens = max_tokens, repetition_penalty = 1.0, enable_thinking = False, # Always disable thinking for AI Assist ): cumulative = text result = cumulative.strip() result = _strip_think_tags(result) return result def fetch_hf_dataset_card( dataset_name: str, hf_token: Optional[str] = None ) -> tuple[Optional[str], Optional[dict]]: """ Fetch HF dataset card (README) and metadata. Returns: (readme_text, metadata_dict) or (None, None) on failure. """ try: from huggingface_hub import DatasetCard card = DatasetCard.load(dataset_name, token = hf_token) readme = card.text or "" # Truncate at sentence boundary if len(readme) > README_MAX_CHARS: cut = readme[:README_MAX_CHARS].rfind(".") if cut > README_MAX_CHARS // 2: readme = readme[: cut + 1] + "\n[...truncated]" else: readme = readme[:README_MAX_CHARS] + "\n[...truncated]" # Extract metadata from YAML frontmatter metadata = {} if card.data: for key in ( "task_categories", "task_ids", "language", "size_categories", "tags", "license", "pretty_name", ): val = getattr(card.data, key, None) if val is not None: metadata[key] = val logger.info( f"Fetched dataset card: {len(readme)} chars, {len(metadata)} metadata fields" ) return readme, metadata except Exception as e: logger.warning(f"Could not fetch dataset card for {dataset_name}: {e}") return None, None def _run_multi_pass_advisor( columns: list[str], samples: list[dict], dataset_name: Optional[str] = None, dataset_card: Optional[str] = None, dataset_metadata: Optional[dict] = None, model_name: Optional[str] = None, model_type: Optional[str] = None, hf_token: Optional[str] = None, ) -> Optional[dict[str, Any]]: """ Multi-pass LLM analysis: classify → convert → validate. Keeps model loaded across all passes. Returns combined result dict or None. """ if os.environ.get("UNSLOTH_HELPER_MODEL_DISABLE", "").strip() in ("1", "true"): return None repo = os.environ.get("UNSLOTH_HELPER_MODEL_REPO", DEFAULT_HELPER_MODEL_REPO) variant = os.environ.get( "UNSLOTH_HELPER_MODEL_VARIANT", DEFAULT_HELPER_MODEL_VARIANT ) backend = None try: from core.inference.llama_cpp import LlamaCppBackend backend = LlamaCppBackend() logger.info(f"Loading advisor model: {repo} ({variant})") t0 = time.monotonic() ok = backend.load_model( hf_repo = repo, hf_variant = variant, model_identifier = f"advisor:{repo}:{variant}", is_vision = False, n_ctx = 2048, n_gpu_layers = -1, ) if not ok: logger.warning("Advisor model failed to start") return None logger.info(f"Advisor model loaded in {time.monotonic() - t0:.1f}s") # ── Format samples ── samples_text = "" for i, row in enumerate(samples[:5], 1): parts = [f" {col}: {str(row.get(col, ''))[:200]}" for col in columns] samples_text += f"Row {i}:\n" + "\n".join(parts) + "\n" metadata_str = ( json.dumps(dataset_metadata, indent = 2, default = str)[:500] if dataset_metadata else "N/A" ) card_excerpt = (dataset_card or "")[:1200] or "N/A" # ── Target Model Hints ── target_hints = "" is_gemma_3n = False if model_name: try: from utils.models.model_config import load_model_config config = load_model_config( model_name, use_auth = True, token = hf_token, trust_remote_code = False, ) archs = getattr(config, "architectures", []) if archs and "Gemma3nForConditionalGeneration" in archs: is_gemma_3n = True except Exception: is_gemma_3n = "gemma-3n" in model_name.lower() if model_type == "audio" and not is_gemma_3n: target_hints = ( "\n\nHINT: The user is training an AUDIO model. The dataset MUST contain " "a column with audio files/paths. Ensure one such column is selected " "as part of the input." ) elif model_type == "embeddings": target_hints = ( "\n\nHINT: The user is training an EMBEDDING model. These models typically " "do not use standard conversational input/output formats but instead use " "specific formats like:\n" "- Pairs of texts for Semantic Textual Similarity (STS)\n" "- Premise, hypothesis, and label for Natural Language Inference (NLI)\n" "- Queries and positive/negative documents for information retrieval\n" "Ensure the dataset format mapped reflects these specialized tasks." ) # ── Pass 1: Classify ── logger.info("Pass 1: Classifying dataset...") t1 = time.monotonic() messages1 = [ { "role": "system", "content": ( "You are a dataset analyst. Your job is to look at a HuggingFace dataset " "and figure out what kind of data it contains and whether it is already in " "a conversational format suitable for LLM fine-tuning. A dataset is " '"conversational" if it already has columns like "messages", "conversations", ' 'or multiturn "user"/"assistant" pairs. Some datasets are NOT conversational ' "— they are things like summarization, question answering, translation, " "classification, etc. Those need conversion. You must respond with ONLY a " "valid JSON object. Do not write any explanation before or after the JSON." f"{target_hints}" ), }, { "role": "user", "content": textwrap.dedent(f"""\ Look at this HuggingFace dataset and classify it. DATASET CARD (excerpt): {card_excerpt} METADATA: {metadata_str} COLUMNS: {columns} SAMPLE DATA (first 3 rows): {samples_text} Based on the above, respond with this exact JSON structure: {{ "dataset_type": "", "is_conversational": , "needs_conversion": , "description": "", "task_description": "" }} Respond with ONLY the JSON object. No markdown, no explanation."""), }, ] raw1 = _generate_with_backend(backend, messages1, max_tokens = 256) pass1 = _parse_json_response(raw1) logger.info(f"Pass 1 done ({time.monotonic() - t1:.1f}s): {pass1}") if not pass1: logger.warning(f"Advisor Pass 1 failed to produce JSON: {raw1[:200]}") return None # If dataset is already conversational, skip passes 2-3 if pass1.get("is_conversational") and not pass1.get("needs_conversion"): return { "success": True, "dataset_type": pass1.get("dataset_type"), "is_conversational": True, "user_notification": ( "This dataset is already in conversational format. " "No conversion needed — columns can be mapped directly." ), } # ── Pass 2: Map columns to roles ── logger.info("Pass 2: Mapping columns to roles...") t2 = time.monotonic() messages2 = [ { "role": "system", "content": ( "You are a data preparation assistant. Your job is to assign each column " "in a dataset to a conversation role for LLM fine-tuning. There are exactly " "two roles:\n" '- "user" = This column contains INPUT that the model will receive as a prompt.\n' '- "assistant" = This column contains OUTPUT that the model should learn to generate.\n\n' "CRITICAL RULES:\n" '1. There MUST be at least one column assigned to "user" AND at least one ' 'column assigned to "assistant". Never assign all columns to the same role.\n' "2. The column that contains the TARGET or OUTPUT or ANSWER or LABEL must " 'ALWAYS be assigned to "assistant". This is the thing the model should learn ' "to produce.\n" "3. The columns that contain the SOURCE or INPUT or CONTEXT or QUESTION must " 'be assigned to "user". This is what the model receives.\n' '4. Metadata columns like "id", "index", "source", "url", "date" should be ' 'set to "skip".\n\n' "You must respond with ONLY a valid JSON object." f"{target_hints}" ), }, { "role": "user", "content": textwrap.dedent(f"""\ Here is a dataset that has been classified: CLASSIFICATION: {json.dumps(pass1, indent = 2)} COLUMNS AVAILABLE: {columns} SAMPLE DATA (first 3 rows): {samples_text} Your task: assign each column to either "user", "assistant", or "skip". Here are worked examples to guide you: Example 1 — Summarization dataset with columns ["document", "summary"]: "document" is the input text → "user" "summary" is the output the model should generate → "assistant" Result: {{"document": "user", "summary": "assistant"}} Example 2 — Question answering dataset with columns ["context", "question", "answer"]: "context" is input → "user" "question" is input → "user" "answer" is what the model should generate → "assistant" Result: {{"context": "user", "question": "user", "answer": "assistant"}} Example 3 — Classification dataset with columns ["text", "label"]: "text" is input → "user" "label" is the output the model should predict → "assistant" Result: {{"text": "user", "label": "assistant"}} Example 4 — Translation dataset with columns ["en", "fr"]: "en" is the source language (input) → "user" "fr" is the target language (output) → "assistant" Result: {{"en": "user", "fr": "assistant"}} Now apply this logic to the actual dataset columns listed above. Respond with this exact JSON structure: {{ "column_roles": {{ "": "" }}, "label_mapping": , "notes": "" }} REMEMBER: There must be at least one "user" column AND at least one "assistant" column. If all columns are "user", you made a mistake — the output/target column should be "assistant". Respond with ONLY the JSON object."""), }, ] raw2 = _generate_with_backend(backend, messages2, max_tokens = 512) pass2 = _parse_json_response(raw2) logger.info(f"Pass 2 done ({time.monotonic() - t2:.1f}s): {pass2}") if not pass2: logger.warning(f"Advisor Pass 2 failed to produce JSON: {raw2[:200]}") return None # ── Extract and validate column roles from Pass 2 ── column_roles = pass2.get("column_roles", {}) label_map = pass2.get("label_mapping") or {} # may be null # Validate: must have at least one user AND one assistant roles_present = set(column_roles.values()) if "user" not in roles_present or "assistant" not in roles_present: logger.warning( f"Pass 2 sanity fail: missing user or assistant role: {column_roles}" ) return None # triggers fallback to simple classification # ── Pass 3: System prompt (non-conversational datasets only) ── sys_prompt = "" dtype = pass1.get("dataset_type", "unknown") is_conv = pass1.get("is_conversational", False) if not is_conv: logger.info("Pass 3: Generating system prompt...") t3 = time.monotonic() # Format label mapping info for the prompt label_info = "" if label_map: for col, mapping in label_map.items(): if isinstance(mapping, dict) and mapping: pairs = ", ".join(f"{k} = {v}" for k, v in mapping.items()) label_info += f"\nLabel mapping for '{col}': {pairs}" # Describe the role assignments for context user_cols = [c for c, r in column_roles.items() if r == "user"] asst_cols = [c for c, r in column_roles.items() if r == "assistant"] task_desc = pass1.get("task_description") or pass1.get("description", "") messages3 = [ { "role": "user", "content": textwrap.dedent(f"""\ I am building a fine-tuning dataset for an LLM. I need you to write a \ system prompt that will be included in every training example to tell \ the model what task it is performing. Here is the task information: - Dataset type: {dtype} - Task description: {task_desc} - The USER (input) columns are: {user_cols} - The ASSISTANT (output) columns are: {asst_cols} {label_info} Write a system prompt that: 1. Explains what task the model is performing in plain language 2. Describes what input it will receive 3. Describes what output it should produce 4. Is 2-4 sentences long Write ONLY the system prompt text. No quotes, no labels, no explanation around it."""), }, ] raw3 = _generate_with_backend(backend, messages3, max_tokens = 256) logger.info( f"Pass 3 done ({time.monotonic() - t3:.1f}s): {raw3[:200] if raw3 else None}" ) if raw3: # Pass 3 returns raw text, not JSON — clean it up cleaned = raw3.strip().strip('"').strip("'").strip() if len(cleaned) >= 20 and cleaned.lower() not in ("null", "none", ""): sys_prompt = cleaned # Build suggested_mapping (column → role, for the frontend dropdowns) suggested_mapping = {} for col, role in column_roles.items(): if col in columns and role in ("user", "assistant", "system"): suggested_mapping[col] = role # Build user notification from Pass 1 classification desc = pass1.get("task_description") or pass1.get("description", "") note_parts = [f"This is a {dtype} dataset (not conversational)."] if desc: note_parts.append(desc) note_parts.append( "Columns have been mapped to conversation roles. You can adjust the mapping if needed." ) user_notification = " ".join(note_parts) total_time = time.monotonic() - t0 logger.info( f"Advisor complete ({total_time:.1f}s): type={dtype}, mapping={suggested_mapping}, sys_prompt={bool(sys_prompt)}, label_map={bool(label_map)}" ) return { "success": True, "suggested_mapping": suggested_mapping, "system_prompt": sys_prompt, "label_mapping": label_map if label_map else None, "dataset_type": dtype, "is_conversational": is_conv, "user_notification": user_notification, } except Exception as e: logger.warning(f"Advisor multi-pass failed: {e}") return None finally: if backend is not None: try: backend.unload_model() logger.info("Advisor model unloaded") except Exception: pass def llm_conversion_advisor( column_names: list[str], samples: list[dict], dataset_name: Optional[str] = None, hf_token: Optional[str] = None, model_name: Optional[str] = None, model_type: Optional[str] = None, ) -> Optional[dict[str, Any]]: """ Full conversion advisor: fetch HF card → multi-pass LLM analysis. Falls back to simple llm_classify_columns() if the multi-pass advisor fails. Returns: Dict with keys: success, suggested_mapping, system_prompt, user_template, assistant_template, label_mapping, dataset_type, is_conversational, user_notification. Or None on complete failure. """ # Fetch HF dataset card if this looks like a HF dataset (has a slash) dataset_card = None dataset_metadata = None if dataset_name and "/" in dataset_name: dataset_card, dataset_metadata = fetch_hf_dataset_card(dataset_name, hf_token) # Try multi-pass advisor result = _run_multi_pass_advisor( columns = column_names, samples = samples, dataset_name = dataset_name, dataset_card = dataset_card, dataset_metadata = dataset_metadata, model_name = model_name, model_type = model_type, hf_token = hf_token, ) if result and result.get("success"): logger.info(f"Conversion advisor succeeded: type={result.get('dataset_type')}") return result # Fallback: simple column classification logger.info("Advisor failed, falling back to simple column classification") simple_mapping = llm_classify_columns(column_names, samples) if simple_mapping: return { "success": True, "suggested_mapping": { col: role for col, role in simple_mapping.items() if role in ("user", "assistant", "system") }, "dataset_type": None, "is_conversational": None, "user_notification": None, } return None ================================================ FILE: studio/backend/utils/datasets/model_mappings.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Model and template mappings for dataset processing. This module contains the mapping dictionaries that associate model names with their corresponding chat templates and response markers. """ TEMPLATE_TO_MODEL_MAPPER = { "phi-3.5": ( "unsloth/Phi-3.5-mini-instruct-bnb-4bit", "unsloth/Phi-3.5-mini-instruct", "microsoft/Phi-3.5-mini-instruct", ), "phi-3": ( "unsloth/Phi-3-mini-4k-instruct-bnb-4bit", "unsloth/Phi-3-mini-4k-instruct", "microsoft/Phi-3-mini-4k-instruct", "unsloth/Phi-3-medium-4k-instruct-bnb-4bit", "unsloth/Phi-3-medium-4k-instruct", "microsoft/Phi-3-medium-4k-instruct", "unsloth/Phi-3-mini-4k-instruct-v0-bnb-4bit", "unsloth/Phi-3-mini-4k-instruct-v0", ), "phi-4": ( "unsloth/phi-4-unsloth-bnb-4bit", "unsloth/phi-4", "microsoft/phi-4", "unsloth/phi-4-bnb-4bit", "unsloth/phi-4-reasoning-unsloth-bnb-4bit", "unsloth/phi-4-reasoning", "microsoft/Phi-4-reasoning", "unsloth/phi-4-reasoning-bnb-4bit", "unsloth/phi-4-reasoning-plus-unsloth-bnb-4bit", "unsloth/phi-4-reasoning-plus", "microsoft/Phi-4-reasoning-plus", "unsloth/phi-4-reasoning-plus-bnb-4bit", "unsloth/phi-4-mini-reasoning-unsloth-bnb-4bit", "unsloth/phi-4-mini-reasoning", "microsoft/Phi-4-mini-reasoning", "unsloth/phi-4-mini-reasoning-bnb-4bit", "unsloth/Phi-4-mini-instruct-unsloth-bnb-4bit", "unsloth/Phi-4-mini-instruct", "microsoft/Phi-4-mini-instruct", "unsloth/Phi-4-mini-instruct-bnb-4bit", ), "mistral": ( "unsloth/mistral-7b-instruct-v0.1-bnb-4bit", "unsloth/mistral-7b-instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.1", "unsloth/mistral-7b-instruct-v0.2-bnb-4bit", "unsloth/mistral-7b-instruct-v0.2", "mistralai/Mistral-7B-Instruct-v0.2", "unsloth/mistral-7b-instruct-v0.3-bnb-4bit", "unsloth/mistral-7b-instruct-v0.3", "mistralai/Mistral-7B-Instruct-v0.3", "unsloth/Mixtral-8x7B-Instruct-v0.1-unsloth-bnb-4bit", "unsloth/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mixtral-8x7B-Instruct-v0.1", "unsloth/Mixtral-8x7B-Instruct-v0.1-bnb-4bit", "unsloth/Mistral-Nemo-Instruct-2407-bnb-4bit", "unsloth/Mistral-Nemo-Instruct-2407", "mistralai/Mistral-Nemo-Instruct-2407", "unsloth/Mistral-Large-Instruct-2407-bnb-4bit", "mistralai/Mistral-Large-Instruct-2407", "unsloth/Mistral-Small-Instruct-2409-bnb-4bit", "unsloth/Mistral-Small-Instruct-2409", "mistralai/Mistral-Small-Instruct-2409", "unsloth/Mistral-Small-24B-Instruct-2501-unsloth-bnb-4bit", "unsloth/Mistral-Small-24B-Instruct-2501", "mistralai/Mistral-Small-24B-Instruct-2501", "unsloth/Mistral-Small-24B-Instruct-2501-bnb-4bit", "unsloth/Mistral-Small-3.1-24B-Instruct-2503-unsloth-bnb-4bit", "unsloth/Mistral-Small-3.1-24B-Instruct-2503", "mistralai/Mistral-Small-3.1-24B-Instruct-2503", "unsloth/Mistral-Small-3.1-24B-Instruct-2503-bnb-4bit", "unsloth/Mistral-Small-3.2-24B-Instruct-2506-unsloth-bnb-4bit", "unsloth/Mistral-Small-3.2-24B-Instruct-2506", "mistralai/Mistral-Small-3.2-24B-Instruct-2506", "unsloth/Mistral-Small-3.2-24B-Instruct-2506-bnb-4bit", ), "llama": ( "meta-llama/Llama-2-13b-chat-hf", "unsloth/llama-2-7b-chat-bnb-4bit", "unsloth/llama-2-7b-chat", "meta-llama/Llama-2-7b-chat-hf", ), "llama3": ( "unsloth/llama-3-8b-Instruct-bnb-4bit", "unsloth/llama-3-8b-Instruct", "meta-llama/Meta-Llama-3-8B-Instruct", "unsloth/llama-3-70b-Instruct-bnb-4bit", "meta-llama/Meta-Llama-3-70B-Instruct", ), "llama-3.1": ( "unsloth/Meta-Llama-3.1-8B-Instruct-unsloth-bnb-4bit", "unsloth/Meta-Llama-3.1-8B-Instruct", "meta-llama/Meta-Llama-3.1-8B-Instruct", "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit", "unsloth/Llama-3.1-8B-Instruct-unsloth-bnb-4bit", "unsloth/Llama-3.1-8B-Instruct", "meta-llama/Llama-3.1-8B-Instruct", "unsloth/Llama-3.1-8B-Instruct-bnb-4bit", "unsloth/Meta-Llama-3.1-405B-Instruct-bnb-4bit", "meta-llama/Meta-Llama-3.1-405B-Instruct", "unsloth/Meta-Llama-3.1-70B-Instruct-bnb-4bit", "unsloth/Meta-Llama-3.1-70B-Instruct", "meta-llama/Meta-Llama-3.1-70B-Instruct", "unsloth/Llama-3.1-Storm-8B-bnb-4bit", "unsloth/Llama-3.1-Storm-8B", "akjindal53244/Llama-3.1-Storm-8B", "unsloth/Hermes-3-Llama-3.1-8B-bnb-4bit", "unsloth/Hermes-3-Llama-3.1-8B", "NousResearch/Hermes-3-Llama-3.1-8B", "unsloth/Hermes-3-Llama-3.1-70B-bnb-4bit", "unsloth/Hermes-3-Llama-3.1-70B", "NousResearch/Hermes-3-Llama-3.1-70B", "unsloth/Hermes-3-Llama-3.1-405B-bnb-4bit", "NousResearch/Hermes-3-Llama-3.1-405B", "unsloth/Llama-3.1-Nemotron-70B-Instruct-bnb-4bit", "unsloth/Llama-3.1-Nemotron-70B-Instruct", "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF", "unsloth/Llama-3.1-Tulu-3-8B-bnb-4bit", "unsloth/Llama-3.1-Tulu-3-8B", "allenai/Llama-3.1-Tulu-3-8B", "unsloth/Llama-3.1-Tulu-3-70B-bnb-4bit", "unsloth/Llama-3.1-Tulu-3-70B", "allenai/Llama-3.1-Tulu-3-70B", ), "llama-3.2": ( "unsloth/Llama-3.2-1B-Instruct-unsloth-bnb-4bit", "unsloth/Llama-3.2-1B-Instruct", "meta-llama/Llama-3.2-1B-Instruct", "unsloth/Llama-3.2-1B-Instruct-bnb-4bit", "unsloth/Llama-3.2-3B-Instruct-unsloth-bnb-4bit", "unsloth/Llama-3.2-3B-Instruct", "meta-llama/Llama-3.2-3B-Instruct", "unsloth/Llama-3.2-3B-Instruct-bnb-4bit", "unsloth/Llama-3.2-11B-Vision-Instruct-unsloth-bnb-4bit", "unsloth/Llama-3.2-11B-Vision-Instruct", "meta-llama/Llama-3.2-11B-Vision-Instruct", "unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit", "unsloth/Llama-3.2-90B-Vision-Instruct-bnb-4bit", "unsloth/Llama-3.2-90B-Vision-Instruct", "meta-llama/Llama-3.2-90B-Vision-Instruct", ), "llama-3.3": ( "unsloth/Llama-3.3-70B-Instruct-bnb-4bit", "unsloth/Llama-3.3-70B-Instruct", "meta-llama/Llama-3.3-70B-Instruct", ), "gemma": ( "unsloth/gemma-7b-it-bnb-4bit", "unsloth/gemma-7b-it", "google/gemma-7b-it", "google/gemma-2b-it", "unsloth/gemma-1.1-2b-it-bnb-4bit", "unsloth/gemma-1.1-2b-it", "google/gemma-1.1-2b-it", "unsloth/gemma-1.1-7b-it-bnb-4bit", "unsloth/gemma-1.1-7b-it", "google/gemma-1.1-7b-it", ), "gemma2": ( "unsloth/gemma-2-9b-it-bnb-4bit", "unsloth/gemma-2-9b-it", "google/gemma-2-9b-it", "unsloth/gemma-2-27b-it-bnb-4bit", "unsloth/gemma-2-27b-it", "google/gemma-2-27b-it", "unsloth/gemma-2-2b-it-bnb-4bit", "unsloth/gemma-2-2b-it", "google/gemma-2-2b-it", ), "gemma-3": ( "unsloth/gemma-3-1b-it-unsloth-bnb-4bit", "unsloth/gemma-3-1b-it", "google/gemma-3-1b-it", "unsloth/gemma-3-1b-it-bnb-4bit", "unsloth/gemma-3-4b-it-unsloth-bnb-4bit", "unsloth/gemma-3-4b-it", "google/gemma-3-4b-it", "unsloth/gemma-3-4b-it-bnb-4bit", "unsloth/gemma-3-12b-it-unsloth-bnb-4bit", "unsloth/gemma-3-12b-it", "google/gemma-3-12b-it", "unsloth/gemma-3-12b-it-bnb-4bit", "unsloth/gemma-3-27b-it-unsloth-bnb-4bit", "unsloth/gemma-3-27b-it", "google/gemma-3-27b-it", "unsloth/gemma-3-27b-it-bnb-4bit", "unsloth/gemma-3-270m-it-unsloth-bnb-4bit", "unsloth/gemma-3-270m-it", "google/gemma-3-270m-it", "unsloth/gemma-3-270m-it-bnb-4bit", "unsloth/gemma-3-270m-unsloth-bnb-4bit", "unsloth/medgemma-4b-it-unsloth-bnb-4bit", "unsloth/medgemma-4b-it", "google/medgemma-4b-it", "unsloth/medgemma-4b-it-bnb-4bit", "unsloth/medgemma-27b-text-it-unsloth-bnb-4bit", "unsloth/medgemma-27b-text-it", "google/medgemma-27b-text-it", "unsloth/medgemma-27b-text-it-bnb-4bit", ), "gemma3n": ( "unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit", "unsloth/gemma-3n-E4B-it", "google/gemma-3n-E4B-it", "unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit", "unsloth/gemma-3n-E2B-it-unsloth-bnb-4bit", "unsloth/gemma-3n-E2B-it", "google/gemma-3n-E2B-it", "unsloth/gemma-3n-E2B-it-unsloth-bnb-4bit", ), "qwen2.5": ( "unsloth/Qwen2.5-0.5B-Instruct-unsloth-bnb-4bit", "unsloth/Qwen2.5-0.5B-Instruct", "Qwen/Qwen2.5-0.5B-Instruct", "unsloth/Qwen2.5-0.5B-Instruct-bnb-4bit", "unsloth/Qwen2.5-1.5B-Instruct-unsloth-bnb-4bit", "unsloth/Qwen2.5-1.5B-Instruct", "Qwen/Qwen2.5-1.5B-Instruct", "unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit", "unsloth/Qwen2.5-3B-Instruct-unsloth-bnb-4bit", "unsloth/Qwen2.5-3B-Instruct", "Qwen/Qwen2.5-3B-Instruct", "unsloth/Qwen2.5-3B-Instruct-bnb-4bit", "unsloth/Qwen2.5-7B-Instruct-unsloth-bnb-4bit", "unsloth/Qwen2.5-7B-Instruct", "Qwen/Qwen2.5-7B-Instruct", "unsloth/Qwen2.5-7B-Instruct-bnb-4bit", "unsloth/Qwen2.5-14B-Instruct-unsloth-bnb-4bit", "unsloth/Qwen2.5-14B-Instruct", "Qwen/Qwen2.5-14B-Instruct", "unsloth/Qwen2.5-14B-Instruct-bnb-4bit", "unsloth/Qwen2.5-32B-Instruct-bnb-4bit", "unsloth/Qwen2.5-32B-Instruct", "Qwen/Qwen2.5-32B-Instruct", "unsloth/Qwen2.5-72B-Instruct-bnb-4bit", "unsloth/Qwen2.5-72B-Instruct", "Qwen/Qwen2.5-72B-Instruct", "unsloth/Qwen2.5-0.5B-unsloth-bnb-4bit", "unsloth/Qwen2.5-Math-1.5B-Instruct-bnb-4bit", "unsloth/Qwen2.5-Math-1.5B-Instruct", "Qwen/Qwen2.5-Math-1.5B-Instruct", "unsloth/Qwen2.5-Math-7B-Instruct-bnb-4bit", "unsloth/Qwen2.5-Math-7B-Instruct", "Qwen/Qwen2.5-Math-7B-Instruct", "unsloth/Qwen2.5-Math-72B-Instruct-bnb-4bit", "unsloth/Qwen2.5-Math-72B-Instruct", "Qwen/Qwen2.5-Math-72B-Instruct", "unsloth/Qwen2.5-Coder-0.5B-Instruct-bnb-4bit", "unsloth/Qwen2.5-Coder-0.5B-Instruct", "Qwen/Qwen2.5-Coder-0.5B-Instruct", "unsloth/Qwen2.5-Coder-1.5B-Instruct-bnb-4bit", "unsloth/Qwen2.5-Coder-1.5B-Instruct", "Qwen/Qwen2.5-Coder-1.5B-Instruct", "unsloth/Qwen2.5-Coder-3B-Instruct-bnb-4bit", "unsloth/Qwen2.5-Coder-3B-Instruct", "Qwen/Qwen2.5-Coder-3B-Instruct", "unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit", "unsloth/Qwen2.5-Coder-7B-Instruct", "Qwen/Qwen2.5-Coder-7B-Instruct", "unsloth/Qwen2.5-Coder-14B-Instruct-bnb-4bit", "unsloth/Qwen2.5-Coder-14B-Instruct", "Qwen/Qwen2.5-Coder-14B-Instruct", "unsloth/Qwen2.5-Coder-32B-Instruct-bnb-4bit", "unsloth/Qwen2.5-Coder-32B-Instruct", "Qwen/Qwen2.5-Coder-32B-Instruct", "unsloth/Qwen2.5-VL-3B-Instruct-unsloth-bnb-4bit", "unsloth/Qwen2.5-VL-3B-Instruct", "Qwen/Qwen2.5-VL-3B-Instruct", "unsloth/Qwen2.5-VL-3B-Instruct-bnb-4bit", "unsloth/Qwen2.5-VL-7B-Instruct-unsloth-bnb-4bit", "unsloth/Qwen2.5-VL-7B-Instruct", "Qwen/Qwen2.5-VL-7B-Instruct", "unsloth/Qwen2.5-VL-7B-Instruct-bnb-4bit", "unsloth/Qwen2.5-VL-32B-Instruct-unsloth-bnb-4bit", "unsloth/Qwen2.5-VL-32B-Instruct", "Qwen/Qwen2.5-VL-32B-Instruct", "unsloth/Qwen2.5-VL-32B-Instruct-bnb-4bit", "unsloth/Qwen2.5-VL-72B-Instruct-unsloth-bnb-4bit", "unsloth/Qwen2.5-VL-72B-Instruct", "Qwen/Qwen2.5-VL-72B-Instruct", "unsloth/Qwen2.5-VL-72B-Instruct-bnb-4bit", "unsloth/OpenThinker-7B-unsloth-bnb-4bit", "unsloth/OpenThinker-7B", "open-thoughts/OpenThinker-7B", "unsloth/OpenThinker-7B-bnb-4bit", ), "qwen3": ( "unsloth/Qwen3-0.6B-unsloth-bnb-4bit", "unsloth/Qwen3-0.6B", "Qwen/Qwen3-0.6B", "unsloth/Qwen3-0.6B-bnb-4bit", "unsloth/Qwen3-1.7B-unsloth-bnb-4bit", "unsloth/Qwen3-1.7B", "Qwen/Qwen3-1.7B", "unsloth/Qwen3-1.7B-bnb-4bit", "unsloth/Qwen3-4B-unsloth-bnb-4bit", "unsloth/Qwen3-4B", "Qwen/Qwen3-4B", "unsloth/Qwen3-4B-bnb-4bit", "unsloth/Qwen3-8B-unsloth-bnb-4bit", "unsloth/Qwen3-8B", "Qwen/Qwen3-8B", "unsloth/Qwen3-8B-bnb-4bit", "unsloth/Qwen3-14B-unsloth-bnb-4bit", "unsloth/Qwen3-14B", "Qwen/Qwen3-14B", "unsloth/Qwen3-14B-bnb-4bit", "unsloth/Qwen3-32B-unsloth-bnb-4bit", "unsloth/Qwen3-32B", "Qwen/Qwen3-32B", "unsloth/Qwen3-32B-bnb-4bit", "unsloth/Qwen3-30B-A3B-unsloth-bnb-4bit", "unsloth/Qwen3-30B-A3B", "Qwen/Qwen3-30B-A3B", "unsloth/Qwen3-30B-A3B-bnb-4bit", ), "qwen3-instruct": ( "unsloth/Qwen3-4B-Instruct-2507-unsloth-bnb-4bit", "unsloth/Qwen3-4B-Instruct-2507", "Qwen/Qwen3-4B-Instruct-2507", "unsloth/Qwen3-4B-Instruct-2507-bnb-4bit", "unsloth/Qwen3-30B-A3B-Instruct-2507", "Qwen/Qwen3-30B-A3B-Instruct-2507", "unsloth/Qwen3-Coder-30B-A3B-Instruct", "Qwen/Qwen3-Coder-30B-A3B-Instruct", "unsloth/Qwen3-4B-Instruct-2507-unsloth-bnb-4bit", "unsloth/Qwen3-4B-Instruct-2507", "Qwen/Qwen3-4B-Instruct-2507", "unsloth/Qwen3-4B-Instruct-2507-bnb-4bit", ), "qwen3-thinking": ( "unsloth/QwQ-32B-Preview-bnb-4bit", "unsloth/QwQ-32B-Preview", "Qwen/QwQ-32B-Preview", "unsloth/QwQ-32B-unsloth-bnb-4bit", "unsloth/QwQ-32B", "Qwen/QwQ-32B", "unsloth/QwQ-32B-bnb-4bit", "unsloth/Qwen3-4B-Thinking-2507-unsloth-bnb-4bit", "unsloth/Qwen3-4B-Thinking-2507", "Qwen/Qwen3-4B-Thinking-2507", "unsloth/Qwen3-4B-Thinking-2507-bnb-4bit", "unsloth/Qwen3-30B-A3B-Thinking-2507", "Qwen/Qwen3-30B-A3B-Thinking-2507", ), "qwen3.5": ( "unsloth/Qwen3.5-0.8B", "unsloth/Qwen3.5-2B", "unsloth/Qwen3.5-4B", "unsloth/Qwen3.5-27B", "unsloth/Qwen3.5-35B-A3B", ), "zephyr": ( "unsloth/zephyr-sft-bnb-4bit", "unsloth/zephyr-sft", "HuggingFaceH4/mistral-7b-sft-beta", ), "chatml": ( "unsloth/yi-6b-bnb-4bit", "unsloth/yi-6b", "01-ai/Yi-6B", "unsloth/Hermes-2-Pro-Mistral-7B-bnb-4bit", "unsloth/Hermes-2-Pro-Mistral-7B", "NousResearch/Hermes-2-Pro-Mistral-7B", "unsloth/OpenHermes-2.5-Mistral-7B-bnb-4bit", "unsloth/OpenHermes-2.5-Mistral-7B", "teknium/OpenHermes-2.5-Mistral-7B", ), "gpt-oss": ( "unsloth/gpt-oss-20b-unsloth-bnb-4bit", "unsloth/gpt-oss-20b", "openai/gpt-oss-20b", "unsloth/gpt-oss-20b-unsloth-bnb-4bit", "unsloth/gpt-oss-120b-unsloth-bnb-4bit", "unsloth/gpt-oss-120b", "openai/gpt-oss-120b", "unsloth/gpt-oss-120b-unsloth-bnb-4bit", ), "starling": ( "unsloth/Starling-LM-7B-beta-bnb-4bit", "unsloth/Starling-LM-7B-beta", "Nexusflow/Starling-LM-7B-beta", ), "yi-chat": ( "unsloth/yi-34b-chat-bnb-4bit", "01-ai/Yi-6B-Chat", "01-ai/Yi-34B-Chat", ), "glm": ( "unsloth/GLM-4.7-Flash-unsloth-bnb-4bit", "unsloth/GLM-4.7-Flash", "THUDM/GLM-4.7-Flash", "unsloth/GLM-4.7-Flash-bnb-4bit", ), } MODEL_TO_TEMPLATE_MAPPER = {} for key, values in TEMPLATE_TO_MODEL_MAPPER.items(): for value in values: MODEL_TO_TEMPLATE_MAPPER[value] = key # Get lowercased lowered_key = key.lower() for value in values: MODEL_TO_TEMPLATE_MAPPER[value.lower()] = lowered_key TEMPLATE_TO_RESPONSES_MAPPER = { "gemma-3": { "instruction": "user\n", "response": "model\n", }, "gemma3n": { "instruction": "user\n", "response": "model\n", }, "qwen3.5": { "instruction": "<|im_start|>user\n", "response": "<|im_start|>assistant\n", }, "qwen3-instruct": { "instruction": "<|im_start|>user\n", "response": "<|im_start|>assistant\n", }, "qwen3-thinking": { "instruction": "<|im_start|>user\n", "response": "<|im_start|>assistant\n\n", }, "qwen3": { "instruction": "<|im_start|>user\n", "response": "<|im_start|>assistant\n", }, "qwen2.5": { "instruction": "<|im_start|>user\n", "response": "<|im_start|>assistant\n", }, "llama-3.2": { "instruction": "<|start_header_id|>user<|end_header_id|>\n\n", "response": "<|start_header_id|>assistant<|end_header_id|>\n\n", }, "llama-3.3": { "instruction": "<|start_header_id|>user<|end_header_id|>\n\n", "response": "<|start_header_id|>assistant<|end_header_id|>\n\n", }, "llama-3.1": { "instruction": "<|start_header_id|>user<|end_header_id|>\n\n", "response": "<|start_header_id|>assistant<|end_header_id|>\n\n", }, "llama3": { "instruction": "<|start_header_id|>user<|end_header_id|>\n\n", "response": "<|start_header_id|>assistant<|end_header_id|>\n\n", }, "phi-3": { "instruction": "<|user|>\n", "response": "<|assistant|>\n", }, "phi-3.5": { "instruction": "<|user|>\n", "response": "<|assistant|>\n", }, "phi-4": { "instruction": "<|im_start|>user<|im_sep|>", "response": "<|im_start|>assistant<|im_sep|>", }, "mistral": { "instruction": "[INST] ", "response": " [/INST]", }, "llama": { "instruction": "[INST] ", "response": " [/INST]", }, "chatml": { "instruction": "<|im_start|>user\n", "response": "<|im_start|>assistant\n", }, "zephyr": { "instruction": "<|user|>\n", "response": "<|assistant|>\n", }, "unsloth": { "instruction": ">>> User: ", "response": ">>> Assistant: ", }, "vicuna": { "instruction": "USER: ", "response": "ASSISTANT: ", }, "alpaca": { "instruction": "### Instruction:\n", "response": "### Response:\n", }, "gemma": { "instruction": "user\n", "response": "model\n", }, "gemma2": { "instruction": "user\n", "response": "model\n", }, "gpt-oss": { "instruction": "<|start|>user<|message|>", "response": "<|start|>assistant<|channel|>final<|message|>", }, "lfm-2": { "instruction": "<|im_start|>user\n", "response": "<|im_start|>assistant\n", }, "starling": { "instruction": "GPT4 Correct User: ", "response": "GPT4 Correct Assistant: ", }, "yi-chat": { "instruction": "<|im_start|>user\n", "response": "<|im_start|>assistant\n", }, "glm": { "instruction": "[gMASK]<|user|>", "response": "<|assistant|>", }, } ================================================ FILE: studio/backend/utils/datasets/vlm_processing.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ VLM (Vision-Language Model) processing utilities. This module contains functions for generating smart instructions for VLM datasets based on content analysis and heuristics. """ import re from itertools import islice def generate_smart_vlm_instruction( dataset, text_column = "text", image_column = "image", dataset_name = None, ): """ Generate smart, context-aware instruction for VLM datasets using heuristics. Strategy: 1. Check for explicit question/instruction columns → use that 2. Infer from text column name + sample content 3. Analyze dataset name for task hints 4. Fall back to generic instruction Returns: dict: { "instruction": str or None, # None means use column content "instruction_type": "explicit" | "inferred" | "generic", "uses_dynamic_instruction": bool, # True if instruction varies per sample "confidence": float, # 0.0 to 1.0 } """ column_names = set(next(iter(dataset)).keys()) sample = next(iter(dataset)) # ===== LEVEL 1: Explicit Instruction Columns ===== # Check for columns that contain per-sample instructions question_columns = ["question", "query", "prompt", "instruction", "user_prompt"] for col in question_columns: if col in column_names: # Check if this column has varied content (not just empty/same) sample_content = sample[col] if sample_content and str(sample_content).strip(): return { "instruction": None, # Signal to use column content "instruction_column": col, "instruction_type": "explicit", "uses_dynamic_instruction": True, "confidence": 1.0, } # ===== LEVEL 2: Infer from Column Names + Content ===== text_col_lower = text_column.lower() # Sample the text content to detect patterns text_sample = str(sample.get(text_column, ""))[:500] # First 500 chars # Task-specific keywords and their instructions task_patterns = { # OCR / Transcription "ocr": { "keywords": ["ocr", "transcribe", "transcript"], "content_hints": [ r"[A-Za-z\u0600-\u06FF]{10,}" ], # Long text passages (Latin/Arabic) "instruction": "Transcribe all the text shown in this image.", "confidence": 0.9, }, # LaTeX / Math "latex": { "keywords": ["latex", "math", "formula", "equation"], "content_hints": [r"\\[a-z]+\{", r"\^", r"_", r"\\frac"], # LaTeX commands "instruction": "Convert this image to LaTeX notation.", "confidence": 0.95, }, # Caption / Description "caption": { "keywords": ["caption", "description", "describe"], "content_hints": [], "instruction": "Provide a detailed description of this image.", "confidence": 0.85, }, # Medical / Radiology "medical": { "keywords": [ "medical", "radiology", "xray", "ct", "mri", "scan", "diagnosis", ], "content_hints": [r"\b(lesion|radiograph|patient|diagnosis|findings)\b"], "instruction": "Analyze this medical image and describe the key findings.", "confidence": 0.9, }, # Code / Programming "code": { "keywords": ["code", "program", "function", "algorithm"], "content_hints": [r"def |class |function|import |return "], "instruction": "Explain what this code visualization shows.", "confidence": 0.85, }, # Chart / Graph "chart": { "keywords": ["chart", "graph", "plot", "visualization", "diagram"], "content_hints": [r"\b(axis|legend|bar|line|pie|scatter)\b"], "instruction": "Describe this chart or graph, including key data points and trends.", "confidence": 0.85, }, # Document / Text Recognition "document": { "keywords": ["document", "page", "paragraph", "article"], "content_hints": [r"\n.*\n.*\n"], # Multi-line text "instruction": "Extract and transcribe the text from this document image.", "confidence": 0.85, }, } # Check column name matches best_match = None best_score = 0.0 for task_name, task_info in task_patterns.items(): score = 0.0 # Check column name if any(keyword in text_col_lower for keyword in task_info["keywords"]): score += 0.5 # Check dataset name if provided if dataset_name and any( keyword in dataset_name.lower() for keyword in task_info["keywords"] ): score += 0.3 # Check content patterns for pattern in task_info["content_hints"]: if re.search(pattern, text_sample, re.IGNORECASE): score += 0.4 break if score > best_score: best_score = score best_match = task_info if best_match and best_score > 0.5: # Confidence threshold return { "instruction": best_match["instruction"], "instruction_column": None, "instruction_type": "inferred", "uses_dynamic_instruction": False, "confidence": min(best_score, best_match["confidence"]), } # ===== LEVEL 3: Analyze Dataset Name ===== if dataset_name: name_lower = dataset_name.lower() # Common dataset name patterns if "vqa" in name_lower or "question" in name_lower: return { "instruction": "Answer the question about this image.", "instruction_column": None, "instruction_type": "inferred", "uses_dynamic_instruction": False, "confidence": 0.75, } if "coco" in name_lower or "flickr" in name_lower: return { "instruction": "Provide a detailed caption for this image.", "instruction_column": None, "instruction_type": "inferred", "uses_dynamic_instruction": False, "confidence": 0.75, } # ===== LEVEL 4: LLM-Assisted Instruction Generation ===== try: from .llm_assist import llm_generate_vlm_instruction sample_rows = [] for s in islice(dataset, 5): row = {} for col in s: val = s[col] if hasattr(val, "size") and hasattr(val, "mode"): # PIL Image row[col] = "" elif isinstance(val, list): row[col] = str(val)[:300] else: row[col] = str(val)[:300] sample_rows.append(row) llm_result = llm_generate_vlm_instruction( column_names = list(column_names), samples = sample_rows, dataset_name = dataset_name, ) if llm_result and llm_result.get("instruction"): print( f"\n[DEBUG] LLM-assisted VLM instruction generated: " f"'{llm_result['instruction']}' (confidence={llm_result.get('confidence', 'N/A')})\n", flush = True, ) return { "instruction": llm_result["instruction"], "instruction_column": None, "instruction_type": "llm_assisted", "uses_dynamic_instruction": False, "confidence": llm_result.get("confidence", 0.85), } except Exception as e: import logging logging.getLogger(__name__).debug(f"LLM-assisted instruction skipped: {e}") # ===== LEVEL 5: Generic Fallback ===== return { "instruction": "Describe this image in detail.", "instruction_column": None, "instruction_type": "generic", "uses_dynamic_instruction": False, "confidence": 0.5, } ================================================ FILE: studio/backend/utils/hardware/__init__.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Hardware detection and GPU utilities """ from .hardware import ( DeviceType, DEVICE, CHAT_ONLY, detect_hardware, get_device, is_apple_silicon, clear_gpu_cache, get_gpu_memory_info, log_gpu_memory, get_gpu_summary, get_package_versions, get_gpu_utilization, get_physical_gpu_count, get_visible_gpu_count, safe_num_proc, ) __all__ = [ "DeviceType", "DEVICE", "CHAT_ONLY", "detect_hardware", "get_device", "is_apple_silicon", "clear_gpu_cache", "get_gpu_memory_info", "log_gpu_memory", "get_gpu_summary", "get_package_versions", "get_gpu_utilization", "get_physical_gpu_count", "get_visible_gpu_count", "safe_num_proc", ] ================================================ FILE: studio/backend/utils/hardware/hardware.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Hardware detection — run once at startup, read everywhere. Usage: # At FastAPI lifespan startup: from utils.hardware import detect_hardware detect_hardware() # Anywhere else: from utils.hardware import DEVICE, DeviceType, is_apple_silicon if DEVICE == DeviceType.CUDA: import torch ... """ import platform import structlog from loggers import get_logger from enum import Enum from typing import Optional, Dict, Any logger = get_logger(__name__) # ========== Device Enum ========== class DeviceType(str, Enum): """Supported compute backends. Inherits from str so it serializes cleanly in JSON.""" CUDA = "cuda" MLX = "mlx" CPU = "cpu" # ========== Global State (set once by detect_hardware) ========== DEVICE: Optional[DeviceType] = None CHAT_ONLY: bool = True # No CUDA GPU -> GGUF chat only (Mac, CPU-only, etc.) # ========== Detection ========== def is_apple_silicon() -> bool: """Check if running on Apple Silicon hardware (pure platform check, no ML imports).""" return platform.system() == "Darwin" and platform.machine() == "arm64" def _has_torch() -> bool: """Check if PyTorch is importable.""" try: import torch return True except ImportError: return False def _has_mlx() -> bool: """Check if MLX is importable.""" try: import mlx.core return True except ImportError: return False def detect_hardware() -> DeviceType: """ Detect the best available compute device and set the module-level DEVICE global. Should be called exactly once during FastAPI lifespan startup. Safe to call multiple times (idempotent). Detection order: 1. CUDA (NVIDIA GPU, requires torch) 2. MLX (Apple Silicon via MLX framework) 3. CPU (fallback) """ global DEVICE, CHAT_ONLY CHAT_ONLY = True # reset -- only CUDA sets it to False # --- CUDA: try PyTorch --- if _has_torch(): import torch if torch.cuda.is_available(): DEVICE = DeviceType.CUDA CHAT_ONLY = False device_name = torch.cuda.get_device_properties(0).name print(f"Hardware detected: CUDA — {device_name}") return DEVICE # --- MLX: Apple Silicon --- if is_apple_silicon() and _has_mlx(): DEVICE = DeviceType.MLX chip = platform.processor() or platform.machine() print(f"Hardware detected: MLX — Apple Silicon ({chip})") return DEVICE # --- Fallback --- DEVICE = DeviceType.CPU print("Hardware detected: CPU (no GPU backend available)") return DEVICE # ========== Convenience helpers ========== def get_device() -> DeviceType: """ Return the detected device. Auto-detects if detect_hardware() hasn't been called yet. Prefer calling detect_hardware() explicitly at startup instead. """ global DEVICE if DEVICE is None: detect_hardware() return DEVICE def clear_gpu_cache(): """ Clear GPU memory cache for the current device. Safe to call on any platform — no-ops gracefully. """ import gc gc.collect() device = get_device() if device == DeviceType.CUDA: import torch torch.cuda.synchronize() torch.cuda.empty_cache() torch.cuda.ipc_collect() elif device == DeviceType.MLX: # MLX manages memory automatically; no explicit cache clear needed. # mlx.core has no empty_cache equivalent — gc.collect() above is enough. pass def get_gpu_memory_info() -> Dict[str, Any]: """ Get GPU memory information. Supports CUDA (NVIDIA), MLX (Apple Silicon), and CPU-only environments. """ device = get_device() # ---- CUDA path ---- if device == DeviceType.CUDA: try: import torch idx = torch.cuda.current_device() props = torch.cuda.get_device_properties(idx) total = props.total_memory allocated = torch.cuda.memory_allocated(idx) reserved = torch.cuda.memory_reserved(idx) return { "available": True, "backend": device.value, "device": idx, "device_name": props.name, "total_gb": total / (1024**3), "allocated_gb": allocated / (1024**3), "reserved_gb": reserved / (1024**3), "free_gb": (total - allocated) / (1024**3), "utilization_pct": (allocated / total) * 100, } except Exception as e: logger.error(f"Error getting CUDA GPU info: {e}") return {"available": False, "backend": device.value, "error": str(e)} # ---- MLX path (Apple Silicon) ---- if device == DeviceType.MLX: try: import mlx.core as mx import psutil # MLX uses unified memory — report system memory as the pool total = psutil.virtual_memory().total # MLX doesn't expose per-process GPU allocation; report 0 as allocated allocated = 0 return { "available": True, "backend": device.value, "device": 0, "device_name": f"Apple Silicon ({platform.processor() or platform.machine()})", "total_gb": total / (1024**3), "allocated_gb": allocated / (1024**3), "reserved_gb": 0, "free_gb": (total - allocated) / (1024**3), "utilization_pct": (allocated / total) * 100 if total else 0, } except Exception as e: logger.error(f"Error getting MLX GPU info: {e}") return {"available": False, "backend": device.value, "error": str(e)} # ---- CPU-only ---- return {"available": False, "backend": "cpu"} def log_gpu_memory(context: str): """Log GPU memory usage with context.""" memory_info = get_gpu_memory_info() if memory_info.get("available"): backend = memory_info.get("backend", "unknown").upper() device_name = memory_info.get("device_name", "") label = f"{backend}" + (f" ({device_name})" if device_name else "") logger.info( f"GPU Memory [{context}] {label}: " f"{memory_info['allocated_gb']:.2f}GB/{memory_info['total_gb']:.2f}GB " f"({memory_info['utilization_pct']:.1f}% used, " f"{memory_info['free_gb']:.2f}GB free)" ) else: logger.info(f"GPU Memory [{context}]: No GPU available (CPU-only)") # ========== GPU Summary & Package Versions ========== def get_gpu_summary() -> Dict[str, Any]: """ Return a compact summary of the primary GPU. Returns dict with keys: gpu_name – e.g. "NVIDIA L4" (or None) vram_total_gb – e.g. 22.17 (or None) """ mem = get_gpu_memory_info() if mem.get("available"): return { "gpu_name": mem.get("device_name"), "vram_total_gb": round(mem.get("total_gb", 0), 2), "vram_free_gb": round(mem.get("free_gb", 0), 2), } return {"gpu_name": None, "vram_total_gb": None, "vram_free_gb": None} def get_package_versions() -> Dict[str, Optional[str]]: """ Return the installed versions of key ML packages. Uses importlib.metadata (stdlib) so no subprocess is needed. CUDA version comes from torch.version.cuda. Returns dict with keys: unsloth, torch, transformers, cuda. Missing packages yield None. """ from importlib.metadata import version as pkg_version, PackageNotFoundError packages = ("unsloth", "torch", "transformers") versions: Dict[str, Optional[str]] = {} for name in packages: try: versions[name] = pkg_version(name) except PackageNotFoundError: versions[name] = None # CUDA toolkit version bundled with torch try: import torch versions["cuda"] = getattr(torch.version, "cuda", None) except Exception: versions["cuda"] = None return versions # ========== Live GPU Utilization (nvidia-smi) ========== def get_gpu_utilization() -> Dict[str, Any]: """ Return a live snapshot of GPU utilization via ``nvidia-smi``. Designed to be polled by the frontend during training (not streaming). Uses ``nvidia-smi --query-gpu`` which is the most accurate source for utilization %, temperature, and power draw – stats that PyTorch does not expose. Returns dict with keys: available – bool, whether stats could be retrieved gpu_utilization_pct – GPU core utilization % temperature_c – GPU temperature in °C vram_used_gb – VRAM currently used (GiB) vram_total_gb – VRAM total (GiB) vram_utilization_pct – VRAM used / total * 100 power_draw_w – current power draw (W) power_limit_w – power limit (W) power_utilization_pct – power draw / limit * 100 """ device = get_device() if device != DeviceType.CUDA: return {"available": False, "backend": device.value} def _parse_smi_value(raw: str): """Parse a single nvidia-smi CSV value. Returns float or None for [N/A].""" raw = raw.strip() if not raw or raw == "[N/A]": return None try: return float(raw) except (ValueError, TypeError): return None # ── nvidia-smi (most complete source) ─────────────────────── smi_data = {} try: import subprocess result = subprocess.run( [ "nvidia-smi", "--query-gpu=utilization.gpu,temperature.gpu," "memory.used,memory.total,power.draw,power.limit", "--format=csv,noheader,nounits", ], capture_output = True, text = True, timeout = 5, ) if result.returncode == 0 and result.stdout.strip(): # nvidia-smi outputs one line per GPU; take GPU 0 first_line = result.stdout.strip().splitlines()[0] parts = [p.strip() for p in first_line.split(",")] if len(parts) >= 6: smi_data = { "gpu_util": _parse_smi_value(parts[0]), "temp": _parse_smi_value(parts[1]), "vram_used_mb": _parse_smi_value(parts[2]), "vram_total_mb": _parse_smi_value(parts[3]), "power_draw": _parse_smi_value(parts[4]), "power_limit": _parse_smi_value(parts[5]), } except FileNotFoundError: logger.debug("nvidia-smi not found, falling back to torch.cuda") except Exception as e: logger.warning(f"nvidia-smi query failed: {e}") # ── Backfill VRAM from torch.cuda if nvidia-smi returned [N/A] ── vram_used_mb = smi_data.get("vram_used_mb") vram_total_mb = smi_data.get("vram_total_mb") if vram_used_mb is None or vram_total_mb is None: try: import torch idx = torch.cuda.current_device() props = torch.cuda.get_device_properties(idx) if vram_total_mb is None: vram_total_mb = props.total_memory / (1024**2) # bytes → MiB if vram_used_mb is None: vram_used_mb = torch.cuda.memory_allocated(idx) / (1024**2) except Exception as e: logger.debug(f"torch.cuda VRAM backfill failed: {e}") # ── Build response ────────────────────────────────────────── gpu_util = smi_data.get("gpu_util") temp = smi_data.get("temp") power_draw = smi_data.get("power_draw") power_limit = smi_data.get("power_limit") vram_used_gb = round(vram_used_mb / 1024, 2) if vram_used_mb is not None else None vram_total_gb = ( round(vram_total_mb / 1024, 2) if vram_total_mb is not None else None ) vram_pct = ( round((vram_used_mb / vram_total_mb) * 100, 1) if vram_used_mb is not None and vram_total_mb and vram_total_mb > 0 else None ) power_pct = ( round((power_draw / power_limit) * 100, 1) if power_draw is not None and power_limit and power_limit > 0 else None ) # If we got at least something useful, report available has_any = any(v is not None for v in [gpu_util, temp, vram_used_gb, power_draw]) if not has_any: return {"available": False, "backend": device.value} return { "available": True, "backend": device.value, "gpu_utilization_pct": gpu_util, "temperature_c": temp, "vram_used_gb": vram_used_gb, "vram_total_gb": vram_total_gb, "vram_utilization_pct": vram_pct, "power_draw_w": power_draw, "power_limit_w": power_limit, "power_utilization_pct": power_pct, } # ========== Multi-GPU Detection & Safe num_proc ========== _physical_gpu_count: Optional[int] = None _visible_gpu_count: Optional[int] = None def get_physical_gpu_count() -> int: """ Return the number of physical NVIDIA GPUs on the machine. Uses ``nvidia-smi -L`` which is NOT affected by CUDA_VISIBLE_DEVICES, so it always reflects the true hardware count. Result is cached after the first call. """ global _physical_gpu_count if _physical_gpu_count is not None: return _physical_gpu_count try: import subprocess result = subprocess.run( ["nvidia-smi", "-L"], capture_output = True, text = True, timeout = 5, ) if result.returncode == 0 and result.stdout.strip(): _physical_gpu_count = len(result.stdout.strip().splitlines()) else: _physical_gpu_count = 1 except Exception: _physical_gpu_count = 1 return _physical_gpu_count def get_visible_gpu_count() -> int: """ Return the number of GPUs visible to this process. Respects ``CUDA_VISIBLE_DEVICES`` -- if set, only those GPUs count. Falls back to physical count if the env var is unset or torch is unavailable. Result is cached after the first call. """ global _visible_gpu_count if _visible_gpu_count is not None: return _visible_gpu_count import os cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES") if cuda_visible is not None: # "" means zero GPUs, "0" means 1, "0,1,2" means 3 cuda_visible = cuda_visible.strip() if cuda_visible == "" or cuda_visible == "-1": _visible_gpu_count = 0 else: _visible_gpu_count = len([x for x in cuda_visible.split(",") if x.strip()]) return _visible_gpu_count # CUDA_VISIBLE_DEVICES not set -- try torch, fall back to physical count try: import torch _visible_gpu_count = torch.cuda.device_count() except Exception: _visible_gpu_count = get_physical_gpu_count() return _visible_gpu_count def safe_num_proc(desired: Optional[int] = None) -> int: """ Return a safe ``num_proc`` for ``dataset.map()`` calls. On Windows, always returns 1 because Python uses ``spawn`` instead of ``fork`` for multiprocessing -- the overhead of re-importing torch, transformers, unsloth etc. per worker is typically slower than single-process for normal dataset sizes. On multi-GPU machines (where multiple GPUs are *visible* to this process) the NVIDIA driver spawns extra background threads, making ``os.fork()`` prone to deadlocks when many workers are created. This helper caps ``num_proc`` to 4 on such machines. When ``CUDA_VISIBLE_DEVICES`` restricts to a single GPU, the cap does not apply. Args: desired: The num_proc you *want*. If None, auto-computes from ``os.cpu_count()``. Returns: A safe integer ≥ 1. """ import os import sys # Windows uses 'spawn' for multiprocessing -- the overhead of re-importing # torch/transformers/unsloth per worker is typically slower than single-process. if sys.platform == "win32": return 1 if desired is None or not isinstance(desired, int): desired = max(1, os.cpu_count() // 3) visible = get_visible_gpu_count() if visible > 1: capped = min(4, desired) logger.info( f"Multi-GPU detected ({visible} visible GPUs) -- " f"capping num_proc {desired} -> {capped} to avoid fork deadlocks" ) return capped return desired ================================================ FILE: studio/backend/utils/inference/__init__.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Inference utility functions """ from utils.inference.inference_config import load_inference_config __all__ = ["load_inference_config"] ================================================ FILE: studio/backend/utils/inference/inference_config.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Inference configuration loading utilities. This module provides functions to load inference parameters (temperature, top_p, top_k, min_p) from model YAML configuration files, with fallback to default.yaml. Includes family-based lookup from inference_defaults.json for GGUF models. """ from pathlib import Path from typing import Dict, Any, Optional import json import yaml import structlog from loggers import get_logger from utils.models.model_config import load_model_defaults logger = get_logger(__name__) # ── Family-based inference defaults (loaded once, cached) ────────────── _FAMILY_DEFAULTS: Optional[Dict[str, Any]] = None _FAMILY_PATTERNS: Optional[list] = None def _load_family_defaults(): """Load and cache inference_defaults.json.""" global _FAMILY_DEFAULTS, _FAMILY_PATTERNS if _FAMILY_DEFAULTS is not None: return json_path = ( Path(__file__).parent.parent.parent / "assets" / "configs" / "inference_defaults.json" ) try: with open(json_path, "r", encoding = "utf-8") as f: data = json.load(f) _FAMILY_DEFAULTS = data.get("families", {}) _FAMILY_PATTERNS = data.get("patterns", []) except Exception as e: logger.warning(f"Failed to load inference_defaults.json: {e}") _FAMILY_DEFAULTS = {} _FAMILY_PATTERNS = [] def get_family_inference_params(model_id: str) -> Dict[str, Any]: """ Look up recommended inference parameters by model family. Extracts the model family from the identifier (e.g. "unsloth/Qwen3.5-9B-GGUF" -> "qwen3.5") and returns the matching parameters from inference_defaults.json. Args: model_id: Model identifier (e.g. "unsloth/Qwen3.5-9B-GGUF") Returns: Dict with inference params, or empty dict if no family match. """ _load_family_defaults() if not _FAMILY_PATTERNS or not _FAMILY_DEFAULTS: return {} # Normalize: lowercase, strip org prefix normalized = model_id.lower() if "/" in normalized: normalized = normalized.split("/", 1)[1] # Match against patterns (ordered longest-match-first in the JSON) for pattern in _FAMILY_PATTERNS: if pattern in normalized: params = _FAMILY_DEFAULTS.get(pattern, {}) if params: return dict(params) return {} def _has_specific_yaml(model_identifier: str) -> bool: """Check if a model has its own YAML config (not just default.yaml).""" from utils.models.model_config import _REVERSE_MODEL_MAPPING script_dir = Path(__file__).parent.parent.parent defaults_dir = script_dir / "assets" / "configs" / "model_defaults" # Check the mapping if model_identifier.lower() in _REVERSE_MODEL_MAPPING: return True # Check for exact filename match model_filename = model_identifier.replace("/", "_") + ".yaml" for config_path in defaults_dir.rglob(model_filename): if config_path.is_file(): return True return False def load_inference_config(model_identifier: str) -> Dict[str, Any]: """ Load inference configuration parameters for a model. Priority chain: 1. Model-specific YAML (if it exists and has inference params) 2. Family-based defaults from inference_defaults.json 3. default.yaml fallback Args: model_identifier: Model identifier (e.g., "unsloth/llama-3-8b-bnb-4bit") Returns: Dictionary containing inference parameters: { "temperature": float, "top_p": float, "top_k": int, "min_p": float } """ # Load model defaults to get inference parameters model_defaults = load_model_defaults(model_identifier) # Load default.yaml for fallback values script_dir = Path(__file__).parent.parent.parent defaults_dir = script_dir / "assets" / "configs" / "model_defaults" default_config_path = defaults_dir / "default.yaml" default_inference = {} if default_config_path.exists(): try: with open(default_config_path, "r", encoding = "utf-8") as f: default_config = yaml.safe_load(f) or {} default_inference = default_config.get("inference", {}) except Exception as e: logger.warning(f"Failed to load default.yaml: {e}") # Family-based defaults from inference_defaults.json family_params = get_family_inference_params(model_identifier) model_inference = model_defaults.get("inference", {}) # If the model has its own YAML config, those values take priority over family defaults. # If it only fell back to default.yaml, family defaults take priority. has_own_yaml = _has_specific_yaml(model_identifier) def _get_param(key, hardcoded_default): if has_own_yaml: # Model-specific YAML wins, then family fills gaps, then default.yaml val = model_inference.get(key) if val is not None and isinstance(val, (int, float)): return val if key in family_params: return family_params[key] return default_inference.get(key, hardcoded_default) else: # No model-specific YAML: family wins, then default.yaml if key in family_params: return family_params[key] return default_inference.get(key, hardcoded_default) inference_config = { "temperature": _get_param("temperature", 0.7), "top_p": _get_param("top_p", 0.95), "top_k": _get_param("top_k", -1), "min_p": _get_param("min_p", 0.01), "presence_penalty": _get_param("presence_penalty", 0.0), "trust_remote_code": model_inference.get( "trust_remote_code", default_inference.get("trust_remote_code", False) ), } return inference_config ================================================ FILE: studio/backend/utils/models/__init__.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Model and LoRA configuration handling """ from .model_config import ( ModelConfig, GgufVariantInfo, is_vision_model, is_embedding_model, detect_audio_type, is_audio_input_type, VALID_AUDIO_TYPES, scan_trained_loras, scan_exported_models, load_model_defaults, get_base_model_from_lora, load_model_config, list_gguf_variants, MODEL_NAME_MAPPING, UI_STATUS_INDICATORS, ) from .checkpoints import scan_checkpoints __all__ = [ "ModelConfig", "GgufVariantInfo", "is_vision_model", "is_embedding_model", "detect_audio_type", "is_audio_input_type", "VALID_AUDIO_TYPES", "scan_trained_loras", "scan_exported_models", "load_model_defaults", "get_base_model_from_lora", "load_model_config", "list_gguf_variants", "MODEL_NAME_MAPPING", "UI_STATUS_INDICATORS", "scan_checkpoints", ] ================================================ FILE: studio/backend/utils/models/checkpoints.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Checkpoint scanning utilities for discovering training runs and their checkpoints. """ import json import structlog from loggers import get_logger from pathlib import Path from typing import List, Optional, Tuple from utils.paths import outputs_root, resolve_output_dir logger = get_logger(__name__) def _read_checkpoint_loss(checkpoint_path: Path) -> Optional[float]: """ Read the training loss from a checkpoint's trainer_state.json. Returns the loss from the last log_history entry, or None if unavailable. """ trainer_state = checkpoint_path / "trainer_state.json" if not trainer_state.exists(): return None try: with open(trainer_state) as f: state = json.load(f) log_history = state.get("log_history", []) if log_history: return log_history[-1].get("loss") except Exception as e: logger.debug(f"Could not read loss from {trainer_state}: {e}") return None def scan_checkpoints( outputs_dir: str = str(outputs_root()), ) -> List[Tuple[str, List[Tuple[str, str, Optional[float]]], dict]]: """ Scan outputs folder for training runs and their checkpoints. Returns: List of tuples: [(model_name, [(display_name, checkpoint_path, loss), ...], metadata), ...] metadata keys: base_model, peft_type, lora_rank (all optional) The first entry in each checkpoint list is the main adapter; its loss is set to the loss of the last (highest-step) intermediate checkpoint. """ models = [] outputs_path = resolve_output_dir(outputs_dir) if not outputs_path.exists(): logger.warning(f"Outputs directory not found: {outputs_dir}") return models try: for item in outputs_path.iterdir(): if not item.is_dir(): continue config_file = item / "config.json" adapter_config = item / "adapter_config.json" if not (config_file.exists() or adapter_config.exists()): continue # Extract training metadata from adapter_config.json / config.json metadata: dict = {} try: if adapter_config.exists(): cfg = json.loads(adapter_config.read_text()) metadata["base_model"] = cfg.get("base_model_name_or_path") metadata["peft_type"] = cfg.get("peft_type") metadata["lora_rank"] = cfg.get("r") elif config_file.exists(): cfg = json.loads(config_file.read_text()) metadata["base_model"] = cfg.get("_name_or_path") except Exception: pass # Fallback: extract base model name from folder name # e.g. "unsloth_Llama-3.2-3B-Instruct_1771227800" → "unsloth/Llama-3.2-3B-Instruct" if not metadata.get("base_model"): parts = item.name.rsplit("_", 1) if len(parts) == 2 and parts[1].isdigit(): name_part = parts[0] idx = name_part.find("_") if idx > 0: metadata["base_model"] = ( name_part[:idx] + "/" + name_part[idx + 1 :] ) else: metadata["base_model"] = name_part # This is a valid training run checkpoints = [] # Placeholder for the main adapter — loss filled from last checkpoint below checkpoints.append((item.name, str(item), None)) # Scan for intermediate checkpoints (checkpoint-N subdirs) for sub in sorted(item.iterdir()): if not sub.is_dir() or not sub.name.startswith("checkpoint-"): continue sub_config = sub / "config.json" sub_adapter = sub / "adapter_config.json" if sub_config.exists() or sub_adapter.exists(): loss = _read_checkpoint_loss(sub) checkpoints.append((sub.name, str(sub), loss)) # Assign the last checkpoint's loss to the main adapter entry if len(checkpoints) > 1: last_checkpoint_loss = checkpoints[-1][2] checkpoints[0] = ( checkpoints[0][0], checkpoints[0][1], last_checkpoint_loss, ) models.append((item.name, checkpoints, metadata)) logger.debug( f"Found model: {item.name} with {len(checkpoints)} checkpoint(s)" ) # Sort by modification time (newest first) models.sort(key = lambda x: Path(x[1][0][1]).stat().st_mtime, reverse = True) logger.info(f"Found {len(models)} training runs in {outputs_dir}") return models except Exception as e: logger.error(f"Error scanning checkpoints: {e}") return [] ================================================ FILE: studio/backend/utils/models/model_config.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Model and LoRA configuration handling """ from transformers import AutoConfig from dataclasses import dataclass from typing import Optional, Dict, Any from utils.paths import ( normalize_path, is_local_path, is_model_cached, outputs_root, exports_root, resolve_output_dir, resolve_export_dir, ) from utils.utils import without_hf_auth import structlog from loggers import get_logger import os import subprocess import sys from pathlib import Path from typing import List, Tuple import json import yaml logger = get_logger(__name__) # Model name mapping: maps all equivalent model names to their canonical YAML config file # Format: "canonical_model_name.yaml": [list of all equivalent model names] # Based on the model mapper provided - canonical filename is based on the first model name in the mapper MODEL_NAME_MAPPING = { # ── Embedding models ── "unsloth_all-MiniLM-L6-v2.yaml": [ "unsloth/all-MiniLM-L6-v2", "sentence-transformers/all-MiniLM-L6-v2", ], "unsloth_bge-m3.yaml": [ "unsloth/bge-m3", "BAAI/bge-m3", ], "unsloth_embeddinggemma-300m.yaml": [ "unsloth/embeddinggemma-300m", "google/embeddinggemma-300m", ], "unsloth_gte-modernbert-base.yaml": [ "unsloth/gte-modernbert-base", "Alibaba-NLP/gte-modernbert-base", ], "unsloth_Qwen3-Embedding-0.6B.yaml": [ "unsloth/Qwen3-Embedding-0.6B", "Qwen/Qwen3-Embedding-0.6B", "unsloth/Qwen3-Embedding-4B", "Qwen/Qwen3-Embedding-4B", ], # ── Other models ── "unsloth_answerdotai_ModernBERT-large.yaml": [ "answerdotai/ModernBERT-large", ], "unsloth_Qwen2.5-Coder-7B-Instruct-bnb-4bit.yaml": [ "unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit", "unsloth/Qwen2.5-Coder-7B-Instruct", "Qwen/Qwen2.5-Coder-7B-Instruct", ], "unsloth_codegemma-7b-bnb-4bit.yaml": [ "unsloth/codegemma-7b-bnb-4bit", "unsloth/codegemma-7b", "google/codegemma-7b", ], "unsloth_ERNIE-4.5-21B-A3B-PT.yaml": [ "unsloth/ERNIE-4.5-21B-A3B-PT", ], "unsloth_ERNIE-4.5-VL-28B-A3B-PT.yaml": [ "unsloth/ERNIE-4.5-VL-28B-A3B-PT", ], "tiiuae_Falcon-H1-0.5B-Instruct.yaml": [ "tiiuae/Falcon-H1-0.5B-Instruct", "unsloth/Falcon-H1-0.5B-Instruct", ], "unsloth_functiongemma-270m-it.yaml": [ "unsloth/functiongemma-270m-it-unsloth-bnb-4bit", "google/functiongemma-270m-it", "unsloth/functiongemma-270m-it-unsloth-bnb-4bit", ], "unsloth_gemma-2-2b.yaml": [ "unsloth/gemma-2-2b-bnb-4bit", "google/gemma-2-2b", ], "unsloth_gemma-2-27b-bnb-4bit.yaml": [ "unsloth/gemma-2-9b-bnb-4bit", "unsloth/gemma-2-9b", "google/gemma-2-9b", "unsloth/gemma-2-27b", "google/gemma-2-27b", ], "unsloth_gemma-3-4b-pt.yaml": [ "unsloth/gemma-3-4b-pt-unsloth-bnb-4bit", "google/gemma-3-4b-pt", "unsloth/gemma-3-4b-pt-bnb-4bit", ], "unsloth_gemma-3-4b-it.yaml": [ "unsloth/gemma-3-4b-it-unsloth-bnb-4bit", "google/gemma-3-4b-it", "unsloth/gemma-3-4b-it-bnb-4bit", ], "unsloth_gemma-3-27b-it.yaml": [ "unsloth/gemma-3-27b-it-unsloth-bnb-4bit", "google/gemma-3-27b-it", "unsloth/gemma-3-27b-it-bnb-4bit", ], "unsloth_gemma-3-270m-it.yaml": [ "unsloth/gemma-3-270m-it-unsloth-bnb-4bit", "google/gemma-3-270m-it", "unsloth/gemma-3-270m-it-bnb-4bit", ], "unsloth_gemma-3n-E4B-it.yaml": [ "unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit", "google/gemma-3n-E4B-it", "unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit", ], "unsloth_gemma-3n-E4B.yaml": [ "unsloth/gemma-3n-E4B-unsloth-bnb-4bit", "google/gemma-3n-E4B", ], "unsloth_gpt-oss-20b.yaml": [ "openai/gpt-oss-20b", "unsloth/gpt-oss-20b-unsloth-bnb-4bit", "unsloth/gpt-oss-20b-BF16", ], "unsloth_gpt-oss-120b.yaml": [ "openai/gpt-oss-120b", "unsloth/gpt-oss-120b-unsloth-bnb-4bit", ], "unsloth_granite-4.0-350m-unsloth-bnb-4bit.yaml": [ "unsloth/granite-4.0-350m", "ibm-granite/granite-4.0-350m", "unsloth/granite-4.0-350m-bnb-4bit", ], "unsloth_granite-4.0-h-micro.yaml": [ "ibm-granite/granite-4.0-h-micro", "unsloth/granite-4.0-h-micro-bnb-4bit", "unsloth/granite-4.0-h-micro-unsloth-bnb-4bit", ], "unsloth_LFM2-1.2B.yaml": [ "unsloth/LFM2-1.2B", ], "unsloth_llama-3-8b-bnb-4bit.yaml": [ "unsloth/llama-3-8b", "meta-llama/Meta-Llama-3-8B", ], "unsloth_llama-3-8b-Instruct-bnb-4bit.yaml": [ "unsloth/llama-3-8b-Instruct", "meta-llama/Meta-Llama-3-8B-Instruct", ], "unsloth_Meta-Llama-3.1-70B-bnb-4bit.yaml": [ "unsloth/Meta-Llama-3.1-8B-bnb-4bit", "unsloth/Meta-Llama-3.1-8B-unsloth-bnb-4bit", "meta-llama/Meta-Llama-3.1-8B", "unsloth/Meta-Llama-3.1-70B-bnb-4bit", "unsloth/Meta-Llama-3.1-8B", "unsloth/Meta-Llama-3.1-70B", "meta-llama/Meta-Llama-3.1-70B", "unsloth/Meta-Llama-3.1-405B-bnb-4bit", "meta-llama/Meta-Llama-3.1-405B", ], "unsloth_Meta-Llama-3.1-8B-Instruct-bnb-4bit.yaml": [ "unsloth/Meta-Llama-3.1-8B-Instruct-unsloth-bnb-4bit", "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit", "meta-llama/Meta-Llama-3.1-8B-Instruct", "unsloth/Meta-Llama-3.1-8B-Instruct", "RedHatAI/Llama-3.1-8B-Instruct-FP8", "unsloth/Llama-3.1-8B-Instruct-FP8-Block", "unsloth/Llama-3.1-8B-Instruct-FP8-Dynamic", ], "unsloth_Llama-3.2-3B-Instruct.yaml": [ "unsloth/Llama-3.2-3B-Instruct-unsloth-bnb-4bit", "meta-llama/Llama-3.2-3B-Instruct", "unsloth/Llama-3.2-3B-Instruct-bnb-4bit", "RedHatAI/Llama-3.2-3B-Instruct-FP8", "unsloth/Llama-3.2-3B-Instruct-FP8-Block", "unsloth/Llama-3.2-3B-Instruct-FP8-Dynamic", ], "unsloth_Llama-3.2-1B-Instruct.yaml": [ "unsloth/Llama-3.2-1B-Instruct-unsloth-bnb-4bit", "meta-llama/Llama-3.2-1B-Instruct", "unsloth/Llama-3.2-1B-Instruct-bnb-4bit", "RedHatAI/Llama-3.2-1B-Instruct-FP8", "unsloth/Llama-3.2-1B-Instruct-FP8-Block", "unsloth/Llama-3.2-1B-Instruct-FP8-Dynamic", ], "unsloth_Llama-3.2-11B-Vision-Instruct.yaml": [ "unsloth/Llama-3.2-11B-Vision-Instruct-unsloth-bnb-4bit", "meta-llama/Llama-3.2-11B-Vision-Instruct", "unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit", ], "unsloth_Llama-3.3-70B-Instruct.yaml": [ "unsloth/Llama-3.3-70B-Instruct-unsloth-bnb-4bit", "meta-llama/Llama-3.3-70B-Instruct", "unsloth/Llama-3.3-70B-Instruct-bnb-4bit", "RedHatAI/Llama-3.3-70B-Instruct-FP8", "unsloth/Llama-3.3-70B-Instruct-FP8-Block", "unsloth/Llama-3.3-70B-Instruct-FP8-Dynamic", ], "unsloth_Llasa-3B.yaml": [ "HKUSTAudio/Llasa-1B", "unsloth/Llasa-3B", ], "unsloth_Magistral-Small-2509-unsloth-bnb-4bit.yaml": [ "unsloth/Magistral-Small-2509", "mistralai/Magistral-Small-2509", "unsloth/Magistral-Small-2509-bnb-4bit", ], "unsloth_Ministral-3-3B-Instruct-2512.yaml": [ "unsloth/Ministral-3-3B-Instruct-2512", ], "unsloth_mistral-7b-v0.3-bnb-4bit.yaml": [ "unsloth/mistral-7b-v0.3-bnb-4bit", "unsloth/mistral-7b-v0.3", "mistralai/Mistral-7B-v0.3", ], "unsloth_Mistral-Nemo-Base-2407-bnb-4bit.yaml": [ "unsloth/Mistral-Nemo-Base-2407-bnb-4bit", "unsloth/Mistral-Nemo-Base-2407", "mistralai/Mistral-Nemo-Base-2407", "unsloth/Mistral-Nemo-Instruct-2407-bnb-4bit", "unsloth/Mistral-Nemo-Instruct-2407", "mistralai/Mistral-Nemo-Instruct-2407", ], "unsloth_Mistral-Small-Instruct-2409.yaml": [ "unsloth/Mistral-Small-Instruct-2409-bnb-4bit", "mistralai/Mistral-Small-Instruct-2409", ], "unsloth_mistral-7b-instruct-v0.3-bnb-4bit.yaml": [ "unsloth/mistral-7b-instruct-v0.3-bnb-4bit", "unsloth/mistral-7b-instruct-v0.3", "mistralai/Mistral-7B-Instruct-v0.3", ], "unsloth_Qwen2.5-1.5B-Instruct.yaml": [ "unsloth/Qwen2.5-1.5B-Instruct-unsloth-bnb-4bit", "Qwen/Qwen2.5-1.5B-Instruct", "unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit", ], "unsloth_Nemotron-3-Nano-30B-A3B.yaml": [ "unsloth/Nemotron-3-Nano-30B-A3B", ], "unsloth_orpheus-3b-0.1-ft.yaml": [ "unsloth/orpheus-3b-0.1-ft", "unsloth/orpheus-3b-0.1-ft-unsloth-bnb-4bit", "canopylabs/orpheus-3b-0.1-ft", "unsloth/orpheus-3b-0.1-ft-bnb-4bit", ], "OuteAI_Llama-OuteTTS-1.0-1B.yaml": [ "OuteAI/Llama-OuteTTS-1.0-1B", "unsloth/Llama-OuteTTS-1.0-1B", "unsloth/llama-outetts-1.0-1b", "OuteAI/OuteTTS-1.0-0.6B", "unsloth/OuteTTS-1.0-0.6B", "unsloth/outetts-1.0-0.6b", ], "unsloth_PaddleOCR-VL.yaml": [ "unsloth/PaddleOCR-VL", ], "unsloth_Phi-3-medium-4k-instruct.yaml": [ "unsloth/Phi-3-medium-4k-instruct-bnb-4bit", "microsoft/Phi-3-medium-4k-instruct", ], "unsloth_Phi-3.5-mini-instruct.yaml": [ "unsloth/Phi-3.5-mini-instruct-bnb-4bit", "microsoft/Phi-3.5-mini-instruct", ], "unsloth_Phi-4.yaml": [ "unsloth/phi-4-unsloth-bnb-4bit", "microsoft/phi-4", "unsloth/phi-4-bnb-4bit", ], "unsloth_Pixtral-12B-2409.yaml": [ "unsloth/Pixtral-12B-2409-unsloth-bnb-4bit", "mistralai/Pixtral-12B-2409", "unsloth/Pixtral-12B-2409-bnb-4bit", ], "unsloth_Qwen2-7B.yaml": [ "unsloth/Qwen2-7B-bnb-4bit", "Qwen/Qwen2-7B", ], "unsloth_Qwen2-VL-7B-Instruct.yaml": [ "unsloth/Qwen2-VL-7B-Instruct-unsloth-bnb-4bit", "Qwen/Qwen2-VL-7B-Instruct", "unsloth/Qwen2-VL-7B-Instruct-bnb-4bit", ], "unsloth_Qwen2.5-7B.yaml": [ "unsloth/Qwen2.5-7B-unsloth-bnb-4bit", "Qwen/Qwen2.5-7B", "unsloth/Qwen2.5-7B-bnb-4bit", ], "unsloth_Qwen2.5-Coder-1.5B-Instruct.yaml": [ "unsloth/Qwen2.5-Coder-1.5B-Instruct-bnb-4bit", "Qwen/Qwen2.5-Coder-1.5B-Instruct", ], "unsloth_Qwen2.5-Coder-14B-Instruct.yaml": [ "unsloth/Qwen2.5-Coder-14B-Instruct-bnb-4bit", "Qwen/Qwen2.5-Coder-14B-Instruct", ], "unsloth_Qwen2.5-VL-7B-Instruct-bnb-4bit.yaml": [ "unsloth/Qwen2.5-VL-7B-Instruct", "Qwen/Qwen2.5-VL-7B-Instruct", "unsloth/Qwen2.5-VL-7B-Instruct-unsloth-bnb-4bit", ], "unsloth_Qwen3-0.6B.yaml": [ "unsloth/Qwen3-0.6B-unsloth-bnb-4bit", "Qwen/Qwen3-0.6B", "unsloth/Qwen3-0.6B-bnb-4bit", "Qwen/Qwen3-0.6B-FP8", "unsloth/Qwen3-0.6B-FP8", ], "unsloth_Qwen3-4B-Instruct-2507.yaml": [ "unsloth/Qwen3-4B-Instruct-2507-unsloth-bnb-4bit", "Qwen/Qwen3-4B-Instruct-2507", "unsloth/Qwen3-4B-Instruct-2507-bnb-4bit", "Qwen/Qwen3-4B-Instruct-2507-FP8", "unsloth/Qwen3-4B-Instruct-2507-FP8", ], "unsloth_Qwen3-4B-Thinking-2507.yaml": [ "unsloth/Qwen3-4B-Thinking-2507-unsloth-bnb-4bit", "Qwen/Qwen3-4B-Thinking-2507", "unsloth/Qwen3-4B-Thinking-2507-bnb-4bit", "Qwen/Qwen3-4B-Thinking-2507-FP8", "unsloth/Qwen3-4B-Thinking-2507-FP8", ], "unsloth_Qwen3-14B-Base-unsloth-bnb-4bit.yaml": [ "unsloth/Qwen3-14B-Base", "Qwen/Qwen3-14B-Base", "unsloth/Qwen3-14B-Base-bnb-4bit", ], "unsloth_Qwen3-14B.yaml": [ "unsloth/Qwen3-14B-unsloth-bnb-4bit", "Qwen/Qwen3-14B", "unsloth/Qwen3-14B-bnb-4bit", "Qwen/Qwen3-14B-FP8", "unsloth/Qwen3-14B-FP8", ], "unsloth_Qwen3-32B.yaml": [ "unsloth/Qwen3-32B-unsloth-bnb-4bit", "Qwen/Qwen3-32B", "unsloth/Qwen3-32B-bnb-4bit", "Qwen/Qwen3-32B-FP8", "unsloth/Qwen3-32B-FP8", ], "unsloth_Qwen3-VL-8B-Instruct-unsloth-bnb-4bit.yaml": [ "Qwen/Qwen3-VL-8B-Instruct-FP8", "unsloth/Qwen3-VL-8B-Instruct-FP8", "unsloth/Qwen3-VL-8B-Instruct", "Qwen/Qwen3-VL-8B-Instruct", "unsloth/Qwen3-VL-8B-Instruct-bnb-4bit", ], "sesame_csm-1b.yaml": [ "sesame/csm-1b", "unsloth/csm-1b", ], "Spark-TTS-0.5B_LLM.yaml": [ "Spark-TTS-0.5B/LLM", "unsloth/Spark-TTS-0.5B", ], "unsloth_tinyllama-bnb-4bit.yaml": [ "unsloth/tinyllama", "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T", ], "unsloth_whisper-large-v3.yaml": [ "unsloth/whisper-large-v3", "openai/whisper-large-v3", ], } # Reverse mapping for quick lookup: model_name -> canonical_filename _REVERSE_MODEL_MAPPING = {} for canonical_file, model_names in MODEL_NAME_MAPPING.items(): for model_name in model_names: _REVERSE_MODEL_MAPPING[model_name.lower()] = canonical_file def load_model_config( model_name: str, use_auth: bool = False, token: Optional[str] = None, trust_remote_code: bool = True, ): """ Load model config with optional authentication control. """ if token: # Explicit token provided - use it return AutoConfig.from_pretrained( model_name, trust_remote_code = trust_remote_code, token = token ) if not use_auth: # Load without any authentication (for public model checks) with without_hf_auth(): return AutoConfig.from_pretrained( model_name, trust_remote_code = trust_remote_code, token = None, ) # Use default authentication (cached tokens) return AutoConfig.from_pretrained( model_name, trust_remote_code = trust_remote_code, ) # VLM architecture suffixes and known VLM model_type values. _VLM_ARCH_SUFFIXES = ("ForConditionalGeneration", "ForVisionText2Text") _VLM_MODEL_TYPES = { "phi3_v", "llava", "llava_next", "llava_onevision", "internvl_chat", "cogvlm2", "minicpmv", } # Pre-computed .venv_t5 path and backend dir for subprocess version switching. _VENV_T5_DIR = str(Path.home() / ".unsloth" / "studio" / ".venv_t5") _BACKEND_DIR = str(Path(__file__).resolve().parent.parent.parent) # Inline script executed in a subprocess with transformers 5.x activated. # Receives model_name and token via argv, prints JSON result to stdout. _VISION_CHECK_SCRIPT = r""" import sys, os, json os.environ["TOKENIZERS_PARALLELISM"] = "false" # Activate transformers 5.x venv_t5 = sys.argv[1] backend_dir = sys.argv[2] model_name = sys.argv[3] token = sys.argv[4] if len(sys.argv) > 4 and sys.argv[4] != "" else None sys.path.insert(0, venv_t5) if backend_dir not in sys.path: sys.path.insert(0, backend_dir) try: from transformers import AutoConfig kwargs = {"trust_remote_code": True} if token: kwargs["token"] = token config = AutoConfig.from_pretrained(model_name, **kwargs) is_vlm = False if hasattr(config, "architectures"): is_vlm = any( x.endswith(("ForConditionalGeneration", "ForVisionText2Text")) for x in config.architectures ) if not is_vlm and hasattr(config, "vision_config"): is_vlm = True if not is_vlm and hasattr(config, "img_processor"): is_vlm = True if not is_vlm and hasattr(config, "image_token_index"): is_vlm = True if not is_vlm and hasattr(config, "model_type"): vlm_types = {"phi3_v","llava","llava_next","llava_onevision", "internvl_chat","cogvlm2","minicpmv"} if config.model_type in vlm_types: is_vlm = True model_type = getattr(config, "model_type", "unknown") archs = getattr(config, "architectures", []) print(json.dumps({"is_vision": is_vlm, "model_type": model_type, "architectures": archs})) except Exception as exc: print(json.dumps({"error": str(exc)})) sys.exit(1) """ def _is_vision_model_subprocess( model_name: str, hf_token: Optional[str] = None ) -> bool: """Run is_vision_model check in a subprocess with transformers 5.x. Same pattern as training/inference workers: spawn a clean subprocess with .venv_t5/ prepended to sys.path so AutoConfig recognizes newer architectures (glm4_moe_lite, etc.). """ token_arg = hf_token or "" try: result = subprocess.run( [ sys.executable, "-c", _VISION_CHECK_SCRIPT, _VENV_T5_DIR, _BACKEND_DIR, model_name, token_arg, ], capture_output = True, text = True, timeout = 60, ) if result.returncode != 0: stderr = result.stderr.strip() logger.warning( "Vision check subprocess failed for '%s': %s", model_name, stderr or result.stdout.strip(), ) return False data = json.loads(result.stdout.strip()) if "error" in data: logger.warning( "Vision check subprocess error for '%s': %s", model_name, data["error"], ) return False is_vlm = data["is_vision"] logger.info( "Vision check (subprocess, transformers 5.x) for '%s': " "model_type=%s, architectures=%s, is_vision=%s", model_name, data.get("model_type"), data.get("architectures"), is_vlm, ) return is_vlm except subprocess.TimeoutExpired: logger.warning("Vision check subprocess timed out for '%s'", model_name) return False except Exception as exc: logger.warning("Vision check subprocess failed for '%s': %s", model_name, exc) return False def is_vision_model(model_name: str, hf_token: Optional[str] = None) -> bool: """ Detect vision-language models (VLMs) by checking architecture in config. Works for fine-tuned models since they inherit the base architecture. For models that require transformers 5.x (e.g. GLM-4.7-Flash), the check runs in a subprocess with .venv_t5/ activated — same pattern as the training and inference workers. Args: model_name: Model identifier (HF repo or local path) hf_token: Optional HF token for accessing gated/private models """ # Models that need transformers 5.x must be checked in a subprocess # because AutoConfig in the main process (transformers 4.57.x) doesn't # recognize their architectures. from utils.transformers_version import needs_transformers_5 if needs_transformers_5(model_name): logger.info( "Model '%s' needs transformers 5.x — checking vision via subprocess", model_name, ) return _is_vision_model_subprocess(model_name, hf_token = hf_token) try: config = load_model_config(model_name, use_auth = True, token = hf_token) # Exclude audio-only models that share ForConditionalGeneration suffix # (e.g. CsmForConditionalGeneration, WhisperForConditionalGeneration) _audio_only_model_types = {"csm", "whisper"} model_type = getattr(config, "model_type", None) if model_type in _audio_only_model_types: return False # Check 1: Architecture class name patterns if hasattr(config, "architectures"): is_vlm = any(x.endswith(_VLM_ARCH_SUFFIXES) for x in config.architectures) if is_vlm: logger.info( f"Model {model_name} detected as VLM: architecture {config.architectures}" ) return True # Check 2: Has vision_config (most VLMs: LLaVA, Gemma-3, Qwen2-VL, etc.) if hasattr(config, "vision_config"): logger.info(f"Model {model_name} detected as VLM: has vision_config") return True # Check 3: Has img_processor (Phi-3.5 Vision uses this instead of vision_config) if hasattr(config, "img_processor"): logger.info(f"Model {model_name} detected as VLM: has img_processor") return True # Check 4: Has image_token_index (common in VLMs for image placeholder tokens) if hasattr(config, "image_token_index"): logger.info(f"Model {model_name} detected as VLM: has image_token_index") return True # Check 5: Known VLM model_type values that may not match above checks if hasattr(config, "model_type"): if config.model_type in _VLM_MODEL_TYPES: logger.info( f"Model {model_name} detected as VLM: model_type={config.model_type}" ) return True return False except Exception as e: logger.warning(f"Could not determine if {model_name} is vision model: {e}") return False VALID_AUDIO_TYPES = ("snac", "csm", "bicodec", "dac", "whisper", "audio_vlm") # Cache detection results per session to avoid repeated API calls _audio_detection_cache: Dict[str, Optional[str]] = {} # Tokenizer token patterns → audio_type (all 6 types detected from tokenizer_config.json) _AUDIO_TOKEN_PATTERNS = { "csm": lambda tokens: "<|AUDIO|>" in tokens and "<|audio_eos|>" in tokens, "whisper": lambda tokens: "<|startoftranscript|>" in tokens, "audio_vlm": lambda tokens: "" in tokens, "bicodec": lambda tokens: any(t.startswith("<|bicodec_") for t in tokens), "dac": lambda tokens: "<|audio_start|>" in tokens and "<|audio_end|>" in tokens and "<|text_start|>" in tokens and "<|text_end|>" in tokens, "snac": lambda tokens: sum(1 for t in tokens if t.startswith(" 10000, } def detect_audio_type(model_name: str, hf_token: Optional[str] = None) -> Optional[str]: """ Dynamically detect if a model is an audio model and return its type. Fully dynamic — works for any model, not just known ones. Uses tokenizer_config.json special tokens to detect all 6 audio types. Returns: audio_type string ('snac', 'csm', 'bicodec', 'dac', 'whisper', 'audio_vlm') or None. """ if model_name in _audio_detection_cache: return _audio_detection_cache[model_name] result = _detect_audio_from_tokenizer(model_name, hf_token) _audio_detection_cache[model_name] = result if result: logger.info(f"Model {model_name} detected as audio model: audio_type={result}") return result def _detect_audio_from_tokenizer( model_name: str, hf_token: Optional[str] = None ) -> Optional[str]: """Detect audio type from tokenizer special tokens (for LLM-based audio models). First checks local HF cache, then fetches tokenizer_config.json from HuggingFace. Checks added_tokens_decoder for distinctive patterns. """ def _check_token_patterns(tok_config: dict) -> Optional[str]: added = tok_config.get("added_tokens_decoder", {}) if not added: return None token_contents = [v.get("content", "") for v in added.values()] for audio_type, check_fn in _AUDIO_TOKEN_PATTERNS.items(): if check_fn(token_contents): return audio_type return None # 1) Check local HF cache first (works for gated/offline models) try: from huggingface_hub.constants import HF_HUB_CACHE cache_dir = Path(HF_HUB_CACHE) repo_dir_name = f"models--{model_name.replace('/', '--')}" repo_dir = cache_dir / repo_dir_name if repo_dir.exists(): snapshots_dir = repo_dir / "snapshots" if snapshots_dir.exists(): for snapshot in snapshots_dir.iterdir(): for tok_path in [ "tokenizer_config.json", "LLM/tokenizer_config.json", ]: tok_file = snapshot / tok_path if tok_file.exists(): tok_config = json.loads(tok_file.read_text()) result = _check_token_patterns(tok_config) if result: return result except Exception as e: logger.debug(f"Could not check local cache for {model_name}: {e}") # 2) Fall back to HuggingFace API try: import requests import os paths_to_try = ["tokenizer_config.json", "LLM/tokenizer_config.json"] # Use provided token, or fall back to env token = hf_token or os.environ.get("HF_TOKEN") headers = {} if token: headers["Authorization"] = f"Bearer {token}" for tok_path in paths_to_try: url = f"https://huggingface.co/{model_name}/resolve/main/{tok_path}" resp = requests.get(url, headers = headers, timeout = 15) if not resp.ok: continue tok_config = resp.json() result = _check_token_patterns(tok_config) if result: return result return None except Exception as e: logger.debug( f"Could not detect audio type from tokenizer for {model_name}: {e}" ) return None def is_audio_input_type(audio_type: Optional[str]) -> bool: """Check if an audio_type accepts audio input (ASR/speech understanding). Whisper (ASR) and audio_vlm (Gemma3n) accept audio input. """ return audio_type in ("whisper", "audio_vlm") def _is_mmproj(filename: str) -> bool: """Check if a GGUF filename is a vision projection (mmproj) file.""" return "mmproj" in filename.lower() def detect_mmproj_file(path: str) -> Optional[str]: """ Find the mmproj (vision projection) GGUF file in a directory. Args: path: Directory to search — or a .gguf file (uses its parent dir). Returns: Full path to the mmproj .gguf file, or None if not found. """ p = Path(path) search_dir = p.parent if p.is_file() else p if not search_dir.is_dir(): return None for f in search_dir.glob("*.gguf"): if _is_mmproj(f.name): return str(f.resolve()) return None def detect_gguf_model(path: str) -> Optional[str]: """ Check if the given local path is or contains a GGUF model file. Handles two cases: 1. path is a direct .gguf file path 2. path is a directory containing .gguf files Skips mmproj (vision projection) files — those must be passed via ``--mmproj``, not ``-m``. Use :func:`detect_mmproj_file` instead. Returns the full path to the .gguf file if found, None otherwise. For HuggingFace repo detection, use detect_gguf_model_remote() instead. """ p = Path(path) # Case 1: direct .gguf file if p.suffix == ".gguf" and p.is_file(): if _is_mmproj(p.name): return None return str(p.resolve()) # Case 2: directory containing .gguf files (skip mmproj) if p.is_dir(): gguf_files = sorted( (f for f in p.glob("*.gguf") if not _is_mmproj(f.name)), key = lambda f: f.stat().st_size, reverse = True, ) if gguf_files: return str(gguf_files[0].resolve()) return None # Preferred GGUF quantization levels, in descending priority. # Q4_K_M is a good default: small, fast, acceptable quality. # UD (Unsloth Dynamic) variants are always preferred over standard quants # because they provide better quality per bit. If the repo has no UD variants # (e.g., bartowski repos), the standard quants are used as fallback. # Ordered by best size/quality tradeoff, not raw quality. _GGUF_QUANT_PREFERENCE = [ # UD variants (best quality per bit) -- Q4 is the sweet spot "UD-Q4_K_XL", "UD-Q4_K_L", "UD-Q5_K_XL", "UD-Q3_K_XL", "UD-Q6_K_XL", "UD-Q6_K_S", "UD-Q8_K_XL", "UD-Q2_K_XL", "UD-IQ4_NL", "UD-IQ4_XS", "UD-IQ3_S", "UD-IQ3_XXS", "UD-IQ2_M", "UD-IQ2_XXS", "UD-IQ1_M", "UD-IQ1_S", # Standard quants (fallback for non-Unsloth repos) "Q4_K_M", "Q4_K_S", "Q5_K_M", "Q5_K_S", "Q6_K", "Q8_0", "Q3_K_M", "Q3_K_L", "Q3_K_S", "Q2_K", "Q2_K_L", "IQ4_NL", "IQ4_XS", "IQ3_M", "IQ3_XXS", "IQ2_M", "IQ1_M", "F16", "BF16", "F32", ] def _pick_best_gguf(filenames: list[str]) -> Optional[str]: """ Pick the best GGUF file from a list of filenames. Prefers quantization levels in _GGUF_QUANT_PREFERENCE order. Falls back to the first .gguf file found. """ gguf_files = [f for f in filenames if f.endswith(".gguf")] if not gguf_files: return None # Try preferred quantization levels for quant in _GGUF_QUANT_PREFERENCE: for f in gguf_files: if quant in f: return f # Fallback: first GGUF file return gguf_files[0] @dataclass class GgufVariantInfo: """A single GGUF quantization variant from a HuggingFace repo.""" filename: str # e.g., "gemma-3-4b-it-Q4_K_M.gguf" quant: str # e.g., "Q4_K_M" (extracted from filename) size_bytes: int # file size def _extract_quant_label(filename: str) -> str: """ Extract quantization label like Q4_K_M, IQ4_XS, BF16 from a GGUF filename. Examples: "gemma-3-4b-it-Q4_K_M.gguf" → "Q4_K_M" "model-IQ4_NL.gguf" → "IQ4_NL" "model-BF16.gguf" → "BF16" "model-UD-IQ1_S.gguf" → "UD-IQ1_S" "model-UD-TQ1_0.gguf" → "UD-TQ1_0" "MXFP4_MOE/model-MXFP4_MOE-0001.gguf"→ "MXFP4_MOE" """ import re # Use only the basename (rfilename may include directory) basename = filename.rsplit("/", 1)[-1] # Strip .gguf and any shard suffix (-00001-of-00010) stem = re.sub(r"-\d{3,}-of-\d{3,}", "", basename.rsplit(".", 1)[0]) # Match known quantization patterns match = re.search( r"(UD-)?" # Optional UD- prefix (Ultra Discrete) r"(MXFP[0-9]+(?:_[A-Z0-9]+)*" # MXFP variants: MXFP4, MXFP4_MOE r"|IQ[0-9]+_[A-Z]+(?:_[A-Z0-9]+)?" # IQ variants: IQ4_XS, IQ4_NL, IQ1_S r"|TQ[0-9]+_[0-9]+" # Ternary quant: TQ1_0, TQ2_0 r"|Q[0-9]+_K_[A-Z]+" # K-quant: Q4_K_M, Q3_K_S r"|Q[0-9]+_[0-9]+" # Standard: Q8_0, Q5_1 r"|Q[0-9]+_K" # Short K-quant: Q6_K r"|BF16|F16|F32)", # Full precision stem, re.IGNORECASE, ) if match: prefix = match.group(1) or "" return f"{prefix}{match.group(2)}" # Fallback: last segment after hyphen return stem.split("-")[-1] def list_gguf_variants( repo_id: str, hf_token: Optional[str] = None, ) -> tuple[list[GgufVariantInfo], bool]: """ List all GGUF quantization variants in a HuggingFace repo. Separates main model files from mmproj (vision projection) files. The presence of mmproj files indicates a vision-capable model. Returns: (variants, has_vision): list of non-mmproj GGUF variants + vision flag. """ from huggingface_hub import model_info as hf_model_info info = hf_model_info(repo_id, token = hf_token, files_metadata = True) variants: list[GgufVariantInfo] = [] has_vision = False quant_totals: dict[str, int] = {} # quant -> total bytes quant_first_file: dict[str, str] = {} # quant -> first filename (for display) for sibling in info.siblings: fname = sibling.rfilename if not fname.endswith(".gguf"): continue size = sibling.size or 0 # mmproj files are vision projection models, not main model files if "mmproj" in fname.lower(): has_vision = True continue quant = _extract_quant_label(fname) quant_totals[quant] = quant_totals.get(quant, 0) + size if quant not in quant_first_file: quant_first_file[quant] = fname for quant, total_size in quant_totals.items(): variants.append( GgufVariantInfo( filename = quant_first_file[quant], quant = quant, size_bytes = total_size, ) ) # Sort by size descending (largest = best quality first). # Recommended pinning and OOM demotion are handled client-side # where GPU VRAM info is available. variants.sort(key = lambda v: -v.size_bytes) return variants, has_vision def detect_gguf_model_remote( repo_id: str, hf_token: Optional[str] = None, ) -> Optional[str]: """ Check if a HuggingFace repo contains GGUF files. Returns the filename of the best GGUF file in the repo, or None. """ try: from huggingface_hub import model_info as hf_model_info info = hf_model_info(repo_id, token = hf_token) repo_files = [s.rfilename for s in info.siblings] return _pick_best_gguf(repo_files) except Exception as e: logger.debug(f"Could not check GGUF files for '{repo_id}': {e}") return None def download_gguf_file( repo_id: str, filename: str, hf_token: Optional[str] = None, ) -> str: """ Download a specific GGUF file from a HuggingFace repo. Returns the local path to the downloaded file. """ from huggingface_hub import hf_hub_download local_path = hf_hub_download( repo_id = repo_id, filename = filename, token = hf_token, ) return local_path # Cache embedding detection results per session to avoid repeated HF API calls _embedding_detection_cache: Dict[tuple, bool] = {} def is_embedding_model(model_name: str, hf_token: Optional[str] = None) -> bool: """ Detect embedding/sentence-transformer models using HuggingFace model metadata. Uses a belt-and-suspenders approach combining three signals: 1. "sentence-transformers" in model tags 2. "feature-extraction" in model tags 3. pipeline_tag is "sentence-similarity" or "feature-extraction" This catches all known embedding models including those like gte-modernbert whose library_name is "transformers" rather than "sentence-transformers". Args: model_name: Model identifier (HF repo or local path) hf_token: Optional HF token for accessing gated/private models Returns: True if the model is an embedding model, False otherwise. Defaults to False for local paths or on errors. """ cache_key = (model_name, hf_token) if cache_key in _embedding_detection_cache: return _embedding_detection_cache[cache_key] # Local paths: check for sentence-transformer marker file (modules.json) if is_local_path(model_name): local_dir = normalize_path(model_name) is_emb = os.path.isfile(os.path.join(local_dir, "modules.json")) _embedding_detection_cache[cache_key] = is_emb return is_emb try: from huggingface_hub import model_info as hf_model_info info = hf_model_info(model_name, token = hf_token) tags = set(info.tags or []) pipeline_tag = info.pipeline_tag or "" is_emb = ( "sentence-transformers" in tags or "feature-extraction" in tags or pipeline_tag in ("sentence-similarity", "feature-extraction") ) _embedding_detection_cache[cache_key] = is_emb if is_emb: logger.info( f"Model {model_name} detected as embedding model: " f"pipeline_tag={pipeline_tag}, " f"sentence-transformers in tags={('sentence-transformers' in tags)}, " f"feature-extraction in tags={('feature-extraction' in tags)}" ) return is_emb except Exception as e: logger.warning(f"Could not determine if {model_name} is embedding model: {e}") _embedding_detection_cache[cache_key] = False return False def scan_trained_loras(outputs_dir: str = str(outputs_root())) -> List[Tuple[str, str]]: """ Scan outputs folder for trained LoRA adapters. Returns: List of tuples: [(display_name, adapter_path), ...] Example: [ ("unsloth_Meta-Llama-3.1_...", "./outputs/unsloth_Meta-Llama-3.1_.../"), ("my_finetuned_model", "./outputs/my_finetuned_model/"), ] """ trained_loras = [] outputs_path = resolve_output_dir(outputs_dir) if not outputs_path.exists(): logger.warning(f"Outputs directory not found: {outputs_dir}") return trained_loras try: for item in outputs_path.iterdir(): if item.is_dir(): # Check if this directory contains a LoRA adapter adapter_config = item / "adapter_config.json" adapter_model = item / "adapter_model.safetensors" if adapter_config.exists() or adapter_model.exists(): display_name = item.name adapter_path = str(item) trained_loras.append((display_name, adapter_path)) logger.debug(f"Found trained LoRA: {display_name}") # Sort by modification time (newest first) trained_loras.sort(key = lambda x: Path(x[1]).stat().st_mtime, reverse = True) logger.info( f"Found {len(trained_loras)} trained LoRA adapters in {outputs_dir}" ) return trained_loras except Exception as e: logger.error(f"Error scanning outputs folder: {e}") return [] def scan_exported_models( exports_dir: str = str(exports_root()), ) -> List[Tuple[str, str, str, Optional[str]]]: """ Scan exports folder for exported models (merged, LoRA, GGUF). Supports two directory layouts: - Two-level: {run}/{checkpoint}/ (merged & LoRA exports) - Flat: {name}-finetune-gguf/ (GGUF exports) Returns: List of tuples: [(display_name, model_path, export_type, base_model), ...] export_type: "lora" | "merged" | "gguf" """ results = [] exports_path = resolve_export_dir(exports_dir) if not exports_path.exists(): return results try: for run_dir in exports_path.iterdir(): if not run_dir.is_dir(): continue # Check for flat GGUF export (e.g. exports/gemma-3-4b-it-finetune-gguf/) # Filter out mmproj (vision projection) files — they aren't loadable as main models gguf_files = [f for f in run_dir.glob("*.gguf") if not _is_mmproj(f.name)] if gguf_files: base_model = None export_meta = run_dir / "export_metadata.json" try: if export_meta.exists(): meta = json.loads(export_meta.read_text()) base_model = meta.get("base_model") except Exception: pass display_name = run_dir.name model_path = str(gguf_files[0]) # path to the .gguf file results.append((display_name, model_path, "gguf", base_model)) logger.debug(f"Found GGUF export: {display_name}") continue # Two-level: {run}/{checkpoint}/ for checkpoint_dir in run_dir.iterdir(): if not checkpoint_dir.is_dir(): continue adapter_config = checkpoint_dir / "adapter_config.json" config_file = checkpoint_dir / "config.json" has_weights = any(checkpoint_dir.glob("*.safetensors")) or any( checkpoint_dir.glob("*.bin") ) has_gguf = any(checkpoint_dir.glob("*.gguf")) base_model = None export_type = None if adapter_config.exists(): export_type = "lora" try: cfg = json.loads(adapter_config.read_text()) base_model = cfg.get("base_model_name_or_path") except Exception: pass elif config_file.exists() and has_weights: export_type = "merged" export_meta = checkpoint_dir / "export_metadata.json" try: if export_meta.exists(): meta = json.loads(export_meta.read_text()) base_model = meta.get("base_model") except Exception: pass elif has_gguf: export_type = "gguf" gguf_list = list(checkpoint_dir.glob("*.gguf")) # Check checkpoint_dir first, then fall back to parent run_dir # (export.py writes metadata to the top-level export directory) for meta_dir in (checkpoint_dir, run_dir): export_meta = meta_dir / "export_metadata.json" try: if export_meta.exists(): meta = json.loads(export_meta.read_text()) base_model = meta.get("base_model") if base_model: break except Exception: pass display_name = f"{run_dir.name} / {checkpoint_dir.name}" model_path = str(gguf_list[0]) if gguf_list else str(checkpoint_dir) results.append((display_name, model_path, export_type, base_model)) logger.debug(f"Found GGUF export: {display_name}") continue else: continue # Fallback: read base model from the original training run's # adapter_config.json in ./outputs/{run_name}/ if not base_model: outputs_adapter_cfg = ( resolve_output_dir(run_dir.name) / "adapter_config.json" ) try: if outputs_adapter_cfg.exists(): cfg = json.loads(outputs_adapter_cfg.read_text()) base_model = cfg.get("base_model_name_or_path") except Exception: pass display_name = f"{run_dir.name} / {checkpoint_dir.name}" model_path = str(checkpoint_dir) results.append((display_name, model_path, export_type, base_model)) logger.debug(f"Found exported model: {display_name} ({export_type})") results.sort(key = lambda x: Path(x[1]).stat().st_mtime, reverse = True) logger.info(f"Found {len(results)} exported models in {exports_dir}") return results except Exception as e: logger.error(f"Error scanning exports folder: {e}") return [] def get_base_model_from_lora(lora_path: str) -> Optional[str]: """ Read the base model name from a LoRA adapter's config. Args: lora_path: Path to the LoRA adapter directory Returns: Base model identifier (e.g., "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit") or None if not found Example: >>> get_base_model_from_lora("./outputs/unsloth_Meta-Llama-3.1_.../") "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit" """ try: lora_path_obj = Path(lora_path) # Try adapter_config.json first adapter_config_path = lora_path_obj / "adapter_config.json" if adapter_config_path.exists(): with open(adapter_config_path, "r") as f: config = json.load(f) base_model = config.get("base_model_name_or_path") if base_model: logger.info( f"Detected base model from adapter_config.json: {base_model}" ) return base_model # Fallback: try training_args.bin (requires torch) training_args_path = lora_path_obj / "training_args.bin" if training_args_path.exists(): try: import torch training_args = torch.load(training_args_path) if hasattr(training_args, "model_name_or_path"): base_model = training_args.model_name_or_path logger.info( f"Detected base model from training_args.bin: {base_model}" ) return base_model except Exception as e: logger.warning(f"Could not load training_args.bin: {e}") # Last resort: parse from directory name # Format: unsloth_Meta-Llama-3.1-8B-Instruct-bnb-4bit_timestamp dir_name = lora_path_obj.name if dir_name.startswith("unsloth_"): # Remove timestamp suffix (usually _1234567890) parts = dir_name.split("_") # Reconstruct model name if len(parts) >= 2: model_parts = parts[1:-1] # Skip "unsloth" and timestamp base_model = "unsloth/" + "_".join(model_parts) logger.info(f"Detected base model from directory name: {base_model}") return base_model logger.warning(f"Could not detect base model for LoRA: {lora_path}") return None except Exception as e: logger.error(f"Error reading base model from LoRA config: {e}") return None # Status indicators that appear in UI dropdowns UI_STATUS_INDICATORS = [" (Ready)", " (Loading...)", " (Active)", "↓ "] def load_model_defaults(model_name: str) -> Dict[str, Any]: """ Load default training parameters for a model from YAML file. Args: model_name: Model identifier (e.g., "unsloth/Meta-Llama-3.1-8B-bnb-4bit") Returns: Dictionary with default parameters from YAML file, or empty dict if not found The function looks for a YAML file in configs/model_defaults/ (including subfolders) based on the model name or its aliases from MODEL_NAME_MAPPING. If no specific file exists, it falls back to default.yaml. """ try: # Get the script directory to locate configs script_dir = Path(__file__).parent.parent.parent defaults_dir = script_dir / "assets" / "configs" / "model_defaults" # First, check if model is in the mapping if model_name.lower() in _REVERSE_MODEL_MAPPING: canonical_file = _REVERSE_MODEL_MAPPING[model_name.lower()] # Search in subfolders and root for config_path in defaults_dir.rglob(canonical_file): if config_path.is_file(): with open(config_path, "r", encoding = "utf-8") as f: config = yaml.safe_load(f) or {} logger.info( f"Loaded model defaults from {config_path} (via mapping)" ) return config # If model_name is a local path (e.g. /home/.../Spark-TTS-0.5B/LLM from # adapter_config.json), try matching the last 1-2 path components against # the registry (e.g. "Spark-TTS-0.5B/LLM"). if model_name not in _REVERSE_MODEL_MAPPING and ( model_name.startswith("/") or model_name.startswith(".") ): parts = Path(model_name).parts for depth in [2, 1]: if len(parts) >= depth: suffix = "/".join(parts[-depth:]) if suffix in _REVERSE_MODEL_MAPPING: canonical_file = _REVERSE_MODEL_MAPPING[suffix] for config_path in defaults_dir.rglob(canonical_file): if config_path.is_file(): with open(config_path, "r", encoding = "utf-8") as f: config = yaml.safe_load(f) or {} logger.info( f"Loaded model defaults from {config_path} (via path suffix '{suffix}')" ) return config # Try exact model name match (for backward compatibility) model_filename = model_name.replace("/", "_") + ".yaml" # Search in subfolders and root for config_path in defaults_dir.rglob(model_filename): if config_path.is_file(): with open(config_path, "r", encoding = "utf-8") as f: config = yaml.safe_load(f) or {} logger.info(f"Loaded model defaults from {config_path}") return config # Fall back to default.yaml default_config_path = defaults_dir / "default.yaml" if default_config_path.exists(): with open(default_config_path, "r", encoding = "utf-8") as f: config = yaml.safe_load(f) or {} logger.info(f"Loaded default model defaults from {default_config_path}") return config logger.warning(f"No default config found for model {model_name}") return {} except Exception as e: logger.error(f"Error loading model defaults for {model_name}: {e}") return {} @dataclass class ModelConfig: """Configuration for a model to load""" identifier: str # Clean model identifier (org/name or path) display_name: str # Original UI display name path: str # Normalized filesystem path is_local: bool # Is this a local file vs HF model? is_cached: bool # Is this already in HF cache? is_vision: bool # Is this a vision model? is_lora: bool # Is this a lora adapter? is_gguf: bool = False # Is this a GGUF model? is_audio: bool = False # Is this a TTS audio model? audio_type: Optional[str] = ( None # Audio codec type: 'snac', 'csm', 'bicodec', 'dac' ) has_audio_input: bool = False # Accepts audio input (ASR/speech understanding) gguf_file: Optional[str] = None # Full path to the .gguf file (local mode) gguf_mmproj_file: Optional[str] = ( None # Full path to the mmproj .gguf file (vision projection) ) gguf_hf_repo: Optional[str] = ( None # HF repo ID for -hf mode (e.g. "unsloth/gemma-3-4b-it-GGUF") ) gguf_variant: Optional[str] = None # Quantization variant (e.g. "Q4_K_M") base_model: Optional[str] = None # Base model (for LoRAs) @classmethod def from_lora_path( cls, lora_path: str, hf_token: Optional[str] = None ) -> Optional["ModelConfig"]: """ Create ModelConfig from a local LoRA adapter path. Automatically detects the base model from adapter config. Args: lora_path: Path to LoRA adapter (e.g., "./outputs/unsloth_Meta-Llama-3.1_.../") hf_token: HF token for vision detection Returns: ModelConfig for the LoRA adapter """ try: lora_path_obj = Path(lora_path) if not lora_path_obj.exists(): logger.error(f"LoRA path does not exist: {lora_path}") return None # Get base model base_model = get_base_model_from_lora(lora_path) if not base_model: logger.error(f"Could not determine base model for LoRA: {lora_path}") return None # Check if base model is vision is_vision = is_vision_model(base_model, hf_token = hf_token) # Check if base model is audio audio_type = detect_audio_type(base_model, hf_token = hf_token) display_name = lora_path_obj.name identifier = lora_path # Use path as identifier for local LoRAs return cls( identifier = identifier, display_name = display_name, path = lora_path, is_local = True, is_cached = True, # Local LoRAs are always "cached" is_vision = is_vision, is_lora = True, is_audio = audio_type is not None and audio_type != "audio_vlm", audio_type = audio_type, has_audio_input = is_audio_input_type(audio_type), base_model = base_model, ) except Exception as e: logger.error(f"Error creating ModelConfig from LoRA path: {e}") return None @classmethod def from_identifier( cls, model_id: str, hf_token: Optional[str] = None, is_lora: bool = False, gguf_variant: Optional[str] = None, ) -> Optional["ModelConfig"]: """ Create ModelConfig from a clean model identifier. For FastAPI routes where the frontend sends sanitized model paths. No Gradio dropdown parsing - expects clean identifiers like: - "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit" - "./outputs/my_lora_adapter" - "/absolute/path/to/model" Args: model_id: Clean model identifier (HF repo name or local path) hf_token: Optional HF token for vision detection on gated models is_lora: Whether this is a LoRA adapter gguf_variant: Optional GGUF quantization variant (e.g. "Q4_K_M"). For remote GGUF repos, specifies which quant to load via -hf. If None, auto-selects using _pick_best_gguf(). Returns: ModelConfig or None if configuration cannot be created """ if not model_id or not model_id.strip(): return None identifier = model_id.strip() is_local = is_local_path(identifier) path = normalize_path(identifier) if is_local else identifier # Add unsloth/ prefix for shorthand HF models if not is_local and "/" not in identifier: identifier = f"unsloth/{identifier}" path = identifier # Enforce lowercase for remote Hugging Face identifiers to prevent cache duplication # Hugging Face Hub APIs are case-insensitive remotely, but case-sensitive locally (repo_folder_name). if not is_local: identifier = identifier.lower() path = path.lower() # Auto-detect GGUF models (check before LoRA/vision detection) if is_local: gguf_file = detect_gguf_model(path) if gguf_file: display_name = Path(gguf_file).stem logger.info(f"Detected local GGUF model: {gguf_file}") # Detect vision: check if base model is vision, then look for mmproj mmproj_file = None gguf_is_vision = False gguf_dir = Path(gguf_file).parent # Determine if this is a vision model from export metadata base_is_vision = False meta_path = gguf_dir / "export_metadata.json" if meta_path.exists(): try: meta = json.loads(meta_path.read_text()) base = meta.get("base_model") if base and is_vision_model(base, hf_token = hf_token): base_is_vision = True logger.info(f"GGUF base model '{base}' is a vision model") except Exception as e: logger.debug(f"Could not read export metadata: {e}") # If vision (or mmproj happens to exist), find the mmproj file mmproj_file = detect_mmproj_file(gguf_file) if mmproj_file: gguf_is_vision = True logger.info(f"Detected mmproj for vision: {mmproj_file}") elif base_is_vision: logger.warning( f"Base model is vision but no mmproj file found in {gguf_dir}" ) return cls( identifier = identifier, display_name = display_name, path = path, is_local = True, is_cached = True, is_vision = gguf_is_vision, is_lora = False, is_gguf = True, gguf_file = gguf_file, gguf_mmproj_file = mmproj_file, ) else: # Check if the HF repo contains GGUF files gguf_filename = detect_gguf_model_remote(identifier, hf_token = hf_token) if gguf_filename: # Preflight: verify llama-server binary exists BEFORE user waits # for a multi-GB download that llama-server handles natively from core.inference.llama_cpp import LlamaCppBackend if not LlamaCppBackend._find_llama_server_binary(): raise RuntimeError( "llama-server binary not found — cannot load GGUF models. " "Run setup.sh to build it, or set LLAMA_SERVER_PATH." ) # Use list_gguf_variants() to detect vision & resolve variant variants, has_vision = list_gguf_variants(identifier, hf_token = hf_token) variant = gguf_variant if not variant: # Auto-select best quantization variant_filenames = [v.filename for v in variants] best = _pick_best_gguf(variant_filenames) if best: variant = _extract_quant_label(best) else: variant = "Q4_K_M" # Fallback — llama-server's own default display_name = f"{identifier.split('/')[-1]} ({variant})" logger.info( f"Detected remote GGUF repo '{identifier}', " f"variant={variant}, vision={has_vision}" ) return cls( identifier = identifier, display_name = display_name, path = identifier, is_local = False, is_cached = False, is_vision = has_vision, is_lora = False, is_gguf = True, gguf_file = None, gguf_hf_repo = identifier, gguf_variant = variant, ) # Auto-detect LoRA for local paths (check adapter_config.json on disk) if not is_lora and is_local: detected_base = get_base_model_from_lora(path) if detected_base: is_lora = True logger.info( f"Auto-detected local LoRA adapter at '{path}' (base: {detected_base})" ) # Auto-detect LoRA for remote HF models (check repo file listing) if not is_lora and not is_local: try: from huggingface_hub import model_info as hf_model_info info = hf_model_info(identifier, token = hf_token) repo_files = [s.rfilename for s in info.siblings] if "adapter_config.json" in repo_files: is_lora = True logger.info(f"Auto-detected remote LoRA adapter: '{identifier}'") except Exception as e: logger.debug( f"Could not check remote LoRA status for '{identifier}': {e}" ) # Handle LoRA adapters base_model = None if is_lora: if is_local: # Local LoRA: read adapter_config.json from disk base_model = get_base_model_from_lora(path) else: # Remote LoRA: download adapter_config.json from HF try: from huggingface_hub import hf_hub_download config_path = hf_hub_download( identifier, "adapter_config.json", token = hf_token ) with open(config_path, "r") as f: adapter_config = json.load(f) base_model = adapter_config.get("base_model_name_or_path") if base_model: logger.info(f"Resolved remote LoRA base model: '{base_model}'") except Exception as e: logger.warning( f"Could not download adapter_config.json for '{identifier}': {e}" ) if not base_model: logger.warning(f"Could not determine base model for LoRA '{path}'") return None check_model = base_model else: check_model = identifier vision = is_vision_model(check_model, hf_token = hf_token) audio_type_val = detect_audio_type(check_model, hf_token = hf_token) has_audio_in = is_audio_input_type(audio_type_val) display_name = Path(path).name if is_local else identifier.split("/")[-1] return cls( identifier = identifier, display_name = display_name, path = path, is_local = is_local, is_cached = is_model_cached(identifier) if not is_local else True, is_vision = vision, is_lora = is_lora, is_audio = audio_type_val is not None and audio_type_val != "audio_vlm", audio_type = audio_type_val, has_audio_input = has_audio_in, base_model = base_model, ) @classmethod def from_ui_selection( cls, dropdown_value: Optional[str], search_value: Optional[str], local_models: list = None, hf_token: Optional[str] = None, is_lora: bool = False, ) -> Optional["ModelConfig"]: """ Create a universal ModelConfig from UI dropdown/search selections. Handles base models and LoRA adapters. """ selected = None if search_value and search_value.strip(): selected = search_value.strip() elif dropdown_value: selected = dropdown_value if not selected: return None display_name = selected # Use the correct 'local_models' parameter to resolve display names if " (Active)" in selected or " (Ready)" in selected: clean_display_name = selected.replace(" (Active)", "").replace( " (Ready)", "" ) if local_models: for local_display, local_path in local_models: if local_display == clean_display_name: selected = local_path break # Clean all UI status indicators to get the final identifier identifier = selected for status in UI_STATUS_INDICATORS: identifier = identifier.replace(status, "") identifier = identifier.strip() is_local = is_local_path(identifier) path = normalize_path(identifier) if is_local else identifier # Add unsloth/ prefix for shorthand HF models if not is_local and "/" not in identifier: identifier = f"unsloth/{identifier}" path = identifier # --- Logic for Base Model and Vision Detection --- base_model = None is_vision = False if is_lora: # For a LoRA, we MUST find its base model. base_model = get_base_model_from_lora(path) if not base_model: logger.warning( f"Could not determine base model for LoRA '{path}'. Cannot create config." ) return None # Cannot proceed without a base model # A LoRA's vision capability is determined by its base model. is_vision = is_vision_model(base_model, hf_token = hf_token) else: # For a base model, just check its own vision status. is_vision = is_vision_model(identifier, hf_token = hf_token) from utils.paths import is_model_cached is_cached = is_model_cached(identifier) if not is_local else True return cls( identifier = identifier, display_name = display_name, path = path, is_local = is_local, is_cached = is_cached, is_vision = is_vision, is_lora = is_lora, base_model = base_model, # This will be None for base models, and populated for LoRAs ) ================================================ FILE: studio/backend/utils/paths/__init__.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Path utilities for model and dataset handling """ from .path_utils import normalize_path, is_local_path, is_model_cached, get_cache_path from .storage_roots import ( studio_root, assets_root, datasets_root, dataset_uploads_root, recipe_datasets_root, outputs_root, exports_root, auth_root, auth_db_path, tmp_root, seed_uploads_root, unstructured_seed_cache_root, oxc_validator_tmp_root, tensorboard_root, ensure_dir, ensure_studio_directories, resolve_under_root, resolve_output_dir, resolve_export_dir, resolve_tensorboard_dir, resolve_dataset_path, ) __all__ = [ "normalize_path", "is_local_path", "is_model_cached", "get_cache_path", "studio_root", "assets_root", "datasets_root", "dataset_uploads_root", "recipe_datasets_root", "outputs_root", "exports_root", "auth_root", "auth_db_path", "tmp_root", "seed_uploads_root", "unstructured_seed_cache_root", "oxc_validator_tmp_root", "tensorboard_root", "ensure_dir", "ensure_studio_directories", "resolve_under_root", "resolve_output_dir", "resolve_export_dir", "resolve_tensorboard_dir", "resolve_dataset_path", ] ================================================ FILE: studio/backend/utils/paths/path_utils.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Path utilities for model and dataset handling """ import os from pathlib import Path from typing import Optional import structlog from loggers import get_logger logger = get_logger(__name__) def normalize_path(path: str) -> str: """ Convert Windows paths to WSL format if needed. Examples: C:\\Users\\... -> /mnt/c/Users/... /home/user/... -> /home/user/... (unchanged) """ if not path: return path # Handle Windows drive letters (C:\\ or c:\\) if len(path) >= 3 and path[1] == ":" and path[2] in ("\\", "/"): drive = path[0].lower() rest = path[3:].replace("\\", "/") return f"/mnt/{drive}/{rest}" # Already Unix-style or relative return path.replace("\\", "/") def is_local_path(path: str) -> bool: """ Check if path is a local filesystem path vs HuggingFace model identifier. Examples: True: /home/user/model, C:\\models, ./model, ~/model False: unsloth/llama-3.1-8b, microsoft/phi-2 """ if not path: return False # If it exists on disk, treat as local (covers relative paths like "outputs/foo"). try: if Path(normalize_path(path)).expanduser().exists(): return True except Exception: pass # Obvious HF patterns if path.count("/") == 1 and not path.startswith(("/", ".", "~")): return False # Looks like org/model format # Filesystem indicators return ( path.startswith(("/", ".", "~")) # Unix absolute/relative or ":" in path # Windows drive or URL or "\\" in path # Windows separator or os.path.isabs(path) # System-absolute ) def get_cache_path(model_name: str) -> Optional[Path]: """Get HuggingFace cache path for a model if it exists.""" cache_dir = Path.home() / ".cache" / "huggingface" / "hub" model_cache_name = model_name.replace("/", "--") model_cache_path = cache_dir / f"models--{model_cache_name}" return model_cache_path if model_cache_path.exists() else None def is_model_cached(model_name: str) -> bool: """Check if model is downloaded in HuggingFace cache.""" cache_path = get_cache_path(model_name) if not cache_path: return False # Check for actual model files for suffix in [".safetensors", ".bin", ".json"]: if list(cache_path.rglob(f"*{suffix}")): return True return False ================================================ FILE: studio/backend/utils/paths/storage_roots.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 from __future__ import annotations import os from pathlib import Path import tempfile def studio_root() -> Path: return Path.home() / ".unsloth" / "studio" def cache_root() -> Path: """Central cache directory for all studio downloads (models, datasets, etc.).""" return Path.home() / ".unsloth" / "studio" / "cache" def assets_root() -> Path: return studio_root() / "assets" def datasets_root() -> Path: return assets_root() / "datasets" def dataset_uploads_root() -> Path: return datasets_root() / "uploads" def recipe_datasets_root() -> Path: return datasets_root() / "recipes" def outputs_root() -> Path: return studio_root() / "outputs" def exports_root() -> Path: return studio_root() / "exports" def auth_root() -> Path: return studio_root() / "auth" def auth_db_path() -> Path: return auth_root() / "auth.db" def tmp_root() -> Path: return Path(tempfile.gettempdir()) / "unsloth-studio" def seed_uploads_root() -> Path: return tmp_root() / "seed-uploads" def unstructured_seed_cache_root() -> Path: return tmp_root() / "unstructured-seed-cache" def oxc_validator_tmp_root() -> Path: return tmp_root() / "oxc-validator" def tensorboard_root() -> Path: return studio_root() / "runs" def ensure_dir(path: Path) -> Path: path.mkdir(parents = True, exist_ok = True) return path def _setup_cache_env() -> None: """Set cache environment variables for HuggingFace, uv, and vLLM. Only sets variables that are not already set by the user, so explicit overrides (e.g. HF_HOME=/data/hf) are respected. Works on Linux, macOS, and Windows. """ root = cache_root() hf_dir = root / "huggingface" defaults = { "HF_HOME": str(hf_dir), "HF_HUB_CACHE": str(hf_dir / "hub"), "HF_XET_CACHE": str(hf_dir / "xet"), "UV_CACHE_DIR": str(root / "uv"), "VLLM_CACHE_ROOT": str(root / "vllm"), } for key, value in defaults.items(): if key not in os.environ: os.environ[key] = value Path(value).mkdir(parents = True, exist_ok = True) def ensure_studio_directories() -> None: """Create all standard studio directories on startup.""" for dir_fn in ( studio_root, assets_root, datasets_root, dataset_uploads_root, recipe_datasets_root, outputs_root, exports_root, auth_root, tensorboard_root, ): ensure_dir(dir_fn()) _setup_cache_env() def _clean_relative_path( path_value: str, *, strip_prefixes: tuple[str, ...] = () ) -> Path: path = Path(path_value).expanduser() parts = [part for part in path.parts if part not in ("", ".")] while parts and parts[0] in strip_prefixes: parts = parts[1:] return Path(*parts) if parts else Path() def resolve_under_root( path_value: str | None, *, root: Path, strip_prefixes: tuple[str, ...] = (), ) -> Path: if not path_value or not str(path_value).strip(): return root path = Path(str(path_value).strip()).expanduser() if path.is_absolute(): return path cleaned = _clean_relative_path(str(path), strip_prefixes = strip_prefixes) return root / cleaned def resolve_output_dir(path_value: str | None = None) -> Path: return resolve_under_root( path_value, root = outputs_root(), strip_prefixes = ("outputs",), ) def resolve_export_dir(path_value: str | None = None) -> Path: return resolve_under_root( path_value, root = exports_root(), strip_prefixes = ("exports",), ) def resolve_tensorboard_dir(path_value: str | None = None) -> Path: return resolve_under_root( path_value, root = tensorboard_root(), strip_prefixes = ("runs", "tensorboard"), ) def resolve_dataset_path(path_value: str) -> Path: path = Path(path_value).expanduser() if path.is_absolute(): return path parts = [part for part in Path(path_value).parts if part not in ("", ".")] if parts[:2] == ["assets", "datasets"]: parts = parts[2:] if parts and parts[0] == "uploads": cleaned = Path(*parts[1:]) if len(parts) > 1 else Path() return dataset_uploads_root() / cleaned if parts and parts[0] == "recipes": cleaned = Path(*parts[1:]) if len(parts) > 1 else Path() return recipe_datasets_root() / cleaned cleaned = Path(*parts) if parts else Path() candidates = [ dataset_uploads_root() / cleaned, recipe_datasets_root() / cleaned, datasets_root() / cleaned, dataset_uploads_root() / cleaned.name, recipe_datasets_root() / cleaned.name, ] for candidate in candidates: if candidate.exists(): return candidate return candidates[0] ================================================ FILE: studio/backend/utils/transformers_version.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Automatic transformers version switching. Some newer model architectures (Ministral-3, GLM-4.7-Flash, Qwen3-30B-A3B MoE, tiny_qwen3_moe) require transformers>=5.3.0, while everything else needs the default 4.57.x that ships with Unsloth. When loading a LoRA adapter with a custom name, we resolve the base model from ``adapter_config.json`` and check *that* against the model list. Strategy: Training and inference run in subprocesses that activate the correct version via sys.path (prepending .venv_t5/ for 5.x models). See: - core/training/worker.py - core/inference/worker.py For export (still in-process), ensure_transformers_version() does a lightweight sys.path swap using the same .venv_t5/ directory pre-installed by setup.sh. """ import importlib import json import structlog from loggers import get_logger import os import shutil import subprocess import sys from pathlib import Path logger = get_logger(__name__) # --------------------------------------------------------------------------- # Detection # --------------------------------------------------------------------------- # Lowercase substrings — if ANY appears anywhere in the lowered model name, # we need transformers 5.x. TRANSFORMERS_5_MODEL_SUBSTRINGS: tuple[str, ...] = ( "ministral-3-", # Ministral-3-{3,8,14}B-{Instruct,Reasoning,Base}-2512 "glm-4.7-flash", # GLM-4.7-Flash "qwen3-30b-a3b", # Qwen3-30B-A3B-Instruct-2507 and variants "qwen3.5", # Qwen3.5 family (35B-A3B, etc.) "qwen3-next", # Qwen3-Next and variants "tiny_qwen3_moe", # imdatta0/tiny_qwen3_moe_2.8B_0.7B ) # Tokenizer classes that only exist in transformers>=5.x _TRANSFORMERS_5_TOKENIZER_CLASSES: set[str] = { "TokenizersBackend", } # Cache for dynamic tokenizer_config.json lookups to avoid repeated fetches _tokenizer_class_cache: dict[str, bool] = {} # Versions TRANSFORMERS_5_VERSION = "5.3.0" TRANSFORMERS_DEFAULT_VERSION = "4.57.6" # Pre-installed directory for transformers 5.x — created by setup.sh / setup.ps1 _VENV_T5_DIR = str(Path.home() / ".unsloth" / "studio" / ".venv_t5") def _resolve_base_model(model_name: str) -> str: """If *model_name* points to a LoRA adapter, return its base model. Checks for ``adapter_config.json`` locally first. Only calls the heavier ``get_base_model_from_lora`` for paths that are actual local directories (avoids noisy warnings for plain HF model IDs). Returns the original *model_name* unchanged if it is not a LoRA adapter. """ # --- Fast local check --------------------------------------------------- local_path = Path(model_name) adapter_cfg_path = local_path / "adapter_config.json" if adapter_cfg_path.is_file(): try: with open(adapter_cfg_path) as f: cfg = json.load(f) base = cfg.get("base_model_name_or_path") if base: logger.info( "Resolved LoRA adapter '%s' → base model '%s'", model_name, base, ) return base except Exception as exc: logger.debug("Could not read %s: %s", adapter_cfg_path, exc) # --- Only try the heavier fallback for local directories ---------------- if local_path.is_dir(): try: from utils.models import get_base_model_from_lora base = get_base_model_from_lora(model_name) if base: logger.info( "Resolved LoRA adapter '%s' → base model '%s' " "(via get_base_model_from_lora)", model_name, base, ) return base except Exception as exc: logger.debug( "get_base_model_from_lora failed for '%s': %s", model_name, exc, ) return model_name def _check_tokenizer_config_needs_v5(model_name: str) -> bool: """Fetch tokenizer_config.json from HuggingFace and check if the tokenizer_class requires transformers 5.x. Results are cached in ``_tokenizer_class_cache`` to avoid repeated fetches. Returns False on any network/parse error (fail-open to default version). """ if model_name in _tokenizer_class_cache: return _tokenizer_class_cache[model_name] import urllib.request url = f"https://huggingface.co/{model_name}/raw/main/tokenizer_config.json" try: req = urllib.request.Request(url, headers = {"User-Agent": "unsloth-studio"}) with urllib.request.urlopen(req, timeout = 10) as resp: data = json.loads(resp.read().decode()) tokenizer_class = data.get("tokenizer_class", "") result = tokenizer_class in _TRANSFORMERS_5_TOKENIZER_CLASSES if result: logger.info( "Dynamic check: %s uses tokenizer_class=%s (requires transformers 5.x)", model_name, tokenizer_class, ) _tokenizer_class_cache[model_name] = result return result except Exception as exc: logger.debug( "Could not fetch tokenizer_config.json for '%s': %s", model_name, exc ) _tokenizer_class_cache[model_name] = False return False def needs_transformers_5(model_name: str) -> bool: """Return True if *model_name* belongs to an architecture that requires ``transformers>=5.3.0``. First checks the hardcoded substring list for known models, then dynamically fetches ``tokenizer_config.json`` from HuggingFace to check if the tokenizer_class (e.g. ``TokenizersBackend``) requires v5. """ lowered = model_name.lower() if any(sub in lowered for sub in TRANSFORMERS_5_MODEL_SUBSTRINGS): return True return _check_tokenizer_config_needs_v5(model_name) # --------------------------------------------------------------------------- # Version switching (in-process — used only by export) # --------------------------------------------------------------------------- def _get_in_memory_version() -> str | None: """Return the transformers version currently loaded in this process.""" tf = sys.modules.get("transformers") if tf is not None: return getattr(tf, "__version__", None) return None # All top-level prefixes that hold references to transformers internals. _PURGE_PREFIXES = ( "transformers", "huggingface_hub", "unsloth", "unsloth_zoo", "peft", "trl", "accelerate", "auto_gptq", # NOTE: bitsandbytes is intentionally EXCLUDED — it registers torch custom # operators at import time via torch.library.define(). Those registrations # live in torch's global operator registry which survives module purge. # Re-importing bitsandbytes after purge → duplicate registration → crash. # Our own modules that import from transformers at module level # (e.g. model_config.py: `from transformers import AutoConfig`) "utils.models", "core.training", "core.inference", "core.export", ) def _purge_modules() -> int: """Remove all cached modules for transformers and its dependents. Returns the number of modules purged. """ importlib.invalidate_caches() to_remove = [ k for k in list(sys.modules.keys()) if any(k == p or k.startswith(p + ".") for p in _PURGE_PREFIXES) ] for key in to_remove: del sys.modules[key] return len(to_remove) _VENV_T5_PACKAGES = ( f"transformers=={TRANSFORMERS_5_VERSION}", "huggingface_hub==1.7.1", "hf_xet==1.4.2", "tiktoken", ) def _venv_t5_is_valid() -> bool: """Return True if .venv_t5/ has all required packages at the correct versions.""" if not os.path.isdir(_VENV_T5_DIR) or not os.listdir(_VENV_T5_DIR): return False # Check that the key package directories exist AND match the required version for pkg_spec in _VENV_T5_PACKAGES: parts = pkg_spec.split("==") pkg_name = parts[0] pkg_version = parts[1] if len(parts) > 1 else None pkg_name_norm = pkg_name.replace("-", "_") # Check directory exists if not any( (Path(_VENV_T5_DIR) / d).is_dir() for d in (pkg_name_norm, pkg_name_norm.replace("_", "-")) ): return False # For unpinned packages, existence is enough if pkg_version is None: continue # Check version via .dist-info metadata dist_info_found = False for di in Path(_VENV_T5_DIR).glob(f"{pkg_name_norm}-*.dist-info"): metadata = di / "METADATA" if not metadata.is_file(): continue for line in metadata.read_text(errors = "replace").splitlines(): if line.startswith("Version:"): installed_ver = line.split(":", 1)[1].strip() if installed_ver != pkg_version: logger.info( ".venv_t5 has %s==%s but need %s", pkg_name, installed_ver, pkg_version, ) return False dist_info_found = True break if dist_info_found: break if not dist_info_found: return False return True def _install_to_venv_t5(pkg: str) -> bool: """Install a single package into .venv_t5/, preferring uv then pip.""" # Try uv first (faster) if already on PATH -- do NOT install uv at runtime if shutil.which("uv"): result = subprocess.run( [ "uv", "pip", "install", "--python", sys.executable, "--target", _VENV_T5_DIR, "--no-deps", "--upgrade", pkg, ], stdout = subprocess.PIPE, stderr = subprocess.STDOUT, text = True, ) if result.returncode == 0: return True logger.warning("uv install of %s failed, falling back to pip", pkg) # Fallback to pip result = subprocess.run( [ sys.executable, "-m", "pip", "install", "--target", _VENV_T5_DIR, "--no-deps", "--upgrade", pkg, ], stdout = subprocess.PIPE, stderr = subprocess.STDOUT, text = True, ) if result.returncode != 0: logger.error("install failed:\n%s", result.stdout) return False return True def _ensure_venv_t5_exists() -> bool: """Ensure .venv_t5/ exists with all required packages. Install if missing.""" if _venv_t5_is_valid(): return True logger.warning( ".venv_t5 not found or incomplete at %s -- installing at runtime", _VENV_T5_DIR ) shutil.rmtree(_VENV_T5_DIR, ignore_errors = True) os.makedirs(_VENV_T5_DIR, exist_ok = True) for pkg in _VENV_T5_PACKAGES: if not _install_to_venv_t5(pkg): return False logger.info("Installed transformers 5.x to %s", _VENV_T5_DIR) return True def _activate_5x() -> None: """Prepend .venv_t5/ to sys.path, purge stale modules, reimport.""" if not _ensure_venv_t5_exists(): raise RuntimeError( f"Cannot activate transformers 5.x: .venv_t5 missing at {_VENV_T5_DIR}" ) if _VENV_T5_DIR not in sys.path: sys.path.insert(0, _VENV_T5_DIR) logger.info("Prepended %s to sys.path", _VENV_T5_DIR) count = _purge_modules() logger.info("Purged %d cached modules", count) import transformers logger.info("Loaded transformers %s", transformers.__version__) def _deactivate_5x() -> None: """Remove .venv_t5/ from sys.path, purge stale modules, reimport.""" while _VENV_T5_DIR in sys.path: sys.path.remove(_VENV_T5_DIR) logger.info("Removed %s from sys.path", _VENV_T5_DIR) count = _purge_modules() logger.info("Purged %d cached modules", count) import transformers logger.info("Reverted to transformers %s", transformers.__version__) def ensure_transformers_version(model_name: str) -> None: """Ensure the correct ``transformers`` version is active for *model_name*. Uses sys.path with .venv_t5/ (pre-installed by setup.sh): • Need 5.x → prepend .venv_t5/ to sys.path, purge modules. • Need 4.x → remove .venv_t5/ from sys.path, purge modules. For LoRA adapters with custom names, the base model is resolved from ``adapter_config.json`` before checking. NOTE: Training and inference use subprocess isolation instead of this function. This is only used by the export path (routes/export.py). """ # Resolve LoRA adapters to their base model for accurate detection resolved = _resolve_base_model(model_name) want_5 = needs_transformers_5(resolved) target_version = TRANSFORMERS_5_VERSION if want_5 else TRANSFORMERS_DEFAULT_VERSION target_major = int(target_version.split(".")[0]) # Check what's actually loaded in memory in_memory = _get_in_memory_version() logger.info( "Version check for '%s' (resolved: '%s'): need=%s, in_memory=%s", model_name, resolved, target_version, in_memory, ) # --- Already correct? --------------------------------------------------- if in_memory is not None: in_memory_major = int(in_memory.split(".")[0]) if in_memory_major == target_major: logger.info( "transformers %s already loaded — correct for '%s'", in_memory, model_name, ) return # --- Switch version ----------------------------------------------------- if want_5: logger.info("Activating transformers %s via .venv_t5…", TRANSFORMERS_5_VERSION) _activate_5x() else: logger.info( "Reverting to default transformers %s…", TRANSFORMERS_DEFAULT_VERSION ) _deactivate_5x() final = _get_in_memory_version() logger.info("✓ transformers version is now %s", final) ================================================ FILE: studio/backend/utils/utils.py ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 """ Shared backend utilities """ import os import structlog from loggers import get_logger from contextlib import contextmanager from pathlib import Path import shutil import tempfile logger = get_logger(__name__) @contextmanager def without_hf_auth(): """ Context manager to temporarily disable HuggingFace authentication. Usage: with without_hf_auth(): # Code that should run without cached tokens model_info(model_name, token=None) """ # Save environment variables saved_env = {} env_vars = ["HF_TOKEN", "HUGGINGFACE_HUB_TOKEN", "HF_HOME"] for var in env_vars: if var in os.environ: saved_env[var] = os.environ[var] del os.environ[var] # Save disable flag saved_disable = os.environ.get("HF_HUB_DISABLE_IMPLICIT_TOKEN") os.environ["HF_HUB_DISABLE_IMPLICIT_TOKEN"] = "1" # Move token files temporarily token_files = [] token_locations = [ Path.home() / ".cache" / "huggingface" / "token", Path.home() / ".huggingface" / "token", ] for token_loc in token_locations: if token_loc.exists(): temp = tempfile.NamedTemporaryFile(delete = False) temp.close() shutil.move(str(token_loc), temp.name) token_files.append((token_loc, temp.name)) try: yield finally: # Restore tokens for original, temp in token_files: try: original.parent.mkdir(parents = True, exist_ok = True) shutil.move(temp, str(original)) except Exception as e: logger.error(f"Failed to restore token {original}: {e}") # Restore environment for var, value in saved_env.items(): os.environ[var] = value if saved_disable is not None: os.environ["HF_HUB_DISABLE_IMPLICIT_TOKEN"] = saved_disable else: os.environ.pop("HF_HUB_DISABLE_IMPLICIT_TOKEN", None) def format_error_message(error: Exception, model_name: str) -> str: """ Format user-friendly error messages for common issues. Args: error: The exception that occurred model_name: Name of the model being loaded Returns: User-friendly error string """ error_str = str(error).lower() model_short = model_name.split("/")[-1] if "/" in model_name else model_name if "repository not found" in error_str or "404" in error_str: return f"Model '{model_short}' not found. Check the model name." if "401" in error_str or "unauthorized" in error_str: return f"Authentication failed for '{model_short}'. Please provide a valid HF token." if "gated" in error_str or "access to model" in error_str: return f"Model '{model_short}' requires authentication. Please provide a valid HF token." if "invalid user token" in error_str: return "Invalid HF token. Please check your token and try again." if ( "memory" in error_str or "cuda" in error_str or "mlx" in error_str or "out of memory" in error_str ): from utils.hardware import get_device device = get_device() device_label = {"cuda": "GPU", "mlx": "Apple Silicon GPU", "cpu": "system"}.get( device.value, "GPU" ) return f"Not enough {device_label} memory to load '{model_short}'. Try a smaller model or free memory." # Generic fallback return str(error) ================================================ FILE: studio/frontend/.gitignore ================================================ # SPDX-License-Identifier: AGPL-3.0-only # Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 # Logs logs *.log npm-debug.log* yarn-debug.log* yarn-error.log* pnpm-debug.log* lerna-debug.log* node_modules dist dist-ssr test/ *.local .env .env.* .omx/ # Editor directories and files .vscode/* !.vscode/extensions.json .idea .DS_Store ._* *.suo *.ntvs* *.njsproj *.sln *.sw? /src/features/recipe-studio/AGENTS.md /docs ================================================ FILE: studio/frontend/.gitkeep ================================================ ================================================ FILE: studio/frontend/biome.json ================================================ { "$schema": "https://biomejs.dev/schemas/1.9.4/schema.json", "files": { "ignore": [ "dist", "node_modules", "test", "test/**", "**/._*", "._*", "**/.DS_Store", "tsconfig*.json" ] }, "formatter": { "enabled": true, "indentStyle": "space", "indentWidth": 2 }, "organizeImports": { "enabled": true }, "linter": { "enabled": true, "rules": { "recommended": true, "a11y": { "all": true }, "complexity": { "all": true }, "correctness": { "all": true, "useImportExtensions": "off" }, "performance": { "all": true }, "security": { "all": true }, "style": { "all": true, "useNamingConvention": { "options": { "strictCase": false } } }, "suspicious": { "all": true, "noReactSpecificProps": "off" } } }, "overrides": [ { "include": ["vite.config.ts", "eslint.config.js"], "linter": { "rules": { "correctness": { "noNodejsModules": "off" }, "style": { "noDefaultExport": "off" } } } }, { "include": ["src/components/assistant-ui/reasoning.tsx"], "linter": { "rules": { "style": { "useNamingConvention": "off" } } } }, { "include": ["src/components/assistant-ui/attachment.tsx"], "linter": { "rules": { "style": { "useNamingConvention": "off" } } } }, { "include": ["src/components/assistant-ui/tool-fallback.tsx"], "linter": { "rules": { "style": { "useNamingConvention": "off" } } } }, { "include": ["src/components/component-example.tsx"], "linter": { "rules": { "style": { "noNamespaceImport": "off" } } } }, { "include": ["src/config/env.ts"], "linter": { "rules": { "style": { "useNamingConvention": "off" } } } }, { "include": ["src/components/layout/index.ts"], "linter": { "rules": { "performance": { "noBarrelFile": "off" } } } }, { "include": ["src/features/**/index.ts"], "linter": { "rules": { "performance": { "noBarrelFile": "off" } } } }, { "include": ["src/features/chat/thread-sidebar.tsx"], "linter": { "rules": { "a11y": { "useSemanticElements": "off" } } } }, { "include": ["src/features/chat/runtime-provider.tsx"], "linter": { "rules": { "style": { "useNamingConvention": "off" } } } }, { "include": ["src/components/assistant-ui/thread.tsx"], "linter": { "rules": { "style": { "useNamingConvention": "off" } } } }, { "include": ["src/features/onboarding/components/steps/summary-step.tsx"], "linter": { "rules": { "style": { "useExplicitLengthCheck": "off" } } } }, { "include": ["src/components/ui/**"], "linter": { "enabled": false }, "formatter": { "enabled": false }, "organizeImports": { "enabled": false } } ] } ================================================ FILE: studio/frontend/components.json ================================================ { "$schema": "https://ui.shadcn.com/schema.json", "style": "radix-maia", "rsc": false, "tsx": true, "tailwind": { "config": "", "css": "src/index.css", "baseColor": "neutral", "cssVariables": true, "prefix": "" }, "iconLibrary": "hugeicons", "menuColor": "default", "menuAccent": "subtle", "aliases": { "components": "@/components", "utils": "@/lib/utils", "ui": "@/components/ui", "lib": "@/lib", "hooks": "@/hooks" }, "registries": { "@magicui": "https://magicui.design/r/{name}" } } ================================================ FILE: studio/frontend/data-designer.openapi (1).yaml ================================================ openapi: 3.1.0 info: title: NeMo Data Designer Microservice description: Service for generating synthetic data. version: 1.5.0 paths: /v1/data-designer/jobs: post: tags: - Data Designer summary: Create Job operationId: create_job_v1_data_designer_jobs_post requestBody: required: true content: application/json: schema: $ref: '#/components/schemas/DataDesignerJobRequest' responses: '200': description: Successful Response content: application/json: schema: $ref: '#/components/schemas/DataDesignerJob' '422': description: Validation Error content: application/json: schema: $ref: '#/components/schemas/HTTPValidationError' get: tags: - Data Designer summary: List Jobs operationId: list_jobs_v1_data_designer_jobs_get parameters: - name: page in: query required: false schema: type: integer exclusiveMinimum: 0 description: Page number. default: 1 title: Page description: Page number. - name: page_size in: query required: false schema: type: integer exclusiveMinimum: 0 description: Page size. default: 10 title: Page Size description: Page size. - name: sort in: query required: false schema: allOf: - $ref: '#/components/schemas/DataDesignerJobsSortField' description: The field to sort by. To sort in decreasing order, use `-` in front of the field name. default: -created_at description: The field to sort by. To sort in decreasing order, use `-` in front of the field name. - in: query name: filter style: deepObject required: false explode: true schema: $ref: '#/components/schemas/DataDesignerJobsListFilter' description: Filter jobs on various criteria. - in: query name: search style: deepObject required: false explode: true schema: $ref: '#/components/schemas/DataDesignerJobsSearch' description: "\nSearch jobs using substring matching.\nYou can combine multiple\ \ search fields and filters.\n\nFor example:\n- `?search[name]=training`:\ \ searches all jobs with 'training' in the name.\n- `?search[project]=my-project`:\ \ searches all jobs with 'my-project'\n in the project field.\n- `?search[name]=training&search[name]=eval`:\ \ searches all jobs with\n 'training' OR 'eval' in the name.\n- `?search[name]=training&search[project]=my-project`:\ \ searches all\n jobs with 'training' in the name AND 'my-project' in the\ \ project.\n" responses: '200': description: Successful Response content: application/json: schema: $ref: '#/components/schemas/DataDesignerJobsPage' '422': description: Validation Error content: application/json: schema: $ref: '#/components/schemas/HTTPValidationError' /v1/data-designer/jobs/{job_id}: get: tags: - Data Designer summary: Get Job operationId: get_job_v1_data_designer_jobs__job_id__get parameters: - name: job_id in: path required: true schema: type: string title: Job Id responses: '200': description: Successful Response content: application/json: schema: $ref: '#/components/schemas/DataDesignerJob' '422': description: Validation Error content: application/json: schema: $ref: '#/components/schemas/HTTPValidationError' delete: tags: - Data Designer summary: Delete Job operationId: delete_job_v1_data_designer_jobs__job_id__delete parameters: - name: job_id in: path required: true schema: type: string title: Job Id responses: '200': description: Successful Response content: application/json: schema: {} '422': description: Validation Error content: application/json: schema: $ref: '#/components/schemas/HTTPValidationError' /v1/data-designer/jobs/{job_id}/cancel: post: tags: - Data Designer summary: Cancel Job operationId: cancel_job_v1_data_designer_jobs__job_id__cancel_post parameters: - name: job_id in: path required: true schema: type: string title: Job Id responses: '200': description: Successful Response content: application/json: schema: $ref: '#/components/schemas/DataDesignerJob' '422': description: Validation Error content: application/json: schema: $ref: '#/components/schemas/HTTPValidationError' /v1/data-designer/jobs/{job_id}/logs: get: tags: - Data Designer summary: Get Job Logs operationId: get_job_logs_v1_data_designer_jobs__job_id__logs_get parameters: - name: job_id in: path required: true schema: type: string title: Job Id - name: limit in: query required: false schema: anyOf: - type: integer - type: 'null' title: Limit - name: page_cursor in: query required: false schema: anyOf: - type: string - type: 'null' title: Page Cursor responses: '200': description: Successful Response content: application/json: schema: $ref: '#/components/schemas/PlatformJobLogPage' '422': description: Validation Error content: application/json: schema: $ref: '#/components/schemas/HTTPValidationError' /v1/data-designer/jobs/{job_id}/results: get: tags: - Data Designer summary: List Job Results operationId: list_job_results_v1_data_designer_jobs__job_id__results_get parameters: - name: job_id in: path required: true schema: type: string title: Job Id responses: '200': description: Successful Response content: application/json: schema: $ref: '#/components/schemas/PlatformJobListResultResponse' '422': description: Validation Error content: application/json: schema: $ref: '#/components/schemas/HTTPValidationError' /v1/data-designer/jobs/{job_id}/results/analysis/download: get: tags: - Data Designer summary: Download Job Result Analysis operationId: download_job_result_analysis_v1_data_designer_jobs__job_id__results_analysis_download_get parameters: - name: job_id in: path required: true schema: type: string title: Job Id responses: '200': description: Successful Response content: application/json: schema: {} '404': description: Not Found '422': description: Validation Error content: application/json: schema: $ref: '#/components/schemas/HTTPValidationError' /v1/data-designer/jobs/{job_id}/results/dataset/download: get: tags: - Data Designer summary: Download Job Result Dataset operationId: download_job_result_dataset_v1_data_designer_jobs__job_id__results_dataset_download_get parameters: - name: job_id in: path required: true schema: type: string title: Job Id responses: '200': description: Successful Response content: application/octet-stream: schema: type: string format: binary '404': description: Not Found '422': description: Validation Error content: application/json: schema: $ref: '#/components/schemas/HTTPValidationError' /v1/data-designer/jobs/{job_id}/results/{result_name}: get: tags: - Data Designer summary: Get Job Result operationId: get_job_result_v1_data_designer_jobs__job_id__results__result_name__get parameters: - name: job_id in: path required: true schema: type: string title: Job Id - name: result_name in: path required: true schema: type: string title: Result Name responses: '200': description: Successful Response content: application/json: schema: $ref: '#/components/schemas/PlatformJobResultResponse' '422': description: Validation Error content: application/json: schema: $ref: '#/components/schemas/HTTPValidationError' /v1/data-designer/jobs/{job_id}/results/{result_name}/download: get: tags: - Data Designer summary: Download Job Result operationId: download_job_result_v1_data_designer_jobs__job_id__results__result_name__download_get parameters: - name: job_id in: path required: true schema: type: string title: Job Id - name: result_name in: path required: true schema: type: string title: Result Name responses: '200': description: Successful Response content: application/octet-stream: schema: type: string format: binary '404': description: Not Found '422': description: Validation Error content: application/json: schema: $ref: '#/components/schemas/HTTPValidationError' /v1/data-designer/jobs/{job_id}/status: get: tags: - Data Designer summary: Get Job Status operationId: get_job_status_v1_data_designer_jobs__job_id__status_get parameters: - name: job_id in: path required: true schema: type: string title: Job Id responses: '200': description: Successful Response content: application/json: schema: $ref: '#/components/schemas/PlatformJobStatusResponse' '422': description: Validation Error content: application/json: schema: $ref: '#/components/schemas/HTTPValidationError' /v1/data-designer/preview: post: tags: - Data Designer summary: Generate preview Data Designer operationId: preview_v1_data_designer_preview_post requestBody: content: application/json: schema: $ref: '#/components/schemas/PreviewRequest' required: true responses: '200': description: Successful Response content: application/jsonl: schema: $ref: '#/components/schemas/PreviewMessage' '422': description: Validation Error content: application/json: schema: $ref: '#/components/schemas/HTTPValidationError' /v1/data-designer/settings: get: tags: - Data Designer summary: Get Data Designer settings description: Returns the settings available for Data Designer. operationId: get_settings_v1_data_designer_settings_get responses: '200': description: Successful Response content: application/json: schema: $ref: '#/components/schemas/SettingsResponse' components: schemas: BernoulliMixtureSamplerParams: properties: p: type: number maximum: 1.0 minimum: 0.0 title: P description: Bernoulli distribution probability of success. dist_name: type: string title: Dist Name description: Mixture distribution name. Samples will be equal to the distribution sample with probability `p`, otherwise equal to 0. Must be a valid scipy.stats distribution name. dist_params: additionalProperties: true type: object title: Dist Params description: Parameters of the scipy.stats distribution given in `dist_name`. sampler_type: type: string const: bernoulli_mixture title: Sampler Type default: bernoulli_mixture additionalProperties: false type: object required: - p - dist_name - dist_params title: BernoulliMixtureSamplerParams description: "Parameters for sampling from a Bernoulli mixture distribution.\n\ \nCombines a Bernoulli distribution with another continuous distribution,\ \ creating a mixture\nwhere values are either 0 (with probability 1-p) or\ \ sampled from the specified distribution\n(with probability p). This is useful\ \ for modeling scenarios with many zero values mixed with\na continuous distribution\ \ of non-zero values.\n\nCommon use cases include modeling sparse events,\ \ zero-inflated data, or situations where\nan outcome either doesn't occur\ \ (0) or follows a specific distribution when it does occur.\n\nAttributes:\n\ \ p: Probability of sampling from the mixture distribution (non-zero outcome).\n\ \ Must be between 0.0 and 1.0 (inclusive). With probability 1-p, the\ \ sample is 0.\n dist_name: Name of the scipy.stats distribution to sample\ \ from when outcome is non-zero.\n Must be a valid scipy.stats distribution\ \ name (e.g., \"norm\", \"gamma\", \"expon\").\n dist_params: Parameters\ \ for the specified scipy.stats distribution." BernoulliSamplerParams: properties: p: type: number maximum: 1.0 minimum: 0.0 title: P description: Probability of success. sampler_type: type: string const: bernoulli title: Sampler Type default: bernoulli additionalProperties: false type: object required: - p title: BernoulliSamplerParams description: "Parameters for sampling from a Bernoulli distribution.\n\nSamples\ \ binary values (0 or 1) representing the outcome of a single trial with a\ \ fixed\nprobability of success. This is the simplest discrete probability\ \ distribution, useful for\nmodeling binary outcomes like success/failure,\ \ yes/no, or true/false.\n\nAttributes:\n p: Probability of success (sampling\ \ 1). Must be between 0.0 and 1.0 (inclusive).\n The probability of\ \ failure (sampling 0) is automatically 1 - p." BinomialSamplerParams: properties: n: type: integer title: N description: Number of trials. p: type: number maximum: 1.0 minimum: 0.0 title: P description: Probability of success on each trial. sampler_type: type: string const: binomial title: Sampler Type default: binomial additionalProperties: false type: object required: - n - p title: BinomialSamplerParams description: "Parameters for sampling from a Binomial distribution.\n\nSamples\ \ integer values representing the number of successes in a fixed number of\ \ independent\nBernoulli trials, each with the same probability of success.\ \ Commonly used to model the number\nof successful outcomes in repeated experiments.\n\ \nAttributes:\n n: Number of independent trials. Must be a positive integer.\n\ \ p: Probability of success on each trial. Must be between 0.0 and 1.0\ \ (inclusive)." BuildStage: type: string enum: - pre_batch - post_batch - pre_generation - post_generation title: BuildStage CategorySamplerParams: properties: values: items: anyOf: - type: string - type: integer - type: number type: array minItems: 1 title: Values description: List of possible categorical values that can be sampled from. weights: type: array items: type: number title: Weights description: List of unnormalized probability weights to assigned to each value, in order. Larger values will be sampled with higher probability. sampler_type: type: string const: category title: Sampler Type default: category additionalProperties: false type: object required: - values title: CategorySamplerParams description: "Parameters for categorical sampling with optional probability\ \ weighting.\n\nSamples values from a discrete set of categories. When weights\ \ are provided, values are\nsampled according to their assigned probabilities.\ \ Without weights, uniform sampling is used.\n\nAttributes:\n values: List\ \ of possible categorical values to sample from. Can contain strings, integers,\n\ \ or floats. Must contain at least one value.\n weights: Optional\ \ unnormalized probability weights for each value. If provided, must be\n\ \ the same length as `values`. Weights are automatically normalized\ \ to sum to 1.0.\n Larger weights result in higher sampling probability\ \ for the corresponding value." CodeLang: type: string enum: - go - javascript - java - kotlin - python - ruby - rust - scala - swift - typescript - sql:sqlite - sql:tsql - sql:bigquery - sql:mysql - sql:postgres - sql:ansi title: CodeLang CodeValidatorParams: properties: code_lang: allOf: - $ref: '#/components/schemas/CodeLang' description: The language of the code to validate additionalProperties: false type: object required: - code_lang title: CodeValidatorParams description: "Configuration for code validation. Supports Python and SQL code\ \ validation.\n\nAttributes:\n code_lang: The language of the code to validate.\ \ Supported values include: `python`,\n `sql:sqlite`, `sql:postgres`,\ \ `sql:mysql`, `sql:tsql`, `sql:bigquery`, `sql:ansi`." ColumnInequalityConstraint: properties: target_column: type: string title: Target Column rhs: type: string title: Rhs operator: $ref: '#/components/schemas/InequalityOperator' additionalProperties: false type: object required: - target_column - rhs - operator title: ColumnInequalityConstraint DataDesignerConfig: properties: columns: items: oneOf: - $ref: '#/components/schemas/ExpressionColumnConfig' - $ref: '#/components/schemas/LLMCodeColumnConfig' - $ref: '#/components/schemas/LLMJudgeColumnConfig' - $ref: '#/components/schemas/LLMStructuredColumnConfig' - $ref: '#/components/schemas/LLMTextColumnConfig' - $ref: '#/components/schemas/SamplerColumnConfig' - $ref: '#/components/schemas/SeedDatasetColumnConfig' - $ref: '#/components/schemas/ValidationColumnConfig' discriminator: propertyName: column_type mapping: expression: '#/components/schemas/ExpressionColumnConfig' llm-code: '#/components/schemas/LLMCodeColumnConfig-Input' llm-judge: '#/components/schemas/LLMJudgeColumnConfig-Input' llm-structured: '#/components/schemas/LLMStructuredColumnConfig-Input' llm-text: '#/components/schemas/LLMTextColumnConfig-Input' sampler: '#/components/schemas/SamplerColumnConfig' seed-dataset: '#/components/schemas/SeedDatasetColumnConfig' validation: '#/components/schemas/ValidationColumnConfig-Input' type: array minItems: 1 title: Columns model_configs: type: array items: $ref: '#/components/schemas/ModelConfigInput' title: Model Configs seed_config: $ref: '#/components/schemas/SeedConfig' constraints: type: array items: anyOf: - $ref: '#/components/schemas/ScalarInequalityConstraint' - $ref: '#/components/schemas/ColumnInequalityConstraint' title: Constraints profilers: type: array items: $ref: '#/components/schemas/JudgeScoreProfilerConfig' title: Profilers processors: type: array items: $ref: '#/components/schemas/ProcessorConfig' title: Processors additionalProperties: false type: object required: - columns title: DataDesignerConfig description: "Configuration for NeMo Data Designer.\n\nThis class defines the\ \ main configuration structure for NeMo Data Designer,\nwhich orchestrates\ \ the generation of synthetic data.\n\nAttributes:\n columns: Required\ \ list of column configurations defining how each column\n should be\ \ generated. Must contain at least one column.\n model_configs: Optional\ \ list of model configurations for LLM-based generation.\n Each model\ \ config defines the model, provider, and inference parameters.\n seed_config:\ \ Optional seed dataset settings to use for generation.\n constraints:\ \ Optional list of column constraints.\n profilers: Optional list of column\ \ profilers for analyzing generated data characteristics." DataDesignerJob: properties: id: type: string title: Id name: type: string title: Name description: type: string title: Description project: type: string title: Project namespace: type: string title: Namespace created_at: type: string title: Created At updated_at: type: string title: Updated At spec: $ref: '#/components/schemas/DataDesignerJobConfig' status: $ref: '#/components/schemas/PlatformJobStatus' status_details: type: object additionalProperties: true title: Status Details error_details: type: object additionalProperties: true title: Error Details ownership: type: object additionalProperties: true title: Ownership custom_fields: type: object additionalProperties: true title: Custom Fields type: object required: - name - spec title: DataDesignerJob DataDesignerJobConfig: properties: num_records: type: integer title: Num Records config: $ref: '#/components/schemas/DataDesignerConfig' type: object required: - num_records - config title: DataDesignerJobConfig DataDesignerJobRequest: properties: name: type: string title: Name description: type: string title: Description namespace: type: string title: Namespace project: type: string title: Project spec: $ref: '#/components/schemas/DataDesignerJobConfig' ownership: type: object additionalProperties: true title: Ownership custom_fields: type: object additionalProperties: true title: Custom Fields type: object required: - spec title: DataDesignerJobRequest DataDesignerJobsListFilter: properties: created_at: allOf: - $ref: '#/components/schemas/DatetimeFilter' description: Jobs created at 'gte' datetime or 'lte' datetime. name: type: string title: Name description: Name of the job. namespace: type: string title: Namespace description: Namespace of the job. project: type: string title: Project description: Project containing the job. status: allOf: - $ref: '#/components/schemas/PlatformJobStatus' description: The current status. updated_at: allOf: - $ref: '#/components/schemas/DatetimeFilter' description: Jobs updated at 'gte' datetime or 'lte' datetime. additionalProperties: false type: object title: DataDesignerJobsListFilter DataDesignerJobsPage: properties: object: type: string title: Object description: The type of object being returned. default: list data: items: $ref: '#/components/schemas/DataDesignerJob' type: array title: Data pagination: allOf: - $ref: '#/components/schemas/PaginationData' description: Pagination information. sort: type: string title: Sort description: The field on which the results are sorted. filter: allOf: - $ref: '#/components/schemas/DataDesignerJobsListFilter' description: Filtering information. search: allOf: - $ref: '#/components/schemas/DataDesignerJobsSearch' description: Search information. type: object required: - data title: DataDesignerJobsPage DataDesignerJobsSearch: properties: name: type: array items: type: string title: Name description: Search jobs where name contains any of these strings. project: type: array items: type: string title: Project description: Search jobs where project contains any of these strings. type: object title: DataDesignerJobsSearch DataDesignerJobsSortField: type: string enum: - created_at - -created_at - updated_at - -updated_at title: DataDesignerJobsSortField DatetimeFilter: properties: gte: type: string title: Gte description: Filter for results greater than or equal to this datetime. lte: type: string title: Lte description: Filter for results less than or equal to this datetime. additionalProperties: false type: object title: DatetimeFilter DatetimeSamplerParams: properties: start: type: string title: Start description: Earliest possible datetime for sampling range, inclusive. end: type: string title: End description: Latest possible datetime for sampling range, inclusive. unit: type: string enum: - Y - M - D - h - m - s title: Unit description: Sampling units, e.g. the smallest possible time interval between samples. default: D sampler_type: type: string const: datetime title: Sampler Type default: datetime additionalProperties: false type: object required: - start - end title: DatetimeSamplerParams description: "Parameters for uniform datetime sampling within a specified range.\n\ \nSamples datetime values uniformly between a start and end date with a specified\ \ granularity.\nThe sampling unit determines the smallest possible time interval\ \ between consecutive samples.\n\nAttributes:\n start: Earliest possible\ \ datetime for the sampling range (inclusive). Must be a valid\n datetime\ \ string parseable by pandas.to_datetime().\n end: Latest possible datetime\ \ for the sampling range (inclusive). Must be a valid\n datetime string\ \ parseable by pandas.to_datetime().\n unit: Time unit for sampling granularity.\ \ Options:\n - \"Y\": Years\n - \"M\": Months\n - \"\ D\": Days (default)\n - \"h\": Hours\n - \"m\": Minutes\n \ \ - \"s\": Seconds" DisplayModelProvider: properties: name: type: string title: Name provider_type: type: string title: Provider Type default: openai extra_body: type: object additionalProperties: true title: Extra Body allowed_models: type: array items: type: string title: Allowed Models additionalProperties: false type: object required: - name title: DisplayModelProvider DistributionType: type: string enum: - uniform - manual title: DistributionType ExpressionColumnConfig: properties: name: type: string title: Name drop: type: boolean title: Drop default: false column_type: type: string const: expression title: Column Type default: expression expr: type: string title: Expr dtype: type: string enum: - int - float - str - bool title: Dtype default: str additionalProperties: false type: object required: - name - expr title: ExpressionColumnConfig description: "Configuration for derived columns using Jinja2 expressions.\n\n\ Expression columns compute values by evaluating Jinja2 templates that reference\ \ other\ncolumns. Useful for transformations, concatenations, conditional\ \ logic, and derived\nfeatures without requiring LLM generation. The expression\ \ is evaluated row-by-row.\n\nAttributes:\n expr: Jinja2 expression to\ \ evaluate. Can reference other column values using\n {{ column_name\ \ }} syntax. Supports filters, conditionals, and arithmetic.\n Must\ \ be a valid, non-empty Jinja2 template.\n dtype: Data type to cast the\ \ result to. Must be one of \"int\", \"float\", \"str\", or \"bool\".\n \ \ Defaults to \"str\". Type conversion is applied after expression evaluation.\n\ \ column_type: Discriminator field, always \"expression\" for this configuration\ \ type." FileStorageType: type: string enum: - nds title: FileStorageType GaussianSamplerParams: properties: mean: type: number title: Mean description: Mean of the Gaussian distribution stddev: type: number title: Stddev description: Standard deviation of the Gaussian distribution decimal_places: type: integer title: Decimal Places description: Number of decimal places to round the sampled values to. sampler_type: type: string const: gaussian title: Sampler Type default: gaussian additionalProperties: false type: object required: - mean - stddev title: GaussianSamplerParams description: "Parameters for sampling from a Gaussian (Normal) distribution.\n\ \nSamples continuous values from a normal distribution characterized by its\ \ mean and standard\ndeviation. The Gaussian distribution is one of the most\ \ commonly used probability distributions,\nappearing naturally in many real-world\ \ phenomena due to the Central Limit Theorem.\n\nAttributes:\n mean: Mean\ \ (center) of the Gaussian distribution. This is the expected value and the\n\ \ location of the distribution's peak.\n stddev: Standard deviation\ \ of the Gaussian distribution. Controls the spread or width\n of the\ \ distribution. Must be positive.\n decimal_places: Optional number of\ \ decimal places to round sampled values to. If None,\n values are\ \ not rounded." HTTPValidationError: properties: detail: items: $ref: '#/components/schemas/ValidationError' type: array title: Detail type: object title: HTTPValidationError ImageContext: properties: modality: allOf: - $ref: '#/components/schemas/Modality' default: image column_name: type: string title: Column Name data_type: $ref: '#/components/schemas/ModalityDataType' image_format: $ref: '#/components/schemas/ImageFormat' type: object required: - column_name - data_type title: ImageContext ImageFormat: type: string enum: - png - jpg - jpeg - gif - webp title: ImageFormat IndexRange: properties: start: type: integer minimum: 0.0 title: Start description: The start index of the index range (inclusive) end: type: integer minimum: 0.0 title: End description: The end index of the index range (inclusive) additionalProperties: false type: object required: - start - end title: IndexRange InequalityOperator: type: string enum: - lt - le - gt - ge title: InequalityOperator InferenceParametersInput: properties: temperature: anyOf: - type: number - $ref: '#/components/schemas/UniformDistribution' - $ref: '#/components/schemas/ManualDistribution' - type: 'null' title: Temperature top_p: anyOf: - type: number - $ref: '#/components/schemas/UniformDistribution' - $ref: '#/components/schemas/ManualDistribution' - type: 'null' title: Top P max_tokens: type: integer title: Max Tokens max_parallel_requests: type: integer minimum: 1.0 title: Max Parallel Requests default: 4 timeout: type: integer title: Timeout extra_body: type: object additionalProperties: true title: Extra Body additionalProperties: false type: object title: InferenceParametersInput InferenceParametersOutput: properties: temperature: anyOf: - type: number - $ref: '#/components/schemas/UniformDistribution' - $ref: '#/components/schemas/ManualDistribution' - type: 'null' title: Temperature top_p: anyOf: - type: number - $ref: '#/components/schemas/UniformDistribution' - $ref: '#/components/schemas/ManualDistribution' - type: 'null' title: Top P max_tokens: type: integer title: Max Tokens max_parallel_requests: type: integer minimum: 1.0 title: Max Parallel Requests default: 4 timeout: type: integer title: Timeout extra_body: type: object additionalProperties: true title: Extra Body additionalProperties: false type: object title: InferenceParametersOutput JudgeScoreProfilerConfig: properties: model_alias: type: string title: Model Alias summary_score_sample_size: type: integer title: Summary Score Sample Size default: 20 additionalProperties: false type: object required: - model_alias title: JudgeScoreProfilerConfig LLMCodeColumnConfig: properties: name: type: string title: Name drop: type: boolean title: Drop default: false column_type: type: string const: llm-code title: Column Type default: llm-code prompt: type: string title: Prompt model_alias: type: string title: Model Alias system_prompt: type: string title: System Prompt multi_modal_context: type: array items: $ref: '#/components/schemas/ImageContext' title: Multi Modal Context code_lang: $ref: '#/components/schemas/CodeLang' additionalProperties: false type: object required: - name - prompt - model_alias - code_lang title: LLMCodeColumnConfig description: "Configuration for code generation columns using Large Language\ \ Models.\n\nExtends LLMTextColumnConfig to generate code snippets in specific\ \ programming languages\nor SQL dialects. The generated code is automatically\ \ extracted from markdown code blocks\nfor the specified language. Inherits\ \ all prompt templating capabilities.\n\nAttributes:\n code_lang: Programming\ \ language or SQL dialect for code generation. Supported\n values include:\ \ \"python\", \"javascript\", \"typescript\", \"java\", \"kotlin\", \"go\"\ ,\n \"rust\", \"ruby\", \"scala\", \"swift\", \"sql:sqlite\", \"sql:postgres\"\ , \"sql:mysql\",\n \"sql:tsql\", \"sql:bigquery\", \"sql:ansi\". See\ \ CodeLang enum for complete list.\n column_type: Discriminator field,\ \ always \"llm-code\" for this configuration type." LLMJudgeColumnConfig: properties: name: type: string title: Name drop: type: boolean title: Drop default: false column_type: type: string const: llm-judge title: Column Type default: llm-judge prompt: type: string title: Prompt model_alias: type: string title: Model Alias system_prompt: type: string title: System Prompt multi_modal_context: type: array items: $ref: '#/components/schemas/ImageContext' title: Multi Modal Context scores: items: $ref: '#/components/schemas/Score' type: array minItems: 1 title: Scores additionalProperties: false type: object required: - name - prompt - model_alias - scores title: LLMJudgeColumnConfig description: "Configuration for LLM-as-a-judge quality assessment and scoring\ \ columns.\n\nExtends LLMTextColumnConfig to create judge columns that evaluate\ \ and score other\ngenerated content based on the defined criteria. Useful\ \ for quality assessment, preference\nranking, and multi-dimensional evaluation\ \ of generated data.\n\nAttributes:\n scores: List of Score objects defining\ \ the evaluation dimensions. Each score\n represents a different aspect\ \ to evaluate (e.g., accuracy, relevance, fluency).\n Must contain\ \ at least one score.\n column_type: Discriminator field, always \"llm-judge\"\ \ for this configuration type." LLMStructuredColumnConfig: properties: name: type: string title: Name drop: type: boolean title: Drop default: false column_type: type: string const: llm-structured title: Column Type default: llm-structured prompt: type: string title: Prompt model_alias: type: string title: Model Alias system_prompt: type: string title: System Prompt multi_modal_context: type: array items: $ref: '#/components/schemas/ImageContext' title: Multi Modal Context output_format: anyOf: - additionalProperties: true type: object - {} title: Output Format additionalProperties: false type: object required: - name - prompt - model_alias - output_format title: LLMStructuredColumnConfig description: "Configuration for structured JSON generation columns using Large\ \ Language Models.\n\nExtends LLMTextColumnConfig to generate structured data\ \ conforming to a specified schema.\nUses JSON schema or Pydantic models to\ \ define the expected output structure, enabling\ntype-safe and validated\ \ structured output generation. Inherits prompt templating capabilities.\n\ \nAttributes:\n output_format: The schema defining the expected output\ \ structure. Can be either:\n - A Pydantic BaseModel class (recommended)\n\ \ - A JSON schema dictionary\n column_type: Discriminator field,\ \ always \"llm-structured\" for this configuration type." LLMTextColumnConfig: properties: name: type: string title: Name drop: type: boolean title: Drop default: false column_type: type: string const: llm-text title: Column Type default: llm-text prompt: type: string title: Prompt model_alias: type: string title: Model Alias system_prompt: type: string title: System Prompt multi_modal_context: type: array items: $ref: '#/components/schemas/ImageContext' title: Multi Modal Context additionalProperties: false type: object required: - name - prompt - model_alias title: LLMTextColumnConfig description: "Configuration for text generation columns using Large Language\ \ Models.\n\nLLM text columns generate free-form text content using language\ \ models via LiteLLM.\nPrompts support Jinja2 templating to reference values\ \ from other columns, enabling\ncontext-aware generation. The generated text\ \ can optionally include reasoning traces\nwhen models support extended thinking.\n\ \nAttributes:\n prompt: Prompt template for text generation. Supports Jinja2\ \ syntax to\n reference other columns (e.g., \"Write a story about\ \ {{ character_name }}\").\n Must be a valid Jinja2 template.\n \ \ model_alias: Alias of the model configuration to use for generation.\n \ \ Must match a model alias defined when initializing the DataDesignerConfigBuilder.\n\ \ system_prompt: Optional system prompt to set model behavior and constraints.\n\ \ Also supports Jinja2 templating. If provided, must be a valid Jinja2\ \ template.\n Do not put any output parsing instructions in the system\ \ prompt. Instead,\n use the appropriate column type for the output\ \ you want to generate - e.g.,\n `LLMStructuredColumnConfig` for structured\ \ output, `LLMCodeColumnConfig` for code.\n multi_modal_context: Optional\ \ list of image contexts for multi-modal generation.\n Enables vision-capable\ \ models to generate text based on image inputs.\n column_type: Discriminator\ \ field, always \"llm-text\" for this configuration type." LocalCallableValidatorParams: properties: validation_function: title: Validation Function description: Function (Callable[[pd.DataFrame], pd.DataFrame]) to validate the data output_schema: type: object additionalProperties: true title: Output Schema description: Expected schema for local callable validator's output additionalProperties: false type: object required: - validation_function title: LocalCallableValidatorParams description: "Configuration for local callable validation. Expects a function\ \ to be passed that validates the data.\n\nAttributes:\n validation_function:\ \ Function (`Callable[[pd.DataFrame], pd.DataFrame]`) to validate the\n \ \ data. Output must contain a column `is_valid` of type `bool`.\n \ \ output_schema: The JSON schema for the local callable validator's output.\ \ If not provided,\n the output will not be validated." ManualDistribution: properties: distribution_type: allOf: - $ref: '#/components/schemas/DistributionType' default: manual params: $ref: '#/components/schemas/ManualDistributionParams' additionalProperties: false type: object required: - params title: ManualDistribution ManualDistributionParams: properties: values: items: type: number type: array minItems: 1 title: Values weights: type: array items: type: number title: Weights additionalProperties: false type: object required: - values title: ManualDistributionParams MessageType: type: string enum: - analysis - dataset - heartbeat - log title: MessageType Modality: type: string enum: - image title: Modality ModalityDataType: type: string enum: - url - base64 title: ModalityDataType ModelConfigInput: properties: alias: type: string title: Alias model: type: string title: Model inference_parameters: $ref: '#/components/schemas/InferenceParametersInput' provider: type: string title: Provider additionalProperties: false type: object required: - alias - model title: ModelConfigInput ModelConfigOutput: properties: alias: type: string title: Alias model: type: string title: Model inference_parameters: $ref: '#/components/schemas/InferenceParametersOutput' provider: type: string title: Provider additionalProperties: false type: object required: - alias - model title: ModelConfigOutput PaginationData: properties: page: type: integer title: Page description: The current page number. page_size: type: integer title: Page Size description: The page size used for the query. current_page_size: type: integer title: Current Page Size description: The size for the current page. total_pages: type: integer title: Total Pages description: The total number of pages. total_results: type: integer title: Total Results description: The total number of results. type: object required: - page - page_size - current_page_size - total_pages - total_results title: PaginationData PartitionBlock: properties: index: type: integer minimum: 0.0 title: Index description: The index of the partition to sample from default: 0 num_partitions: type: integer minimum: 1.0 title: Num Partitions description: The total number of partitions in the dataset default: 1 additionalProperties: false type: object title: PartitionBlock PersonFromFakerSamplerParams: properties: locale: type: string title: Locale description: Locale string, determines the language and geographic locale that a synthetic person will be sampled from. E.g, en_US, en_GB, fr_FR, ... default: en_US sex: type: string title: Sex description: If specified, then only synthetic people of the specified sex will be sampled. city: anyOf: - type: string - items: type: string type: array title: City description: If specified, then only synthetic people from these cities will be sampled. age_range: items: type: integer type: array maxItems: 2 minItems: 2 title: Age Range description: If specified, then only synthetic people within this age range will be sampled. default: - 18 - 114 sampler_type: type: string const: person_from_faker title: Sampler Type default: person_from_faker additionalProperties: false type: object title: PersonFromFakerSamplerParams PersonSamplerParams: properties: locale: type: string title: Locale description: 'Locale that determines the language and geographic location that a synthetic person will be sampled from. Must be a locale supported by a managed Nemotron Personas dataset. Managed datasets exist for the following locales: en_US, ja_JP, en_IN, hi_IN.' default: en_US sex: type: string title: Sex description: If specified, then only synthetic people of the specified sex will be sampled. city: anyOf: - type: string - items: type: string type: array title: City description: If specified, then only synthetic people from these cities will be sampled. age_range: items: type: integer type: array maxItems: 2 minItems: 2 title: Age Range description: If specified, then only synthetic people within this age range will be sampled. default: - 18 - 114 select_field_values: type: object additionalProperties: items: type: string type: array title: Select Field Values description: Sample synthetic people with the specified field values. This is meant to be a flexible argument for selecting a subset of the population from the managed dataset. Note that this sampler does not support rare combinations of field values and will likely fail if your desired subset is not well-represented in the managed Nemotron Personas dataset. We generally recommend using the `sex`, `city`, and `age_range` arguments to filter the population when possible. examples: - education_level: - high_school - some_college - bachelors state: - NY - CA - OH - TX - NV with_synthetic_personas: type: boolean title: With Synthetic Personas description: If True, then append synthetic persona columns to each generated person. default: false sampler_type: type: string const: person title: Sampler Type default: person additionalProperties: false type: object title: PersonSamplerParams description: "Parameters for sampling synthetic person data with demographic\ \ attributes.\n\nGenerates realistic synthetic person data including names,\ \ addresses, phone numbers, and other\ndemographic information. Data can be\ \ sampled from managed datasets (when available) or generated\nusing Faker.\ \ The sampler supports filtering by locale, sex, age, geographic location,\ \ and can\noptionally include synthetic persona descriptions.\n\nAttributes:\n\ \ locale: Locale string determining the language and geographic region\ \ for synthetic people.\n Format: language_COUNTRY (e.g., \"en_US\"\ , \"en_GB\", \"fr_FR\", \"de_DE\", \"es_ES\", \"ja_JP\").\n Defaults\ \ to \"en_US\".\n sex: If specified, filters to only sample people of the\ \ specified sex. Options: \"Male\" or\n \"Female\". If None, samples\ \ both sexes.\n city: If specified, filters to only sample people from\ \ the specified city or cities. Can be\n a single city name (string)\ \ or a list of city names.\n age_range: Two-element list [min_age, max_age]\ \ specifying the age range to sample from\n (inclusive). Defaults to\ \ a standard age range. Both values must be between minimum and\n maximum\ \ allowed ages.\n with_synthetic_personas: If True, appends additional\ \ synthetic persona columns including\n personality traits, interests,\ \ and background descriptions. Only supported for certain\n locales\ \ with managed datasets.\n sample_dataset_when_available: If True, samples\ \ from curated managed datasets when available\n for the specified\ \ locale. If False or unavailable, falls back to Faker-generated data.\n \ \ Managed datasets typically provide more realistic and diverse synthetic\ \ people." PlatformJobListResultResponse: properties: object: type: string title: Object description: The type of object being returned. default: list data: items: $ref: '#/components/schemas/PlatformJobResultResponse' type: array title: Data type: object required: - data title: PlatformJobListResultResponse PlatformJobLog: properties: timestamp: type: string format: date-time title: Timestamp job_id: type: string title: Job Id job_step: type: string title: Job Step job_task: type: string title: Job Task message: type: string title: Message type: object required: - timestamp - job_id - job_step - job_task - message title: PlatformJobLog PlatformJobLogPage: properties: object: type: string title: Object description: The type of object being returned. default: list data: items: $ref: '#/components/schemas/PlatformJobLog' type: array title: Data total: type: integer title: Total next_page: type: string title: Next Page prev_page: type: string title: Prev Page type: object required: - data - total - next_page - prev_page title: PlatformJobLogPage PlatformJobResultResponse: properties: result_name: type: string title: Result Name job_id: type: string title: Job Id namespace: type: string title: Namespace project: type: string title: Project created_at: type: string format: date-time title: Created At updated_at: type: string format: date-time title: Updated At artifact_url: type: string title: Artifact Url artifact_storage_type: $ref: '#/components/schemas/FileStorageType' type: object required: - result_name - job_id - namespace - artifact_url - artifact_storage_type title: PlatformJobResultResponse PlatformJobStatus: type: string enum: - created - pending - active - cancelled - cancelling - error - completed - paused - pausing - resuming title: PlatformJobStatus description: 'Enumeration of possible job statuses. This enum represents the various states a job can be in during its lifecycle, from creation to a terminal state.' PlatformJobStatusResponse: properties: job_id: type: string title: Job Id status: $ref: '#/components/schemas/PlatformJobStatus' status_details: additionalProperties: true type: object title: Status Details error_details: type: object additionalProperties: true title: Error Details steps: items: $ref: '#/components/schemas/PlatformJobStepStatusResponse' type: array title: Steps type: object required: - job_id - status - status_details - error_details - steps title: PlatformJobStatusResponse PlatformJobStepStatusResponse: properties: name: type: string title: Name status: $ref: '#/components/schemas/PlatformJobStatus' status_details: additionalProperties: true type: object title: Status Details error_details: type: object additionalProperties: true title: Error Details tasks: items: $ref: '#/components/schemas/PlatformJobTaskStatusResponse' type: array title: Tasks type: object required: - name - status - status_details - error_details - tasks title: PlatformJobStepStatusResponse PlatformJobTaskStatusResponse: properties: id: type: string title: Id status: $ref: '#/components/schemas/PlatformJobStatus' status_details: additionalProperties: true type: object title: Status Details error_details: type: object additionalProperties: true title: Error Details error_stack: type: string title: Error Stack type: object required: - id - status - status_details - error_details - error_stack title: PlatformJobTaskStatusResponse PoissonSamplerParams: properties: mean: type: number title: Mean description: Mean number of events in a fixed interval. sampler_type: type: string const: poisson title: Sampler Type default: poisson additionalProperties: false type: object required: - mean title: PoissonSamplerParams description: "Parameters for sampling from a Poisson distribution.\n\nSamples\ \ non-negative integer values representing the number of events occurring\ \ in a fixed\ninterval of time or space. The Poisson distribution is commonly\ \ used to model count data\nlike the number of arrivals, occurrences, or events\ \ per time period.\n\nThe distribution is characterized by a single parameter\ \ (mean/rate), and both the mean and\nvariance equal this parameter value.\n\ \nAttributes:\n mean: Mean number of events in the fixed interval (also\ \ called rate parameter \u03BB).\n Must be positive. This represents\ \ both the expected value and the variance of the\n distribution." PreviewMessage: properties: message: type: string title: Message message_type: $ref: '#/components/schemas/MessageType' extra: type: object additionalProperties: type: string title: Extra additionalProperties: false type: object required: - message - message_type title: PreviewMessage PreviewRequest: properties: config: $ref: '#/components/schemas/DataDesignerConfig' num_records: type: integer title: Num Records type: object required: - config title: PreviewRequest ProcessorConfig: properties: build_stage: allOf: - $ref: '#/components/schemas/BuildStage' description: 'The stage at which the processor will run. Supported stages: post_batch' additionalProperties: false type: object required: - build_stage title: ProcessorConfig RemoteValidatorParams: properties: endpoint_url: type: string title: Endpoint Url description: URL of the remote endpoint output_schema: type: object additionalProperties: true title: Output Schema description: Expected schema for remote validator's output timeout: type: number exclusiveMinimum: 0.0 title: Timeout description: The timeout for the HTTP request default: 30.0 max_retries: type: integer minimum: 0.0 title: Max Retries description: The maximum number of retry attempts default: 3 retry_backoff: type: number exclusiveMinimum: 1.0 title: Retry Backoff description: The backoff factor for the retry delay default: 2.0 max_parallel_requests: type: integer minimum: 1.0 title: Max Parallel Requests description: The maximum number of parallel requests to make default: 4 additionalProperties: false type: object required: - endpoint_url title: RemoteValidatorParams description: "Configuration for remote validation. Sends data to a remote endpoint\ \ for validation.\n\nAttributes:\n endpoint_url: The URL of the remote\ \ endpoint.\n output_schema: The JSON schema for the remote validator's\ \ output. If not provided,\n the output will not be validated.\n \ \ timeout: The timeout for the HTTP request in seconds. Defaults to 30.0.\n\ \ max_retries: The maximum number of retry attempts. Defaults to 3.\n \ \ retry_backoff: The backoff factor for the retry delay in seconds. Defaults\ \ to 2.0.\n max_parallel_requests: The maximum number of parallel requests\ \ to make. Defaults to 4." SamplerColumnConfig: properties: name: type: string title: Name drop: type: boolean title: Drop default: false column_type: type: string const: sampler title: Column Type default: sampler sampler_type: $ref: '#/components/schemas/SamplerType' params: oneOf: - $ref: '#/components/schemas/SubcategorySamplerParams' - $ref: '#/components/schemas/CategorySamplerParams' - $ref: '#/components/schemas/DatetimeSamplerParams' - $ref: '#/components/schemas/PersonSamplerParams' - $ref: '#/components/schemas/PersonFromFakerSamplerParams' - $ref: '#/components/schemas/TimeDeltaSamplerParams' - $ref: '#/components/schemas/UUIDSamplerParams' - $ref: '#/components/schemas/BernoulliSamplerParams' - $ref: '#/components/schemas/BernoulliMixtureSamplerParams' - $ref: '#/components/schemas/BinomialSamplerParams' - $ref: '#/components/schemas/GaussianSamplerParams' - $ref: '#/components/schemas/PoissonSamplerParams' - $ref: '#/components/schemas/UniformSamplerParams' - $ref: '#/components/schemas/ScipySamplerParams' title: Params discriminator: propertyName: sampler_type mapping: bernoulli: '#/components/schemas/BernoulliSamplerParams' bernoulli_mixture: '#/components/schemas/BernoulliMixtureSamplerParams' binomial: '#/components/schemas/BinomialSamplerParams' category: '#/components/schemas/CategorySamplerParams' datetime: '#/components/schemas/DatetimeSamplerParams' gaussian: '#/components/schemas/GaussianSamplerParams' person: '#/components/schemas/PersonSamplerParams' person_from_faker: '#/components/schemas/PersonFromFakerSamplerParams' poisson: '#/components/schemas/PoissonSamplerParams' scipy: '#/components/schemas/ScipySamplerParams' subcategory: '#/components/schemas/SubcategorySamplerParams' timedelta: '#/components/schemas/TimeDeltaSamplerParams' uniform: '#/components/schemas/UniformSamplerParams' uuid: '#/components/schemas/UUIDSamplerParams' conditional_params: additionalProperties: oneOf: - $ref: '#/components/schemas/SubcategorySamplerParams' - $ref: '#/components/schemas/CategorySamplerParams' - $ref: '#/components/schemas/DatetimeSamplerParams' - $ref: '#/components/schemas/PersonSamplerParams' - $ref: '#/components/schemas/PersonFromFakerSamplerParams' - $ref: '#/components/schemas/TimeDeltaSamplerParams' - $ref: '#/components/schemas/UUIDSamplerParams' - $ref: '#/components/schemas/BernoulliSamplerParams' - $ref: '#/components/schemas/BernoulliMixtureSamplerParams' - $ref: '#/components/schemas/BinomialSamplerParams' - $ref: '#/components/schemas/GaussianSamplerParams' - $ref: '#/components/schemas/PoissonSamplerParams' - $ref: '#/components/schemas/UniformSamplerParams' - $ref: '#/components/schemas/ScipySamplerParams' discriminator: propertyName: sampler_type mapping: bernoulli: '#/components/schemas/BernoulliSamplerParams' bernoulli_mixture: '#/components/schemas/BernoulliMixtureSamplerParams' binomial: '#/components/schemas/BinomialSamplerParams' category: '#/components/schemas/CategorySamplerParams' datetime: '#/components/schemas/DatetimeSamplerParams' gaussian: '#/components/schemas/GaussianSamplerParams' person: '#/components/schemas/PersonSamplerParams' person_from_faker: '#/components/schemas/PersonFromFakerSamplerParams' poisson: '#/components/schemas/PoissonSamplerParams' scipy: '#/components/schemas/ScipySamplerParams' subcategory: '#/components/schemas/SubcategorySamplerParams' timedelta: '#/components/schemas/TimeDeltaSamplerParams' uniform: '#/components/schemas/UniformSamplerParams' uuid: '#/components/schemas/UUIDSamplerParams' type: object title: Conditional Params default: {} convert_to: type: string title: Convert To additionalProperties: false type: object required: - name - sampler_type - params title: SamplerColumnConfig description: "Configuration for columns generated using numerical samplers.\n\ \nSampler columns provide efficient data generation using numerical samplers\ \ for\ncommon data types and distributions. Supported samplers include UUID\ \ generation,\ndatetime/timedelta sampling, person generation, category /\ \ subcategory sampling,\nand various statistical distributions (uniform, gaussian,\ \ binomial, poisson, scipy).\n\nAttributes:\n sampler_type: Type of sampler\ \ to use. Available types include:\n \"uuid\", \"category\", \"subcategory\"\ , \"uniform\", \"gaussian\", \"bernoulli\",\n \"bernoulli_mixture\"\ , \"binomial\", \"poisson\", \"scipy\", \"person\", \"datetime\", \"timedelta\"\ .\n params: Parameters specific to the chosen sampler type. Type varies\ \ based on the `sampler_type`\n (e.g., `CategorySamplerParams`, `UniformSamplerParams`,\ \ `PersonSamplerParams`).\n conditional_params: Optional dictionary for\ \ conditional parameters. The dict keys\n are the conditions that must\ \ be met (e.g., \"age > 21\") for the conditional parameters\n to be\ \ used. The values of dict are the parameters to use when the condition is\ \ met.\n convert_to: Optional type conversion to apply after sampling.\ \ Must be one of \"float\", \"int\", or \"str\".\n Useful for converting\ \ numerical samples to strings or other types.\n column_type: Discriminator\ \ field, always \"sampler\" for this configuration type.\n\n!!! tip \"Displaying\ \ available samplers and their parameters\"\n The config builder has an\ \ `info` attribute that can be used to display the\n available samplers\ \ and their parameters:\n ```python\n config_builder.info.display(\"\ samplers\")\n ```" SamplerType: type: string enum: - bernoulli - bernoulli_mixture - binomial - category - datetime - gaussian - person - person_from_faker - poisson - scipy - subcategory - timedelta - uniform - uuid title: SamplerType SamplingStrategy: type: string enum: - ordered - shuffle title: SamplingStrategy ScalarInequalityConstraint: properties: target_column: type: string title: Target Column rhs: type: number title: Rhs operator: $ref: '#/components/schemas/InequalityOperator' additionalProperties: false type: object required: - target_column - rhs - operator title: ScalarInequalityConstraint ScipySamplerParams: properties: dist_name: type: string title: Dist Name description: Name of a scipy.stats distribution. dist_params: additionalProperties: true type: object title: Dist Params description: Parameters of the scipy.stats distribution given in `dist_name`. decimal_places: type: integer title: Decimal Places description: Number of decimal places to round the sampled values to. sampler_type: type: string const: scipy title: Sampler Type default: scipy additionalProperties: false type: object required: - dist_name - dist_params title: ScipySamplerParams description: "Parameters for sampling from any scipy.stats continuous or discrete\ \ distribution.\n\nProvides a flexible interface to sample from the wide range\ \ of probability distributions\navailable in scipy.stats. This enables advanced\ \ statistical sampling beyond the built-in\ndistribution types (Gaussian,\ \ Uniform, etc.).\n\nSee: [scipy.stats documentation](https://docs.scipy.org/doc/scipy/reference/stats.html)\n\ \nAttributes:\n dist_name: Name of the scipy.stats distribution to sample\ \ from (e.g., \"beta\", \"gamma\",\n \"lognorm\", \"expon\"). Must\ \ be a valid distribution name from scipy.stats.\n dist_params: Dictionary\ \ of parameters for the specified distribution. Parameter names\n and\ \ values must match the scipy.stats distribution specification (e.g., {\"\ a\": 2, \"b\": 5}\n for beta distribution, {\"scale\": 1.5} for exponential).\n\ \ decimal_places: Optional number of decimal places to round sampled values\ \ to. If None,\n values are not rounded." Score: properties: name: type: string title: Name description: A clear name for this score. description: type: string title: Description description: An informative and detailed assessment guide for using this score. options: additionalProperties: type: string type: object title: Options description: 'Score options in the format of {score: description}.' additionalProperties: false type: object required: - name - description - options title: Score description: "Configuration for a \"score\" in an LLM judge evaluation.\n\n\ Defines a single scoring criterion with its possible values and descriptions.\ \ Multiple\nScore objects can be combined in an LLMJudgeColumnConfig to create\ \ multi-dimensional\nquality assessments.\n\nAttributes:\n name: A clear,\ \ concise name for this scoring dimension (e.g., \"Relevance\", \"Fluency\"\ ).\n description: An informative and detailed assessment guide explaining\ \ how to evaluate\n this dimension. Should provide clear criteria for\ \ scoring.\n options: Dictionary mapping score values to their descriptions.\ \ Keys can be integers\n (e.g., 1-5 scale) or strings (e.g., \"Poor\"\ , \"Good\", \"Excellent\"). Values are\n descriptions explaining what\ \ each score level means." SeedConfig: properties: dataset: type: string title: Dataset sampling_strategy: allOf: - $ref: '#/components/schemas/SamplingStrategy' default: ordered selection_strategy: anyOf: - $ref: '#/components/schemas/IndexRange' - $ref: '#/components/schemas/PartitionBlock' title: Selection Strategy additionalProperties: false type: object required: - dataset title: SeedConfig description: "Configuration for sampling data from a seed dataset.\n\nArgs:\n\ \ dataset: Path or identifier for the seed dataset.\n sampling_strategy:\ \ Strategy for how to sample rows from the dataset.\n - ORDERED: Read\ \ rows sequentially in their original order.\n - SHUFFLE: Randomly\ \ shuffle rows before sampling. When used with\n selection_strategy,\ \ shuffling occurs within the selected range/partition.\n selection_strategy:\ \ Optional strategy to select a subset of the dataset.\n - IndexRange:\ \ Select a specific range of indices (e.g., rows 100-200).\n - PartitionBlock:\ \ Select a partition by splitting the dataset into N equal parts.\n \ \ Partition indices are zero-based (index=0 is the first partition, index=1\ \ is\n the second, etc.).\n\nExamples:\n Read rows sequentially\ \ from start to end:\n SeedConfig(dataset=\"my_data.parquet\", sampling_strategy=SamplingStrategy.ORDERED)\n\ \n Read rows in random order:\n SeedConfig(dataset=\"my_data.parquet\"\ , sampling_strategy=SamplingStrategy.SHUFFLE)\n\n Read specific index range\ \ (rows 100-199):\n SeedConfig(\n dataset=\"my_data.parquet\"\ ,\n sampling_strategy=SamplingStrategy.ORDERED,\n selection_strategy=IndexRange(start=100,\ \ end=199)\n )\n\n Read random rows from a specific index range\ \ (shuffles within rows 100-199):\n SeedConfig(\n dataset=\"\ my_data.parquet\",\n sampling_strategy=SamplingStrategy.SHUFFLE,\n\ \ selection_strategy=IndexRange(start=100, end=199)\n )\n\ \n Read from partition 2 (3rd partition, zero-based) of 5 partitions (20%\ \ of dataset):\n SeedConfig(\n dataset=\"my_data.parquet\"\ ,\n sampling_strategy=SamplingStrategy.ORDERED,\n selection_strategy=PartitionBlock(index=2,\ \ num_partitions=5)\n )\n\n Read shuffled rows from partition 0\ \ of 10 partitions (shuffles within the partition):\n SeedConfig(\n\ \ dataset=\"my_data.parquet\",\n sampling_strategy=SamplingStrategy.SHUFFLE,\n\ \ selection_strategy=PartitionBlock(index=0, num_partitions=10)\n\ \ )" SeedDatasetColumnConfig: properties: name: type: string title: Name drop: type: boolean title: Drop default: false column_type: type: string const: seed-dataset title: Column Type default: seed-dataset additionalProperties: false type: object required: - name title: SeedDatasetColumnConfig description: "Configuration for columns sourced from seed datasets.\n\nThis\ \ config marks columns that come from seed data. It is typically created\n\ automatically when calling `with_seed_dataset()` on the builder, rather than\n\ being instantiated directly by users.\n\nAttributes:\n column_type: Discriminator\ \ field, always \"seed-dataset\" for this configuration type." SettingsDefaults: properties: model_configs: items: $ref: '#/components/schemas/ModelConfigOutput' type: array title: Model Configs model_provider: type: string title: Model Provider type: object required: - model_configs - model_provider title: SettingsDefaults SettingsResponse: properties: defaults: $ref: '#/components/schemas/SettingsDefaults' model_providers: items: $ref: '#/components/schemas/DisplayModelProvider' type: array title: Model Providers type: object required: - defaults - model_providers title: SettingsResponse SubcategorySamplerParams: properties: category: type: string title: Category description: Name of parent category to this subcategory. values: additionalProperties: items: anyOf: - type: string - type: integer - type: number type: array type: object title: Values description: Mapping from each value of parent category to a list of subcategory values. sampler_type: type: string const: subcategory title: Sampler Type default: subcategory additionalProperties: false type: object required: - category - values title: SubcategorySamplerParams description: "Parameters for subcategory sampling conditioned on a parent category\ \ column.\n\nSamples subcategory values based on the value of a parent category\ \ column. Each parent\ncategory value maps to its own list of possible subcategory\ \ values, enabling hierarchical\nor conditional sampling patterns.\n\nAttributes:\n\ \ category: Name of the parent category column that this subcategory depends\ \ on.\n The parent column must be generated before this subcategory\ \ column.\n values: Mapping from each parent category value to a list of\ \ possible subcategory values.\n Each key must correspond to a value\ \ that appears in the parent category column." TimeDeltaSamplerParams: properties: dt_min: type: integer minimum: 0.0 title: Dt Min description: Minimum possible time-delta for sampling range, inclusive. Must be less than `dt_max`. dt_max: type: integer exclusiveMinimum: 0.0 title: Dt Max description: Maximum possible time-delta for sampling range, exclusive. Must be greater than `dt_min`. reference_column_name: type: string title: Reference Column Name description: Name of an existing datetime column to condition time-delta sampling on. unit: type: string enum: - D - h - m - s title: Unit description: Sampling units, e.g. the smallest possible time interval between samples. default: D sampler_type: type: string const: timedelta title: Sampler Type default: timedelta additionalProperties: false type: object required: - dt_min - dt_max - reference_column_name title: TimeDeltaSamplerParams description: "Parameters for sampling time deltas relative to a reference datetime\ \ column.\n\nSamples time offsets within a specified range and adds them to\ \ values from a reference\ndatetime column. This is useful for generating\ \ related datetime columns like order dates\nand delivery dates, or event\ \ start times and end times.\n\nNote:\n Years and months are not supported\ \ as timedelta units because they have variable lengths.\n See: [pandas\ \ timedelta documentation](https://pandas.pydata.org/docs/user_guide/timedeltas.html)\n\ \nAttributes:\n dt_min: Minimum time-delta value (inclusive). Must be non-negative\ \ and less than `dt_max`.\n Specified in units defined by the `unit`\ \ parameter.\n dt_max: Maximum time-delta value (exclusive). Must be positive\ \ and greater than `dt_min`.\n Specified in units defined by the `unit`\ \ parameter.\n reference_column_name: Name of an existing datetime column\ \ to add the time-delta to.\n This column must be generated before\ \ the timedelta column.\n unit: Time unit for the delta values. Options:\n\ \ - \"D\": Days (default)\n - \"h\": Hours\n - \"m\"\ : Minutes\n - \"s\": Seconds" UUIDSamplerParams: properties: prefix: type: string title: Prefix description: String prepended to the front of the UUID. short_form: type: boolean title: Short Form description: If true, all UUIDs sampled will be truncated at 8 characters. default: false uppercase: type: boolean title: Uppercase description: If true, all letters in the UUID will be capitalized. default: false sampler_type: type: string const: uuid title: Sampler Type default: uuid additionalProperties: false type: object title: UUIDSamplerParams description: "Parameters for generating UUID (Universally Unique Identifier)\ \ values.\n\nGenerates UUID4 (random) identifiers with optional formatting\ \ options. UUIDs are useful\nfor creating unique identifiers for records,\ \ entities, or transactions.\n\nAttributes:\n prefix: Optional string to\ \ prepend to each UUID. Useful for creating namespaced or\n typed identifiers\ \ (e.g., \"user-\", \"order-\", \"txn-\").\n short_form: If True, truncates\ \ UUIDs to 8 characters (first segment only). Default is False\n for\ \ full 32-character UUIDs (excluding hyphens).\n uppercase: If True, converts\ \ all hexadecimal letters to uppercase. Default is False for\n lowercase\ \ UUIDs." UniformDistribution: properties: distribution_type: allOf: - $ref: '#/components/schemas/DistributionType' default: uniform params: $ref: '#/components/schemas/UniformDistributionParams' additionalProperties: false type: object required: - params title: UniformDistribution UniformDistributionParams: properties: low: type: number title: Low high: type: number title: High additionalProperties: false type: object required: - low - high title: UniformDistributionParams UniformSamplerParams: properties: low: type: number title: Low description: Lower bound of the uniform distribution, inclusive. high: type: number title: High description: Upper bound of the uniform distribution, inclusive. decimal_places: type: integer title: Decimal Places description: Number of decimal places to round the sampled values to. sampler_type: type: string const: uniform title: Sampler Type default: uniform additionalProperties: false type: object required: - low - high title: UniformSamplerParams description: "Parameters for sampling from a continuous Uniform distribution.\n\ \nSamples continuous values uniformly from a specified range, where every\ \ value in the range\nhas equal probability of being sampled. This is useful\ \ when all values within a range are\nequally likely, such as random percentages,\ \ proportions, or unbiased measurements.\n\nAttributes:\n low: Lower bound\ \ of the uniform distribution (inclusive). Can be any real number.\n high:\ \ Upper bound of the uniform distribution (inclusive). Must be greater than\ \ `low`.\n decimal_places: Optional number of decimal places to round sampled\ \ values to. If None,\n values are not rounded and may have many decimal\ \ places." ValidationColumnConfig: properties: name: type: string title: Name drop: type: boolean title: Drop default: false column_type: type: string const: validation title: Column Type default: validation target_columns: items: type: string type: array title: Target Columns validator_type: $ref: '#/components/schemas/ValidatorType' validator_params: anyOf: - $ref: '#/components/schemas/CodeValidatorParams' - $ref: '#/components/schemas/LocalCallableValidatorParams' - $ref: '#/components/schemas/RemoteValidatorParams' title: Validator Params batch_size: type: integer minimum: 1.0 title: Batch Size description: Number of records to process in each batch default: 10 additionalProperties: false type: object required: - name - target_columns - validator_type - validator_params title: ValidationColumnConfig description: "Configuration for validation columns that validate existing columns.\n\ \nValidation columns execute validation logic against specified target columns\ \ and return\nstructured results indicating pass/fail status with validation\ \ details. Supports multiple\nvalidation strategies: code execution (Python/SQL),\ \ local callable functions (library only),\nand remote HTTP endpoints.\n\n\ Attributes:\n target_columns: List of column names to validate. These columns\ \ are passed to the\n validator for validation. All target columns\ \ must exist in the dataset\n before validation runs.\n validator_type:\ \ The type of validator to use. Options:\n - \"code\": Execute code\ \ (Python or SQL) for validation. The code receives a\n DataFrame\ \ with target columns and must return a DataFrame with validation results.\n\ \ - \"local_callable\": Call a local Python function with the data.\ \ Only supported\n when running DataDesigner locally.\n -\ \ \"remote\": Send data to a remote HTTP endpoint for validation. Useful for\n\ \ validator_params: Parameters specific to the validator type. Type varies\ \ by validator:\n - CodeValidatorParams: Specifies code language (python\ \ or SQL dialect like\n \"sql:postgres\", \"sql:mysql\").\n \ \ - LocalCallableValidatorParams: Provides validation function (Callable[[pd.DataFrame],\n\ \ pd.DataFrame]) and optional output schema for validation results.\n\ \ - RemoteValidatorParams: Configures endpoint URL, HTTP timeout, retry\ \ behavior\n (max_retries, retry_backoff), and parallel request limits\ \ (max_parallel_requests).\n batch_size: Number of records to process in\ \ each validation batch. Defaults to 10.\n Larger batches are more\ \ efficient but use more memory. Adjust based on validator\n complexity\ \ and available resources.\n column_type: Discriminator field, always \"\ validation\" for this configuration type." ValidationError: properties: loc: items: anyOf: - type: string - type: integer type: array title: Location msg: type: string title: Message type: type: string title: Error Type type: object required: - loc - msg - type title: ValidationError ValidatorType: type: string enum: - code - local_callable - remote title: ValidatorType tags: - name: Data Designer description: Operations related to synthetic data generation. - name: Health Checks description: Operations related to NeMo Microservices platform health. ================================================ FILE: studio/frontend/eslint.config.js ================================================ // SPDX-License-Identifier: AGPL-3.0-only // Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 import js from "@eslint/js"; import reactHooks from "eslint-plugin-react-hooks"; import reactRefresh from "eslint-plugin-react-refresh"; import { defineConfig, globalIgnores } from "eslint/config"; import globals from "globals"; import tseslint from "typescript-eslint"; export default defineConfig([ globalIgnores(["dist", "**/._*"]), { files: ["**/*.{ts,tsx}"], extends: [ js.configs.recommended, tseslint.configs.recommended, reactHooks.configs.flat.recommended, reactRefresh.configs.vite, ], languageOptions: { ecmaVersion: 2020, globals: globals.browser, }, rules: { // Allow shadcn ui components to export variants "react-refresh/only-export-components": [ "warn", { allowConstantExport: true }, ], // Import restrictions for architecture enforcement "no-restricted-imports": [ "error", { patterns: [ // Prevent cross-feature imports { group: ["@/features/*/*"], message: "Import from feature index only: @/features/[name]", }, // Prevent app layer from importing features internals { group: ["../features/*/**"], message: "Use absolute imports: @/features/[name]", }, ], }, ], }, }, ]); ================================================ FILE: studio/frontend/index.html ================================================ Unsloth Studio
================================================ FILE: studio/frontend/package.json ================================================ { "name": "unsloth-theme", "private": true, "version": "0.0.0", "type": "module", "scripts": { "dev": "vite", "build": "tsc -b && vite build", "lint": "eslint .", "preview": "vite preview", "typecheck": "tsc -b --pretty false", "biome:check": "biome check .", "biome:fix": "biome check . --write" }, "dependencies": { "@assistant-ui/react": "^0.12.19", "@assistant-ui/react-markdown": "^0.12.3", "@assistant-ui/react-streamdown": "^0.1.2", "@base-ui/react": "^1.2.0", "@dagrejs/dagre": "^2.0.4", "@dagrejs/graphlib": "^3.0.4", "@fontsource-variable/figtree": "^5.2.10", "@fontsource-variable/inter": "^5.2.8", "@fontsource-variable/space-grotesk": "^5.2.10", "@hugeicons/core-free-icons": "^3.1.1", "@hugeicons/react": "^1.1.5", "@huggingface/hub": "^2.9.0", "@langchain/core": "^1.1.27", "@radix-ui/react-checkbox": "^1.3.3", "@radix-ui/react-label": "^2.1.8", "@radix-ui/react-select": "^2.2.6", "@radix-ui/react-separator": "^1.1.8", "@radix-ui/react-slot": "^1.2.4", "@streamdown/cjk": "1.0.2", "@streamdown/code": "1.0.2", "@streamdown/math": "1.0.2", "@streamdown/mermaid": "1.0.2", "@tailwindcss/vite": "^4.1.18", "@tanstack/react-router": "^1.159.10", "@tanstack/react-table": "^8.21.3", "@toolwind/corner-shape": "^0.0.8-3", "@types/canvas-confetti": "^1.9.0", "@xyflow/react": "^12.10.0", "assistant-stream": "^0.3.2", "canvas-confetti": "^1.9.4", "class-variance-authority": "^0.7.1", "clsx": "^2.1.1", "cmdk": "^1.1.1", "date-fns": "^4.1.0", "dexie": "^4.3.0", "framer-motion": "^11.18.2", "js-yaml": "^4.1.1", "katex": "^0.16.28", "lucide-react": "^0.577.0", "mammoth": "^1.11.0", "motion": "^12.34.0", "next": "^16.1.6", "next-themes": "^0.4.6", "radix-ui": "^1.4.3", "react": "^19.2.4", "react-day-picker": "^9.13.2", "react-dom": "^19.2.4", "react-resizable-panels": "^4.6.4", "recharts": "3.7.0", "remark-gfm": "^4.0.1", "shadcn": "^3.8.4", "sonner": "^2.0.7", "streamdown": "2.3.0", "tailwind-merge": "^3.4.0", "tailwindcss": "^4.1.18", "tw-animate-css": "^1.4.0", "tw-shimmer": "^0.4.6", "unpdf": "^1.4.0", "zustand": "^5.0.11" }, "devDependencies": { "@biomejs/biome": "^1.9.4", "@eslint/js": "^9.39.1", "@types/js-yaml": "^4.0.9", "@types/node": "^24.10.1", "@types/react": "^19.2.5", "@types/react-dom": "^19.2.3", "@vitejs/plugin-react": "^5.1.1", "eslint": "^9.39.1", "eslint-plugin-react-hooks": "^7.0.1", "eslint-plugin-react-refresh": "^0.4.26", "globals": "^16.5.0", "typescript": "~5.9.3", "typescript-eslint": "^8.55.0", "vite": "^7.3.1" } } ================================================ FILE: studio/frontend/src/app/app.tsx ================================================ // SPDX-License-Identifier: AGPL-3.0-only // Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 import { RouterProvider } from "@tanstack/react-router"; import { router } from "./router"; export function App() { return ; } ================================================ FILE: studio/frontend/src/app/auth-guards.ts ================================================ // SPDX-License-Identifier: AGPL-3.0-only // Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 import { redirect } from "@tanstack/react-router"; import { getPostAuthRoute, hasAuthToken, hasRefreshToken, mustChangePassword, refreshSession, } from "@/features/auth"; async function hasActiveSession(): Promise { if (hasAuthToken()) return true; if (!hasRefreshToken()) return false; return refreshSession(); } async function checkAuthInitialized(): Promise { try { const res = await fetch("/api/auth/status"); if (!res.ok) return true; // fallback to login on error const data = (await res.json()) as { initialized: boolean }; return data.initialized; } catch { return true; // fallback to login on error } } async function checkPasswordChangeRequired(): Promise { try { const res = await fetch("/api/auth/status"); if (!res.ok) return mustChangePassword(); const data = (await res.json()) as { requires_password_change: boolean }; return data.requires_password_change || mustChangePassword(); } catch { return mustChangePassword(); } } export async function requireAuth(): Promise { if (await hasActiveSession()) { if (await checkPasswordChangeRequired()) { throw redirect({ to: "/change-password" }); } return; } const requiresPasswordChange = await checkPasswordChangeRequired(); if (requiresPasswordChange) throw redirect({ to: "/change-password" }); const initialized = await checkAuthInitialized(); throw redirect({ to: initialized ? "/login" : "/change-password" }); } export async function requireGuest(): Promise { if (!(await hasActiveSession())) return; throw redirect({ to: getPostAuthRoute() }); } export async function requirePasswordChangeFlow(): Promise { const requiresPasswordChange = await checkPasswordChangeRequired(); if (requiresPasswordChange) return; if (await hasActiveSession()) { throw redirect({ to: getPostAuthRoute() }); } const initialized = await checkAuthInitialized(); throw redirect({ to: initialized ? "/login" : "/change-password" }); } ================================================ FILE: studio/frontend/src/app/provider.tsx ================================================ // SPDX-License-Identifier: AGPL-3.0-only // Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 import { Toaster } from "@/components/ui/sonner"; import { ThemeProvider } from "next-themes"; import type { ReactNode } from "react"; interface AppProviderProps { children: ReactNode; } export function AppProvider({ children }: AppProviderProps) { return ( {children} ); } ================================================ FILE: studio/frontend/src/app/router.tsx ================================================ // SPDX-License-Identifier: AGPL-3.0-only // Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 import { createRouter } from "@tanstack/react-router"; import { Route as rootRoute } from "./routes/__root"; import { Route as dataRecipesRoute } from "./routes/data-recipes"; import { Route as dataRecipeRoute } from "./routes/data-recipes.$recipeId"; import { Route as chatRoute } from "./routes/chat"; import { Route as exportRoute } from "./routes/export"; import { Route as gridTestRoute } from "./routes/grid-test"; import { Route as indexRoute } from "./routes/index"; import { Route as loginRoute } from "./routes/login"; import { Route as onboardingRoute } from "./routes/onboarding"; import { Route as changePasswordRoute } from "./routes/change-password"; import { Route as studioRoute } from "./routes/studio"; const routeTree = rootRoute.addChildren([ indexRoute, onboardingRoute, loginRoute, changePasswordRoute, gridTestRoute, studioRoute, chatRoute, exportRoute, dataRecipesRoute, dataRecipeRoute, ]); export const router = createRouter({ routeTree }); declare module "@tanstack/react-router" { interface Register { router: typeof router; } } ================================================ FILE: studio/frontend/src/app/routes/__root.tsx ================================================ // SPDX-License-Identifier: AGPL-3.0-only // Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 import { Navbar } from "@/components/navbar"; import { usePlatformStore } from "@/config/env"; import { Outlet, createRootRoute, redirect, useRouterState, } from "@tanstack/react-router"; import { AnimatePresence, motion } from "motion/react"; import { Suspense } from "react"; import { AppProvider } from "../provider"; const CHAT_ONLY_ALLOWED = new Set(["/", "/chat", "/login", "/signup", "/change-password"]); function isChatOnlyAllowed(pathname: string): boolean { if (CHAT_ONLY_ALLOWED.has(pathname)) return true; if (pathname === "/data-recipes" || pathname.startsWith("/data-recipes/")) return true; return false; } export const Route = createRootRoute({ beforeLoad: ({ location }) => { const chatOnly = usePlatformStore.getState().isChatOnly(); if (chatOnly && !isChatOnlyAllowed(location.pathname)) { throw redirect({ to: "/chat" }); } }, component: RootLayout, }); const HIDDEN_NAVBAR_ROUTES = ["/onboarding", "/login", "/change-password"]; function RootLayout() { const pathname = useRouterState({ select: (s) => s.location.pathname }); const hideNavbar = HIDDEN_NAVBAR_ROUTES.includes(pathname); return ( {!hideNavbar && } ); } ================================================ FILE: studio/frontend/src/app/routes/change-password.tsx ================================================ // SPDX-License-Identifier: AGPL-3.0-only // Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 import { createRoute } from "@tanstack/react-router"; import { lazy } from "react"; import { requirePasswordChangeFlow } from "../auth-guards"; import { Route as rootRoute } from "./__root"; const ChangePasswordPage = lazy(() => import("@/features/auth").then((m) => ({ default: m.ChangePasswordPage, })), ); export const Route = createRoute({ getParentRoute: () => rootRoute, path: "/change-password", beforeLoad: () => requirePasswordChangeFlow(), component: ChangePasswordPage, }); ================================================ FILE: studio/frontend/src/app/routes/chat.tsx ================================================ // SPDX-License-Identifier: AGPL-3.0-only // Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 import { createRoute } from "@tanstack/react-router"; import { lazy } from "react"; import { requireAuth } from "../auth-guards"; import { Route as rootRoute } from "./__root"; const ChatPage = lazy(() => import("@/features/chat/chat-page").then((m) => ({ default: m.ChatPage })), ); export const Route = createRoute({ getParentRoute: () => rootRoute, path: "/chat", beforeLoad: () => requireAuth(), component: ChatPage, }); ================================================ FILE: studio/frontend/src/app/routes/data-recipes.$recipeId.tsx ================================================ // SPDX-License-Identifier: AGPL-3.0-only // Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 import { createRoute } from "@tanstack/react-router"; import type { ReactElement } from "react"; import { lazy } from "react"; import { requireAuth } from "../auth-guards"; import { Route as rootRoute } from "./__root"; const EditRecipePage = lazy(() => import("@/features/data-recipes").then((m) => ({ default: m.EditRecipePage, })), ); export const Route = createRoute({ getParentRoute: () => rootRoute, path: "/data-recipes/$recipeId", beforeLoad: () => requireAuth(), component: DataRecipeEditorRoute, }); function DataRecipeEditorRoute(): ReactElement { const { recipeId } = Route.useParams(); return ; } ================================================ FILE: studio/frontend/src/app/routes/data-recipes.tsx ================================================ // SPDX-License-Identifier: AGPL-3.0-only // Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 import { createRoute } from "@tanstack/react-router"; import { lazy } from "react"; import { requireAuth } from "../auth-guards"; import { Route as rootRoute } from "./__root"; const DataRecipesPage = lazy(() => import("@/features/data-recipes").then((m) => ({ default: m.DataRecipesPage, })), ); export const Route = createRoute({ getParentRoute: () => rootRoute, path: "/data-recipes", beforeLoad: () => requireAuth(), component: DataRecipesPage, }); ================================================ FILE: studio/frontend/src/app/routes/export.tsx ================================================ // SPDX-License-Identifier: AGPL-3.0-only // Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 import { createRoute } from "@tanstack/react-router"; import { lazy } from "react"; import { requireAuth } from "../auth-guards"; import { Route as rootRoute } from "./__root"; const ExportPage = lazy(() => import("@/features/export/export-page").then((m) => ({ default: m.ExportPage, })), ); export const Route = createRoute({ getParentRoute: () => rootRoute, path: "/export", beforeLoad: () => requireAuth(), component: ExportPage, }); ================================================ FILE: studio/frontend/src/app/routes/grid-test.tsx ================================================ // SPDX-License-Identifier: AGPL-3.0-only // Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 import { DashboardGrid, DashboardLayout } from "@/components/layout"; import { Card, CardContent, CardDescription, CardHeader, CardTitle, } from "@/components/ui/card"; import { createRoute } from "@tanstack/react-router"; import { requireAuth } from "../auth-guards"; import { Route as rootRoute } from "./__root"; export const Route = createRoute({ getParentRoute: () => rootRoute, path: "/grid-test", beforeLoad: () => requireAuth(), component: GridTestPage, }); function GridTestPage() { return (

Grid Test - 3 Columns

max-w-7xl, gap-6, responsive 1→2→3

{[1, 2, 3].map((i) => ( Card {i} ~400px at 1280px viewport
))}

4 Columns

~296px per card at 1280px

{[1, 2, 3, 4].map((i) => ( Card {i} Smaller cards
))}
); } ================================================ FILE: studio/frontend/src/app/routes/index.tsx ================================================ // SPDX-License-Identifier: AGPL-3.0-only // Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 import { createRoute, redirect } from "@tanstack/react-router"; import { getPostAuthRoute } from "@/features/auth"; import { requireAuth } from "../auth-guards"; import { Route as rootRoute } from "./__root"; export const Route = createRoute({ getParentRoute: () => rootRoute, path: "/", beforeLoad: async () => { await requireAuth(); throw redirect({ to: getPostAuthRoute() }); }, component: () => null, }); ================================================ FILE: studio/frontend/src/app/routes/login.tsx ================================================ // SPDX-License-Identifier: AGPL-3.0-only // Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 import { createRoute } from "@tanstack/react-router"; import { lazy } from "react"; import { requireGuest } from "../auth-guards"; import { Route as rootRoute } from "./__root"; const LoginPage = lazy(() => import("@/features/auth").then((m) => ({ default: m.LoginPage })), ); export const Route = createRoute({ getParentRoute: () => rootRoute, path: "/login", beforeLoad: () => requireGuest(), component: LoginPage, }); ================================================ FILE: studio/frontend/src/app/routes/onboarding.tsx ================================================ // SPDX-License-Identifier: AGPL-3.0-only // Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 import { createRoute } from "@tanstack/react-router"; import { lazy } from "react"; import { requireAuth } from "../auth-guards"; import { Route as rootRoute } from "./__root"; const WizardLayout = lazy(() => import("@/features/onboarding/components/wizard-layout").then((m) => ({ default: m.WizardLayout, })), ); export const Route = createRoute({ getParentRoute: () => rootRoute, path: "/onboarding", beforeLoad: () => requireAuth(), component: WizardLayout, }); ================================================ FILE: studio/frontend/src/app/routes/studio.tsx ================================================ // SPDX-License-Identifier: AGPL-3.0-only // Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 import { createRoute } from "@tanstack/react-router"; import { lazy } from "react"; import { requireAuth } from "../auth-guards"; import { Route as rootRoute } from "./__root"; const StudioPage = lazy(() => import("@/features/studio/studio-page").then((m) => ({ default: m.StudioPage, })), ); export const Route = createRoute({ getParentRoute: () => rootRoute, path: "/studio", beforeLoad: () => requireAuth(), component: StudioPage, }); ================================================ FILE: studio/frontend/src/components/assistant-ui/attachment.tsx ================================================ // SPDX-License-Identifier: AGPL-3.0-only // Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 "use client"; // Avatar removed — caused circular crop on image thumbnails import { TooltipIconButton } from "@/components/assistant-ui/tooltip-icon-button"; import { Dialog, DialogContent, DialogTitle, DialogTrigger, } from "@/components/ui/dialog"; import { Tooltip, TooltipContent, TooltipTrigger, } from "@/components/ui/tooltip"; import { cn } from "@/lib/utils"; import { AttachmentPrimitive, ComposerPrimitive, MessagePrimitive, useAui, useAuiState, } from "@assistant-ui/react"; import { FileText, PlusIcon, XIcon } from "lucide-react"; import { type FC, type PropsWithChildren, useEffect, useState, } from "react"; import { useShallow } from "zustand/shallow"; const useFileSrc = (file: File | undefined): string | undefined => { const [objectUrl, setObjectUrl] = useState(undefined); useEffect(() => { if (!file) { setObjectUrl(undefined); return; } const url = URL.createObjectURL(file); setObjectUrl(url); return () => URL.revokeObjectURL(url); }, [file]); return objectUrl; }; const useAttachmentSrc = (): string | undefined => { const { file, src } = useAuiState( useShallow(({ attachment }): { file?: File; src?: string } => { if (attachment.type !== "image") { return {}; } if (attachment.file) { return { file: attachment.file }; } const src = attachment.content?.filter((c) => c.type === "image")[0] ?.image; if (!src) { return {}; } return { src }; }), ); return useFileSrc(file) ?? src; }; type AttachmentPreviewProps = { src: string; }; const AttachmentPreview: FC = ({ src }) => { const [isLoaded, setIsLoaded] = useState(false); return ( Preview setIsLoaded(true)} /> ); }; const AttachmentPreviewDialog: FC = ({ children }) => { const src = useAttachmentSrc(); if (!src) { return children; } return ( {children} Image Attachment Preview
); }; const AttachmentThumb: FC = () => { const src = useAttachmentSrc(); if (src) { return ( Attachment preview ); } return (
); }; const AttachmentUI: FC = () => { const aui = useAui(); const isComposer = aui.attachment.source === "composer"; const isImage = useAuiState(({ attachment }) => attachment.type === "image"); const typeLabel = useAuiState(({ attachment }) => { const type = attachment.type; switch (type) { case "image": return "Image"; case "document": return "Document"; case "file": return "File"; default: throw new Error(`Unknown attachment type: ${type as string}`); } }); return ( #attachment-tile]:size-16", )} > {isComposer && } ); }; const AttachmentRemove: FC = () => { return ( ); }; export const UserMessageAttachments: FC = () => { return (
); }; export const ComposerAttachments: FC = () => { return (
); }; export const ComposerAddAttachment: FC = () => { return ( ); }; ================================================ FILE: studio/frontend/src/components/assistant-ui/audio-player.tsx ================================================ // SPDX-License-Identifier: AGPL-3.0-only // Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 "use client"; import { Button } from "@/components/ui/button"; import { DownloadIcon, PauseIcon, PlayIcon } from "lucide-react"; import { type FC, useRef, useState } from "react"; interface AudioPlayerProps { src: string; } export const AudioPlayer: FC = ({ src }) => { const audioRef = useRef(null); const [isPlaying, setIsPlaying] = useState(false); const [progress, setProgress] = useState(0); const [duration, setDuration] = useState(0); const togglePlay = () => { const audio = audioRef.current; if (!audio) return; if (isPlaying) { audio.pause(); } else { audio.play(); } setIsPlaying(!isPlaying); }; const handleTimeUpdate = () => { const audio = audioRef.current; if (!audio) return; setProgress(audio.currentTime); }; const handleLoadedMetadata = () => { const audio = audioRef.current; if (!audio) return; setDuration(audio.duration); }; const handleEnded = () => { setIsPlaying(false); setProgress(0); }; const handleSeek = (e: React.ChangeEvent) => { const audio = audioRef.current; if (!audio) return; const time = parseFloat(e.target.value); audio.currentTime = time; setProgress(time); }; const handleDownload = () => { const link = document.createElement("a"); link.href = src; link.download = "generated-audio.wav"; link.click(); }; const formatTime = (t: number) => { const mins = Math.floor(t / 60); const secs = Math.floor(t % 60); return `${mins}:${secs.toString().padStart(2, "0")}`; }; return (
); }; ================================================ FILE: studio/frontend/src/components/assistant-ui/badge.tsx ================================================ "use client"; import type { ComponentProps } from "react"; import { Slot } from "radix-ui"; import { cva, type VariantProps } from "class-variance-authority"; import { cn } from "@/lib/utils"; const badgeVariants = cva( "inline-flex items-center justify-center gap-1 rounded-md font-medium text-xs transition-colors [&_svg]:size-3 [&_svg]:shrink-0", { variants: { variant: { outline: "border border-input bg-transparent text-muted-foreground hover:bg-accent hover:text-accent-foreground", secondary: "bg-secondary text-secondary-foreground hover:bg-secondary/80", muted: "bg-muted text-muted-foreground hover:bg-muted/80 hover:text-foreground", ghost: "bg-transparent text-muted-foreground hover:bg-accent hover:text-accent-foreground", info: "bg-blue-100 text-blue-700 hover:bg-blue-100/80 dark:bg-blue-900/50 dark:text-blue-300", warning: "bg-amber-100 text-amber-700 hover:bg-amber-100/80 dark:bg-amber-900/50 dark:text-amber-300", success: "bg-emerald-100 text-emerald-700 hover:bg-emerald-100/80 dark:bg-emerald-900/50 dark:text-emerald-300", destructive: "bg-red-100 text-red-700 hover:bg-red-100/80 dark:bg-red-900/50 dark:text-red-300", }, size: { sm: "px-1.5 py-0.5", default: "px-2 py-1", lg: "px-2.5 py-1.5 text-sm", }, }, defaultVariants: { variant: "outline", size: "default", }, }, ); export type BadgeProps = ComponentProps<"span"> & VariantProps & { asChild?: boolean; }; function Badge({ className, variant, size, asChild = false, ...props }: BadgeProps) { const Comp = asChild ? Slot.Root : "span"; return ( ); } export { Badge, badgeVariants }; ================================================ FILE: studio/frontend/src/components/assistant-ui/markdown-text.tsx ================================================ // SPDX-License-Identifier: AGPL-3.0-only // Copyright 2026-present the Unsloth AI Inc. team. All rights reserved. See /studio/LICENSE.AGPL-3.0 "use client"; import { copyToClipboard } from "@/lib/copy-to-clipboard"; import { INTERNAL, useMessagePartText } from "@assistant-ui/react"; import { Copy02Icon, Tick02Icon } from "@hugeicons/core-free-icons"; import { HugeiconsIcon } from "@hugeicons/react"; import { code } from "@streamdown/code"; import { math } from "@streamdown/math"; import { mermaid } from "@streamdown/mermaid"; import { DownloadIcon, Maximize2Icon, Minimize2Icon } from "lucide-react"; import { useEffect, useRef, useState } from "react"; import { Block, type BlockProps, Streamdown } from "streamdown"; import "katex/dist/katex.min.css"; import { AudioPlayer } from "./audio-player"; const { withSmoothContextProvider } = INTERNAL; const COPY_RESET_MS = 2000; const MERMAID_SOURCE_RE = /```mermaid\s*([\s\S]*?)```/i; const CODE_FENCE_RE = /^```([^\r\n`]*)\r?\n([\s\S]*?)\r?\n?```$/; const ACTION_PANEL_CLASS = "pointer-events-auto flex shrink-0 items-center gap-2 rounded-md border border-sidebar bg-sidebar/80 px-1.5 py-1 supports-[backdrop-filter]:bg-sidebar/70 supports-[backdrop-filter]:backdrop-blur"; const ACTION_BUTTON_CLASS = "cursor-pointer p-1 text-muted-foreground transition-all hover:text-foreground disabled:cursor-not-allowed disabled:opacity-50"; type CodeFence = { language: string | null; source: string; }; function getMermaidSource(blockContent: string): string | null { const source = blockContent.match(MERMAID_SOURCE_RE)?.[1]?.trim(); return source && source.length > 0 ? source : null; } function getCodeFence(blockContent: string): CodeFence | null { const match = blockContent.trimEnd().match(CODE_FENCE_RE); if (!match) { return null; } return { language: match[1]?.trim() || null, source: match[2], }; } function getCodeFilename(language: string | null) { const extByLanguage: Record = { bash: "sh", javascript: "js", js: "js", json: "json", jsx: "jsx", markdown: "md", md: "md", python: "py", py: "py", shell: "sh", sh: "sh", sql: "sql", ts: "ts", tsx: "tsx", typescript: "ts", svg: "svg", yaml: "yml", yml: "yml", }; const normalized = language?.toLowerCase(); const fallbackExt = normalized?.replace(/[^a-z0-9]+/g, "-"); const ext = normalized ? extByLanguage[normalized] || fallbackExt || "txt" : "txt"; return `snippet.${ext}`; } function isSvgFence(codeFence: CodeFence): boolean { const lang = codeFence.language?.toLowerCase() ?? ""; if (lang === "svg") return true; if ((lang === "xml" || lang === "html") && codeFence.source.trimStart().startsWith("]|on\w+\s*=|javascript:|]|]|]|]/i; function sanitizeSvg(source: string): string | null { if (UNSAFE_SVG_RE.test(source)) return null; return source; } function SvgPreview({ source }: { source: string }) { const dataUri = `data:image/svg+xml;charset=utf-8,${encodeURIComponent(source)}`; return (
SVG preview
); } const HTML_PREVIEW_DEFAULT_HEIGHT = 400; const HTML_PREVIEW_MAX_HEIGHT = 800; function HtmlPreview({ source }: { source: string }) { const iframeRef = useRef(null); const [height, setHeight] = useState(HTML_PREVIEW_DEFAULT_HEIGHT); const [enlarged, setEnlarged] = useState(false); useEffect(() => { const handler = (e: MessageEvent) => { if (e.source !== iframeRef.current?.contentWindow) return; if (typeof e.data?.htmlPreviewHeight === "number") { setHeight(Math.min(Math.max(e.data.htmlPreviewHeight, 100), HTML_PREVIEW_MAX_HEIGHT)); } }; window.addEventListener("message", handler); return () => window.removeEventListener("message", handler); }, []); useEffect(() => { if (!enlarged) return; const handler = (e: KeyboardEvent) => { if (e.key === "Escape") setEnlarged(false); }; window.addEventListener("keydown", handler); return () => window.removeEventListener("keydown", handler); }, [enlarged]); const resizeScript = ``; const srcDoc = source + resizeScript; if (enlarged) { return ( <>
{/* Placeholder keeps layout stable while overlay is shown */}
{ if (e.target === e.currentTarget) setEnlarged(false); }} >