Repository: oobabooga/text-generation-webui Branch: main Commit: 256431f25869 Files: 118 Total size: 1013.0 KB Directory structure: gitextract_rxb2e4y6/ ├── .github/ │ ├── ISSUE_TEMPLATE/ │ │ ├── bug_report_template.yml │ │ └── feature_request.md │ ├── dependabot.yml │ ├── pull_request_template.md │ └── workflows/ │ ├── build-everything-tgw.yml │ ├── build-portable-release-cuda.yml │ ├── build-portable-release-rocm.yml │ ├── build-portable-release-vulkan.yml │ └── build-portable-release.yml ├── .gitignore ├── LICENSE ├── README.md ├── cmd_linux.sh ├── cmd_macos.sh ├── cmd_windows.bat ├── docker/ │ ├── .dockerignore │ ├── TensorRT-LLM/ │ │ └── Dockerfile │ ├── amd/ │ │ ├── Dockerfile │ │ └── docker-compose.yml │ ├── cpu/ │ │ ├── Dockerfile │ │ └── docker-compose.yml │ ├── intel/ │ │ ├── Dockerfile │ │ └── docker-compose.yml │ └── nvidia/ │ ├── Dockerfile │ └── docker-compose.yml ├── docs/ │ ├── 01 - Chat Tab.md │ ├── 02 - Default and Notebook Tabs.md │ ├── 03 - Parameters Tab.md │ ├── 04 - Model Tab.md │ ├── 05 - Training Tab.md │ ├── 06 - Session Tab.md │ ├── 07 - Extensions.md │ ├── 08 - Additional Tips.md │ ├── 09 - Docker.md │ ├── 11 - AMD Setup.md │ ├── 12 - OpenAI API.md │ ├── 13 - Keyboard Shortcuts.md │ ├── Image Generation Tutorial.md │ ├── Multimodal Tutorial.md │ ├── README.md │ ├── Tool Calling Tutorial.md │ └── What Works.md ├── download-model.py ├── js/ │ ├── dark_theme.js │ ├── global_scope_js.js │ ├── katex/ │ │ └── auto-render.js │ ├── main.js │ ├── save_files.js │ ├── show_controls.js │ ├── switch_tabs.js │ └── update_big_picture.js ├── modules/ │ ├── LoRA.py │ ├── callbacks.py │ ├── chat.py │ ├── evaluate.py │ ├── exllamav3.py │ ├── exllamav3_hf.py │ ├── extensions.py │ ├── grammar/ │ │ ├── grammar_utils.py │ │ └── logits_process.py │ ├── html_generator.py │ ├── image_models.py │ ├── image_utils.py │ ├── llama_cpp_server.py │ ├── loaders.py │ ├── logging_colors.py │ ├── logits.py │ ├── metadata_gguf.py │ ├── models.py │ ├── models_settings.py │ ├── paths.py │ ├── presets.py │ ├── prompts.py │ ├── reasoning.py │ ├── sampler_hijack.py │ ├── sane_markdown_lists.py │ ├── shared.py │ ├── tensorrt_llm.py │ ├── text_generation.py │ ├── tool_parsing.py │ ├── tool_use.py │ ├── torch_utils.py │ ├── training.py │ ├── transformers_loader.py │ ├── ui.py │ ├── ui_chat.py │ ├── ui_default.py │ ├── ui_file_saving.py │ ├── ui_image_generation.py │ ├── ui_model_menu.py │ ├── ui_notebook.py │ ├── ui_parameters.py │ ├── ui_session.py │ ├── utils.py │ └── web_search.py ├── one_click.py ├── requirements/ │ ├── full/ │ │ ├── requirements.txt │ │ ├── requirements_amd.txt │ │ ├── requirements_apple_intel.txt │ │ ├── requirements_apple_silicon.txt │ │ ├── requirements_cpu_only.txt │ │ └── requirements_nowheels.txt │ └── portable/ │ ├── requirements.txt │ ├── requirements_amd.txt │ ├── requirements_apple_intel.txt │ ├── requirements_apple_silicon.txt │ ├── requirements_cpu_only.txt │ ├── requirements_cuda131.txt │ ├── requirements_nowheels.txt │ └── requirements_vulkan.txt ├── server.py ├── setup.cfg ├── start_linux.sh ├── start_macos.sh ├── start_windows.bat ├── update_wizard_linux.sh ├── update_wizard_macos.sh └── update_wizard_windows.bat ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/ISSUE_TEMPLATE/bug_report_template.yml ================================================ name: "Bug report" description: Report a bug labels: [ "bug" ] body: - type: markdown attributes: value: | Thanks for taking the time to fill out this bug report! - type: textarea id: bug-description attributes: label: Describe the bug description: A clear and concise description of what the bug is. placeholder: Bug description validations: required: true - type: checkboxes attributes: label: Is there an existing issue for this? description: Please search to see if an issue already exists for the issue you encountered. options: - label: I have searched the existing issues required: true - type: textarea id: reproduction attributes: label: Reproduction description: Please provide the steps necessary to reproduce your issue. placeholder: Reproduction validations: required: true - type: textarea id: screenshot attributes: label: Screenshot description: "If possible, please include screenshot(s) so that we can understand what the issue is." - type: textarea id: logs attributes: label: Logs description: "Please include the full stacktrace of the errors you get in the command-line (if any)." render: shell validations: required: true - type: textarea id: system-info attributes: label: System Info description: "Please share your operating system and GPU type (NVIDIA/AMD/Intel/Apple). If you are using a Google Colab notebook, mention that instead." render: shell placeholder: validations: required: true ================================================ FILE: .github/ISSUE_TEMPLATE/feature_request.md ================================================ --- name: Feature request about: Suggest an improvement or new feature for the web UI title: '' labels: 'enhancement' assignees: '' --- **Description** A clear and concise description of what you want to be implemented. **Additional Context** If applicable, please provide any extra information, external links, or screenshots that could be useful. ================================================ FILE: .github/dependabot.yml ================================================ # To get started with Dependabot version updates, you'll need to specify which # package ecosystems to update and where the package manifests are located. # Please see the documentation for all configuration options: # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates version: 2 updates: - package-ecosystem: "pip" directories: - "/requirements/full/" - "/requirements/portable/" target-branch: "dev" schedule: interval: "weekly" ================================================ FILE: .github/pull_request_template.md ================================================ ## Checklist: - [ ] I have read the [Contributing guidelines](https://github.com/oobabooga/text-generation-webui/wiki/Contributing-guidelines). ================================================ FILE: .github/workflows/build-everything-tgw.yml ================================================ name: Build Everything TGW on: workflow_dispatch: inputs: version: description: 'Version tag of text-generation-webui to build: v3.0' default: 'v3.0' required: true type: string permissions: contents: write jobs: build_release_cuda_windows: name: CUDA Windows uses: ./.github/workflows/build-portable-release-cuda.yml with: version: ${{ inputs.version }} config: 'os:windows-2022' build_release_cuda_linux: name: CUDA Linux uses: ./.github/workflows/build-portable-release-cuda.yml with: version: ${{ inputs.version }} config: 'os:ubuntu-22.04' build_release_vulkan_windows: name: Vulkan Windows uses: ./.github/workflows/build-portable-release-vulkan.yml with: version: ${{ inputs.version }} config: 'os:windows-2022' build_release_vulkan_linux: name: Vulkan Linux uses: ./.github/workflows/build-portable-release-vulkan.yml with: version: ${{ inputs.version }} config: 'os:ubuntu-22.04' build_release_rocm_linux: name: ROCm Linux uses: ./.github/workflows/build-portable-release-rocm.yml with: version: ${{ inputs.version }} config: 'os:ubuntu-22.04' build_release_cpu_windows: name: CPU Windows uses: ./.github/workflows/build-portable-release.yml with: version: ${{ inputs.version }} config: 'os:windows-2022' build_release_cpu_linux: name: CPU Linux uses: ./.github/workflows/build-portable-release.yml with: version: ${{ inputs.version }} config: 'os:ubuntu-22.04' build_release_macos: name: macOS uses: ./.github/workflows/build-portable-release.yml with: version: ${{ inputs.version }} config: 'os:macos-15-intel,macos-14' ================================================ FILE: .github/workflows/build-portable-release-cuda.yml ================================================ name: Build CUDA on: workflow_dispatch: inputs: version: description: 'Version tag of text-generation-webui to build: v3.0' default: 'v3.0' required: true type: string config: description: 'Override configurations to build: key1:item1-1,item1-2;key2:item2-1,item2-2' default: 'Default' required: false type: string exclude: description: 'Exclude build configurations: key1-1:item1-1,key1-2:item1-2;key2-1:item2-1,key2-2:item2-2' default: 'None' required: false type: string workflow_call: inputs: version: description: 'Version tag of text-generation-webui to build: v3.0' default: 'v3.0' required: true type: string config: description: 'Configurations to build: key1:item1-1,item1-2;key2:item2-1,item2-2' default: 'Default' required: false type: string exclude: description: 'Exclude build configurations: key1-1:item1-1,key1-2:item1-2;key2-1:item2-1,key2-2:item2-2' default: 'None' required: false type: string permissions: contents: write jobs: define_matrix: name: Define Build Matrix runs-on: ubuntu-latest outputs: matrix: ${{ steps.set-matrix.outputs.matrix }} defaults: run: shell: pwsh env: CONFIGIN: ${{ inputs.config }} EXCLUDEIN: ${{ inputs.exclude }} steps: - name: Define Job Output id: set-matrix run: | $matrix = @{ 'os' = @('ubuntu-22.04', 'windows-2022') 'pyver' = @("3.13") 'cuda' = @("12.4", "13.1") } if ($env:CONFIGIN -ne 'Default') {$env:CONFIGIN.split(';').foreach({$matrix[$_.split(':')[0]] = $_.split(':')[1].split(',')})} if ($env:EXCLUDEIN -ne 'None') { $exclusions = @() $exclusions += $env:EXCLUDEIN.split(';').replace(':','=').replace(',',"`n") | ConvertFrom-StringData $matrix['exclude'] = $exclusions } $matrixOut = ConvertTo-Json $matrix -Compress Write-Output ('matrix=' + $matrixOut) >> $env:GITHUB_OUTPUT build_wheels: name: ${{ matrix.os }} ${{ matrix.pyver }} CUDA ${{ matrix.cuda }} needs: define_matrix runs-on: ${{ matrix.os }} strategy: matrix: ${{ fromJSON(needs.define_matrix.outputs.matrix) }} defaults: run: shell: pwsh env: PCKGVER: ${{ inputs.version }} steps: - uses: actions/checkout@v6 with: repository: 'oobabooga/text-generation-webui' ref: ${{ inputs.version }} submodules: 'recursive' - uses: actions/setup-python@v6 with: python-version: ${{ matrix.pyver }} - name: Build Package shell: bash run: | VERSION_CLEAN="${{ inputs.version }}" VERSION_CLEAN="${VERSION_CLEAN#v}" cd .. cp -r text-generation-webui "text-generation-webui-${VERSION_CLEAN}" cd "text-generation-webui-${VERSION_CLEAN}" # Remove extensions that need additional requirements allowed=("character_bias" "gallery" "openai" "sd_api_pictures") find extensions/ -mindepth 1 -maxdepth 1 -type d | grep -v -E "$(printf '%s|' "${allowed[@]}" | sed 's/|$//')" | xargs rm -rf # Define common variables CUDA_VERSION="${{ matrix.cuda }}" VERSION="${{ inputs.version }}" # 1. Set platform-specific variables if [[ "$RUNNER_OS" == "Windows" ]]; then PLATFORM="windows" PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-pc-windows-msvc-install_only.tar.gz" PIP_PATH="portable_env/python.exe -m pip" PACKAGES_PATH="portable_env/Lib/site-packages" rm start_linux.sh start_macos.sh else PLATFORM="linux" PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-unknown-linux-gnu-install_only.tar.gz" PIP_PATH="portable_env/bin/python -m pip" PACKAGES_PATH="portable_env/lib/python3.13/site-packages" rm start_macos.sh start_windows.bat fi # 2. Download and extract Python cd .. echo "Downloading Python for $PLATFORM..." curl -L -o python-build.tar.gz "$PYTHON_URL" tar -xzf python-build.tar.gz mv python "text-generation-webui-${VERSION_CLEAN}/portable_env" # 3. Prepare requirements file based on CUDA version cd "text-generation-webui-${VERSION_CLEAN}" if [[ "$CUDA_VERSION" == "13.1" ]]; then REQ_FILE="requirements/portable/requirements_cuda131.txt" else REQ_FILE="requirements/portable/requirements.txt" fi # 4. Install packages echo "Installing Python packages from $REQ_FILE..." $PIP_PATH install --target="./$PACKAGES_PATH" -r "$REQ_FILE" # 5. Clean up rm -rf .git cmd* update_wizard* Colab-TextGen-GPU.ipynb docker setup.cfg .github .gitignore requirements/ one_click.py # 6. Create archive cd .. if [[ "$RUNNER_OS" == "Windows" ]]; then ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-cuda${CUDA_VERSION}.zip" echo "Creating archive: $ARCHIVE_NAME" powershell -Command "Compress-Archive -Path text-generation-webui-${VERSION_CLEAN} -DestinationPath $ARCHIVE_NAME" else ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-cuda${CUDA_VERSION}.tar.gz" echo "Creating archive: $ARCHIVE_NAME" tar czf "$ARCHIVE_NAME" "text-generation-webui-${VERSION_CLEAN}" fi - name: Upload files to a GitHub release id: upload-release uses: svenstaro/upload-release-action@2.7.0 continue-on-error: true with: repo_token: ${{ secrets.GITHUB_TOKEN }} file: ../textgen-portable-* tag: ${{ inputs.version }} file_glob: true make_latest: false overwrite: true ================================================ FILE: .github/workflows/build-portable-release-rocm.yml ================================================ name: Build ROCm on: workflow_dispatch: inputs: version: description: 'Version tag of text-generation-webui to build: v3.0' default: 'v3.0' required: true type: string config: description: 'Override configurations to build: key1:item1-1,item1-2;key2:item2-1,item2-2' default: 'Default' required: false type: string exclude: description: 'Exclude build configurations: key1-1:item1-1,key1-2:item1-2;key2-1:item2-1,key2-2:item2-2' default: 'None' required: false type: string workflow_call: inputs: version: description: 'Version tag of text-generation-webui to build: v3.0' default: 'v3.0' required: true type: string config: description: 'Configurations to build: key1:item1-1,item1-2;key2:item2-1,item2-2' default: 'Default' required: false type: string exclude: description: 'Exclude build configurations: key1-1:item1-1,key1-2:item1-2;key2-1:item2-1,key2-2:item2-2' default: 'None' required: false type: string permissions: contents: write jobs: define_matrix: name: Define Build Matrix runs-on: ubuntu-latest outputs: matrix: ${{ steps.set-matrix.outputs.matrix }} defaults: run: shell: pwsh env: CONFIGIN: ${{ inputs.config }} EXCLUDEIN: ${{ inputs.exclude }} steps: - name: Define Job Output id: set-matrix run: | $matrix = @{ 'os' = @('ubuntu-22.04', 'windows-2022') 'pyver' = @("3.13") } if ($env:CONFIGIN -ne 'Default') {$env:CONFIGIN.split(';').foreach({$matrix[$_.split(':')[0]] = $_.split(':')[1].split(',')})} if ($env:EXCLUDEIN -ne 'None') { $exclusions = @() $exclusions += $env:EXCLUDEIN.split(';').replace(':','=').replace(',',"`n") | ConvertFrom-StringData $matrix['exclude'] = $exclusions } $matrixOut = ConvertTo-Json $matrix -Compress Write-Output ('matrix=' + $matrixOut) >> $env:GITHUB_OUTPUT build_wheels: name: ${{ matrix.os }} ${{ matrix.pyver }} needs: define_matrix runs-on: ${{ matrix.os }} strategy: matrix: ${{ fromJSON(needs.define_matrix.outputs.matrix) }} defaults: run: shell: pwsh env: PCKGVER: ${{ inputs.version }} steps: - uses: actions/checkout@v6 with: repository: 'oobabooga/text-generation-webui' ref: ${{ inputs.version }} submodules: 'recursive' - uses: actions/setup-python@v6 with: python-version: ${{ matrix.pyver }} - name: Build Package shell: bash run: | VERSION_CLEAN="${{ inputs.version }}" VERSION_CLEAN="${VERSION_CLEAN#v}" cd .. cp -r text-generation-webui "text-generation-webui-${VERSION_CLEAN}" cd "text-generation-webui-${VERSION_CLEAN}" # Remove extensions that need additional requirements allowed=("character_bias" "gallery" "openai" "sd_api_pictures") find extensions/ -mindepth 1 -maxdepth 1 -type d | grep -v -E "$(printf '%s|' "${allowed[@]}" | sed 's/|$//')" | xargs rm -rf # Define common variables VERSION="${{ inputs.version }}" # 1. Set platform-specific variables if [[ "$RUNNER_OS" == "Windows" ]]; then PLATFORM="windows" PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-pc-windows-msvc-install_only.tar.gz" PIP_PATH="portable_env/python.exe -m pip" PACKAGES_PATH="portable_env/Lib/site-packages" rm start_linux.sh start_macos.sh else PLATFORM="linux" PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-unknown-linux-gnu-install_only.tar.gz" PIP_PATH="portable_env/bin/python -m pip" PACKAGES_PATH="portable_env/lib/python3.13/site-packages" rm start_macos.sh start_windows.bat fi # 2. Download and extract Python cd .. echo "Downloading Python for $PLATFORM..." curl -L -o python-build.tar.gz "$PYTHON_URL" tar -xzf python-build.tar.gz mv python "text-generation-webui-${VERSION_CLEAN}/portable_env" # 3. Prepare requirements file REQ_FILE="requirements/portable/requirements_amd.txt" cd "text-generation-webui-${VERSION_CLEAN}" # 4. Install packages echo "Installing Python packages from $REQ_FILE..." $PIP_PATH install --target="./$PACKAGES_PATH" -r "$REQ_FILE" # 5. Clean up rm -rf .git cmd* update_wizard* Colab-TextGen-GPU.ipynb docker setup.cfg .github .gitignore requirements/ one_click.py # 6. Create archive cd .. if [[ "$RUNNER_OS" == "Windows" ]]; then ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-rocm7.2.zip" echo "Creating archive: $ARCHIVE_NAME" powershell -Command "Compress-Archive -Path text-generation-webui-${VERSION_CLEAN} -DestinationPath $ARCHIVE_NAME" else ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-rocm7.2.tar.gz" echo "Creating archive: $ARCHIVE_NAME" tar czf "$ARCHIVE_NAME" "text-generation-webui-${VERSION_CLEAN}" fi - name: Upload files to a GitHub release id: upload-release uses: svenstaro/upload-release-action@2.7.0 continue-on-error: true with: repo_token: ${{ secrets.GITHUB_TOKEN }} file: ../textgen-portable-* tag: ${{ inputs.version }} file_glob: true make_latest: false overwrite: true ================================================ FILE: .github/workflows/build-portable-release-vulkan.yml ================================================ name: Build Vulkan on: workflow_dispatch: inputs: version: description: 'Version tag of text-generation-webui to build: v3.0' default: 'v3.0' required: true type: string config: description: 'Override configurations to build: key1:item1-1,item1-2;key2:item2-1,item2-2' default: 'Default' required: false type: string exclude: description: 'Exclude build configurations: key1-1:item1-1,key1-2:item1-2;key2-1:item2-1,key2-2:item2-2' default: 'None' required: false type: string workflow_call: inputs: version: description: 'Version tag of text-generation-webui to build: v3.0' default: 'v3.0' required: true type: string config: description: 'Configurations to build: key1:item1-1,item1-2;key2:item2-1,item2-2' default: 'Default' required: false type: string exclude: description: 'Exclude build configurations: key1-1:item1-1,key1-2:item1-2;key2-1:item2-1,key2-2:item2-2' default: 'None' required: false type: string permissions: contents: write jobs: define_matrix: name: Define Build Matrix runs-on: ubuntu-latest outputs: matrix: ${{ steps.set-matrix.outputs.matrix }} defaults: run: shell: pwsh env: CONFIGIN: ${{ inputs.config }} EXCLUDEIN: ${{ inputs.exclude }} steps: - name: Define Job Output id: set-matrix run: | $matrix = @{ 'os' = @('ubuntu-22.04', 'windows-2022') 'pyver' = @("3.13") } if ($env:CONFIGIN -ne 'Default') {$env:CONFIGIN.split(';').foreach({$matrix[$_.split(':')[0]] = $_.split(':')[1].split(',')})} if ($env:EXCLUDEIN -ne 'None') { $exclusions = @() $exclusions += $env:EXCLUDEIN.split(';').replace(':','=').replace(',',"`n") | ConvertFrom-StringData $matrix['exclude'] = $exclusions } $matrixOut = ConvertTo-Json $matrix -Compress Write-Output ('matrix=' + $matrixOut) >> $env:GITHUB_OUTPUT build_wheels: name: ${{ matrix.os }} ${{ matrix.pyver }} needs: define_matrix runs-on: ${{ matrix.os }} strategy: matrix: ${{ fromJSON(needs.define_matrix.outputs.matrix) }} defaults: run: shell: pwsh env: PCKGVER: ${{ inputs.version }} steps: - uses: actions/checkout@v6 with: repository: 'oobabooga/text-generation-webui' ref: ${{ inputs.version }} submodules: 'recursive' - uses: actions/setup-python@v6 with: python-version: ${{ matrix.pyver }} - name: Build Package shell: bash run: | VERSION_CLEAN="${{ inputs.version }}" VERSION_CLEAN="${VERSION_CLEAN#v}" cd .. cp -r text-generation-webui "text-generation-webui-${VERSION_CLEAN}" cd "text-generation-webui-${VERSION_CLEAN}" # Remove extensions that need additional requirements allowed=("character_bias" "gallery" "openai" "sd_api_pictures") find extensions/ -mindepth 1 -maxdepth 1 -type d | grep -v -E "$(printf '%s|' "${allowed[@]}" | sed 's/|$//')" | xargs rm -rf # Define common variables VERSION="${{ inputs.version }}" # 1. Set platform-specific variables if [[ "$RUNNER_OS" == "Windows" ]]; then PLATFORM="windows" PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-pc-windows-msvc-install_only.tar.gz" PIP_PATH="portable_env/python.exe -m pip" PACKAGES_PATH="portable_env/Lib/site-packages" rm start_linux.sh start_macos.sh else PLATFORM="linux" PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-unknown-linux-gnu-install_only.tar.gz" PIP_PATH="portable_env/bin/python -m pip" PACKAGES_PATH="portable_env/lib/python3.13/site-packages" rm start_macos.sh start_windows.bat fi # 2. Download and extract Python cd .. echo "Downloading Python for $PLATFORM..." curl -L -o python-build.tar.gz "$PYTHON_URL" tar -xzf python-build.tar.gz mv python "text-generation-webui-${VERSION_CLEAN}/portable_env" # 3. Prepare requirements file REQ_FILE="requirements/portable/requirements_vulkan.txt" cd "text-generation-webui-${VERSION_CLEAN}" # 4. Install packages echo "Installing Python packages from $REQ_FILE..." $PIP_PATH install --target="./$PACKAGES_PATH" -r "$REQ_FILE" # 5. Clean up rm -rf .git cmd* update_wizard* Colab-TextGen-GPU.ipynb docker setup.cfg .github .gitignore requirements/ one_click.py # 6. Create archive cd .. if [[ "$RUNNER_OS" == "Windows" ]]; then ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-vulkan.zip" echo "Creating archive: $ARCHIVE_NAME" powershell -Command "Compress-Archive -Path text-generation-webui-${VERSION_CLEAN} -DestinationPath $ARCHIVE_NAME" else ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}-vulkan.tar.gz" echo "Creating archive: $ARCHIVE_NAME" tar czf "$ARCHIVE_NAME" "text-generation-webui-${VERSION_CLEAN}" fi - name: Upload files to a GitHub release id: upload-release uses: svenstaro/upload-release-action@2.7.0 continue-on-error: true with: repo_token: ${{ secrets.GITHUB_TOKEN }} file: ../textgen-portable-* tag: ${{ inputs.version }} file_glob: true make_latest: false overwrite: true ================================================ FILE: .github/workflows/build-portable-release.yml ================================================ name: Build CPU and macOS on: workflow_dispatch: inputs: version: description: 'Version tag of text-generation-webui to build: v3.0' default: 'v3.0' required: true type: string config: description: 'Override configurations to build: key1:item1-1,item1-2;key2:item2-1,item2-2' default: 'Default' required: false type: string exclude: description: 'Exclude build configurations: key1-1:item1-1,key1-2:item1-2;key2-1:item2-1,key2-2:item2-2' default: 'None' required: false type: string workflow_call: inputs: version: description: 'Version tag of text-generation-webui to build: v3.0' default: 'v3.0' required: true type: string config: description: 'Configurations to build: key1:item1-1,item1-2;key2:item2-1,item2-2' default: 'Default' required: false type: string exclude: description: 'Exclude build configurations: key1-1:item1-1,key1-2:item1-2;key2-1:item2-1,key2-2:item2-2' default: 'None' required: false type: string permissions: contents: write jobs: define_matrix: name: Define Build Matrix runs-on: ubuntu-latest outputs: matrix: ${{ steps.set-matrix.outputs.matrix }} defaults: run: shell: pwsh env: CONFIGIN: ${{ inputs.config }} EXCLUDEIN: ${{ inputs.exclude }} steps: - name: Define Job Output id: set-matrix run: | $matrix = @{ 'os' = @('ubuntu-22.04', 'windows-2022', 'macos-14') 'pyver' = @("3.13") } if ($env:CONFIGIN -ne 'Default') {$env:CONFIGIN.split(';').foreach({$matrix[$_.split(':')[0]] = $_.split(':')[1].split(',')})} if ($env:EXCLUDEIN -ne 'None') { $exclusions = @() $exclusions += $env:EXCLUDEIN.split(';').replace(':','=').replace(',',"`n") | ConvertFrom-StringData $matrix['exclude'] = $exclusions } $matrixOut = ConvertTo-Json $matrix -Compress Write-Output ('matrix=' + $matrixOut) >> $env:GITHUB_OUTPUT build_wheels: name: ${{ matrix.os }} ${{ matrix.pyver }} needs: define_matrix runs-on: ${{ matrix.os }} strategy: matrix: ${{ fromJSON(needs.define_matrix.outputs.matrix) }} defaults: run: shell: pwsh env: PCKGVER: ${{ inputs.version }} steps: - uses: actions/checkout@v6 with: repository: 'oobabooga/text-generation-webui' ref: ${{ inputs.version }} submodules: 'recursive' - uses: actions/setup-python@v6 with: python-version: ${{ matrix.pyver }} - name: Build Package shell: bash run: | VERSION_CLEAN="${{ inputs.version }}" VERSION_CLEAN="${VERSION_CLEAN#v}" cd .. cp -r text-generation-webui "text-generation-webui-${VERSION_CLEAN}" cd "text-generation-webui-${VERSION_CLEAN}" # Remove extensions that need additional requirements allowed=("character_bias" "gallery" "openai" "sd_api_pictures") find extensions/ -mindepth 1 -maxdepth 1 -type d | grep -v -E "$(printf '%s|' "${allowed[@]}" | sed 's/|$//')" | xargs rm -rf # Define common variables VERSION="${{ inputs.version }}" OS_TYPE="${{ matrix.os }}" # 1. Set platform-specific variables if [[ "$RUNNER_OS" == "Windows" ]]; then PLATFORM="windows-cpu" PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-pc-windows-msvc-install_only.tar.gz" PIP_PATH="portable_env/python.exe -m pip" PACKAGES_PATH="portable_env/Lib/site-packages" rm start_linux.sh start_macos.sh elif [[ "$RUNNER_OS" == "macOS" ]]; then if [[ "$OS_TYPE" == "macos-15-intel" ]]; then PLATFORM="macos-x86_64" PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-apple-darwin-install_only.tar.gz" REQ_TYPE="apple_intel" else PLATFORM="macos-arm64" PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-aarch64-apple-darwin-install_only.tar.gz" REQ_TYPE="apple_silicon" fi PIP_PATH="portable_env/bin/python -m pip" PACKAGES_PATH="portable_env/lib/python3.13/site-packages" rm start_linux.sh start_windows.bat else # Linux case PLATFORM="linux-cpu" PYTHON_URL="https://github.com/astral-sh/python-build-standalone/releases/download/20260303/cpython-3.13.12+20260303-x86_64-unknown-linux-gnu-install_only.tar.gz" PIP_PATH="portable_env/bin/python -m pip" PACKAGES_PATH="portable_env/lib/python3.13/site-packages" rm start_macos.sh start_windows.bat fi # 2. Download and extract Python echo "Downloading Python for $PLATFORM..." cd .. curl -L -o python-build.tar.gz "$PYTHON_URL" tar -xzf python-build.tar.gz mv python "text-generation-webui-${VERSION_CLEAN}/portable_env" # 3. Prepare requirements file based on platform cd "text-generation-webui-${VERSION_CLEAN}" # Select requirements file based on platform if [[ "$RUNNER_OS" == "macOS" ]]; then if [[ "$OS_TYPE" == "macos-15-intel" ]]; then REQ_FILE="requirements/portable/requirements_apple_intel.txt" else REQ_FILE="requirements/portable/requirements_apple_silicon.txt" fi else REQ_FILE="requirements/portable/requirements_cpu_only.txt" fi echo "Using requirements file: $REQ_FILE" # 4. Install packages echo "Installing Python packages from $REQ_FILE..." $PIP_PATH install --target="./$PACKAGES_PATH" -r "$REQ_FILE" # 5. Clean up rm -rf .git cmd* update_wizard* Colab-TextGen-GPU.ipynb docker setup.cfg .github .gitignore requirements/ one_click.py # 6. Create archive cd .. if [[ "$RUNNER_OS" == "Windows" ]]; then ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}.zip" echo "Creating archive: $ARCHIVE_NAME" powershell -Command "Compress-Archive -Path text-generation-webui-${VERSION_CLEAN} -DestinationPath $ARCHIVE_NAME" else ARCHIVE_NAME="textgen-portable-${VERSION_CLEAN}-${PLATFORM}.tar.gz" echo "Creating archive: $ARCHIVE_NAME" tar czf "$ARCHIVE_NAME" "text-generation-webui-${VERSION_CLEAN}" fi - name: Upload files to a GitHub release id: upload-release uses: svenstaro/upload-release-action@2.7.0 continue-on-error: true with: repo_token: ${{ secrets.GITHUB_TOKEN }} file: ../textgen-portable-* tag: ${{ inputs.version }} file_glob: true make_latest: false overwrite: true ================================================ FILE: .gitignore ================================================ /css /extensions /installer_files /repositories /user_data .chroma .DS_Store .eslintrc.js .idea .installer_state.json .venv venv .envrc .direnv .vs .vscode *.bak *.ipynb *.log *pycache* cert.pem key.pem package.json package-lock.json Thumbs.db wandb # ignore user docker config and top level links to docker files /docker-compose.yaml /docker-compose.yml /Dockerfile .env ================================================ FILE: LICENSE ================================================ GNU AFFERO GENERAL PUBLIC LICENSE Version 3, 19 November 2007 Copyright (C) 2007 Free Software Foundation, Inc. Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed. Preamble The GNU Affero General Public License is a free, copyleft license for software and other kinds of works, specifically designed to ensure cooperation with the community in the case of network server software. The licenses for most software and other practical works are designed to take away your freedom to share and change the works. By contrast, our General Public Licenses are intended to guarantee your freedom to share and change all versions of a program--to make sure it remains free software for all its users. When we speak of free software, we are referring to freedom, not price. Our General Public Licenses are designed to make sure that you have the freedom to distribute copies of free software (and charge for them if you wish), that you receive source code or can get it if you want it, that you can change the software or use pieces of it in new free programs, and that you know you can do these things. Developers that use our General Public Licenses protect your rights with two steps: (1) assert copyright on the software, and (2) offer you this License which gives you legal permission to copy, distribute and/or modify the software. A secondary benefit of defending all users' freedom is that improvements made in alternate versions of the program, if they receive widespread use, become available for other developers to incorporate. Many developers of free software are heartened and encouraged by the resulting cooperation. However, in the case of software used on network servers, this result may fail to come about. The GNU General Public License permits making a modified version and letting the public access it on a server without ever releasing its source code to the public. The GNU Affero General Public License is designed specifically to ensure that, in such cases, the modified source code becomes available to the community. It requires the operator of a network server to provide the source code of the modified version running there to the users of that server. Therefore, public use of a modified version, on a publicly accessible server, gives the public access to the source code of the modified version. An older license, called the Affero General Public License and published by Affero, was designed to accomplish similar goals. This is a different license, not a version of the Affero GPL, but Affero has released a new version of the Affero GPL which permits relicensing under this license. The precise terms and conditions for copying, distribution and modification follow. TERMS AND CONDITIONS 0. Definitions. "This License" refers to version 3 of the GNU Affero General Public License. "Copyright" also means copyright-like laws that apply to other kinds of works, such as semiconductor masks. "The Program" refers to any copyrightable work licensed under this License. Each licensee is addressed as "you". "Licensees" and "recipients" may be individuals or organizations. To "modify" a work means to copy from or adapt all or part of the work in a fashion requiring copyright permission, other than the making of an exact copy. The resulting work is called a "modified version" of the earlier work or a work "based on" the earlier work. A "covered work" means either the unmodified Program or a work based on the Program. To "propagate" a work means to do anything with it that, without permission, would make you directly or secondarily liable for infringement under applicable copyright law, except executing it on a computer or modifying a private copy. Propagation includes copying, distribution (with or without modification), making available to the public, and in some countries other activities as well. To "convey" a work means any kind of propagation that enables other parties to make or receive copies. Mere interaction with a user through a computer network, with no transfer of a copy, is not conveying. An interactive user interface displays "Appropriate Legal Notices" to the extent that it includes a convenient and prominently visible feature that (1) displays an appropriate copyright notice, and (2) tells the user that there is no warranty for the work (except to the extent that warranties are provided), that licensees may convey the work under this License, and how to view a copy of this License. If the interface presents a list of user commands or options, such as a menu, a prominent item in the list meets this criterion. 1. Source Code. The "source code" for a work means the preferred form of the work for making modifications to it. "Object code" means any non-source form of a work. A "Standard Interface" means an interface that either is an official standard defined by a recognized standards body, or, in the case of interfaces specified for a particular programming language, one that is widely used among developers working in that language. The "System Libraries" of an executable work include anything, other than the work as a whole, that (a) is included in the normal form of packaging a Major Component, but which is not part of that Major Component, and (b) serves only to enable use of the work with that Major Component, or to implement a Standard Interface for which an implementation is available to the public in source code form. A "Major Component", in this context, means a major essential component (kernel, window system, and so on) of the specific operating system (if any) on which the executable work runs, or a compiler used to produce the work, or an object code interpreter used to run it. The "Corresponding Source" for a work in object code form means all the source code needed to generate, install, and (for an executable work) run the object code and to modify the work, including scripts to control those activities. However, it does not include the work's System Libraries, or general-purpose tools or generally available free programs which are used unmodified in performing those activities but which are not part of the work. For example, Corresponding Source includes interface definition files associated with source files for the work, and the source code for shared libraries and dynamically linked subprograms that the work is specifically designed to require, such as by intimate data communication or control flow between those subprograms and other parts of the work. The Corresponding Source need not include anything that users can regenerate automatically from other parts of the Corresponding Source. The Corresponding Source for a work in source code form is that same work. 2. Basic Permissions. All rights granted under this License are granted for the term of copyright on the Program, and are irrevocable provided the stated conditions are met. This License explicitly affirms your unlimited permission to run the unmodified Program. The output from running a covered work is covered by this License only if the output, given its content, constitutes a covered work. This License acknowledges your rights of fair use or other equivalent, as provided by copyright law. You may make, run and propagate covered works that you do not convey, without conditions so long as your license otherwise remains in force. You may convey covered works to others for the sole purpose of having them make modifications exclusively for you, or provide you with facilities for running those works, provided that you comply with the terms of this License in conveying all material for which you do not control copyright. Those thus making or running the covered works for you must do so exclusively on your behalf, under your direction and control, on terms that prohibit them from making any copies of your copyrighted material outside their relationship with you. Conveying under any other circumstances is permitted solely under the conditions stated below. Sublicensing is not allowed; section 10 makes it unnecessary. 3. Protecting Users' Legal Rights From Anti-Circumvention Law. No covered work shall be deemed part of an effective technological measure under any applicable law fulfilling obligations under article 11 of the WIPO copyright treaty adopted on 20 December 1996, or similar laws prohibiting or restricting circumvention of such measures. When you convey a covered work, you waive any legal power to forbid circumvention of technological measures to the extent such circumvention is effected by exercising rights under this License with respect to the covered work, and you disclaim any intention to limit operation or modification of the work as a means of enforcing, against the work's users, your or third parties' legal rights to forbid circumvention of technological measures. 4. Conveying Verbatim Copies. You may convey verbatim copies of the Program's source code as you receive it, in any medium, provided that you conspicuously and appropriately publish on each copy an appropriate copyright notice; keep intact all notices stating that this License and any non-permissive terms added in accord with section 7 apply to the code; keep intact all notices of the absence of any warranty; and give all recipients a copy of this License along with the Program. You may charge any price or no price for each copy that you convey, and you may offer support or warranty protection for a fee. 5. Conveying Modified Source Versions. You may convey a work based on the Program, or the modifications to produce it from the Program, in the form of source code under the terms of section 4, provided that you also meet all of these conditions: a) The work must carry prominent notices stating that you modified it, and giving a relevant date. b) The work must carry prominent notices stating that it is released under this License and any conditions added under section 7. This requirement modifies the requirement in section 4 to "keep intact all notices". c) You must license the entire work, as a whole, under this License to anyone who comes into possession of a copy. This License will therefore apply, along with any applicable section 7 additional terms, to the whole of the work, and all its parts, regardless of how they are packaged. This License gives no permission to license the work in any other way, but it does not invalidate such permission if you have separately received it. d) If the work has interactive user interfaces, each must display Appropriate Legal Notices; however, if the Program has interactive interfaces that do not display Appropriate Legal Notices, your work need not make them do so. A compilation of a covered work with other separate and independent works, which are not by their nature extensions of the covered work, and which are not combined with it such as to form a larger program, in or on a volume of a storage or distribution medium, is called an "aggregate" if the compilation and its resulting copyright are not used to limit the access or legal rights of the compilation's users beyond what the individual works permit. Inclusion of a covered work in an aggregate does not cause this License to apply to the other parts of the aggregate. 6. Conveying Non-Source Forms. You may convey a covered work in object code form under the terms of sections 4 and 5, provided that you also convey the machine-readable Corresponding Source under the terms of this License, in one of these ways: a) Convey the object code in, or embodied in, a physical product (including a physical distribution medium), accompanied by the Corresponding Source fixed on a durable physical medium customarily used for software interchange. b) Convey the object code in, or embodied in, a physical product (including a physical distribution medium), accompanied by a written offer, valid for at least three years and valid for as long as you offer spare parts or customer support for that product model, to give anyone who possesses the object code either (1) a copy of the Corresponding Source for all the software in the product that is covered by this License, on a durable physical medium customarily used for software interchange, for a price no more than your reasonable cost of physically performing this conveying of source, or (2) access to copy the Corresponding Source from a network server at no charge. c) Convey individual copies of the object code with a copy of the written offer to provide the Corresponding Source. This alternative is allowed only occasionally and noncommercially, and only if you received the object code with such an offer, in accord with subsection 6b. d) Convey the object code by offering access from a designated place (gratis or for a charge), and offer equivalent access to the Corresponding Source in the same way through the same place at no further charge. You need not require recipients to copy the Corresponding Source along with the object code. If the place to copy the object code is a network server, the Corresponding Source may be on a different server (operated by you or a third party) that supports equivalent copying facilities, provided you maintain clear directions next to the object code saying where to find the Corresponding Source. Regardless of what server hosts the Corresponding Source, you remain obligated to ensure that it is available for as long as needed to satisfy these requirements. e) Convey the object code using peer-to-peer transmission, provided you inform other peers where the object code and Corresponding Source of the work are being offered to the general public at no charge under subsection 6d. A separable portion of the object code, whose source code is excluded from the Corresponding Source as a System Library, need not be included in conveying the object code work. A "User Product" is either (1) a "consumer product", which means any tangible personal property which is normally used for personal, family, or household purposes, or (2) anything designed or sold for incorporation into a dwelling. In determining whether a product is a consumer product, doubtful cases shall be resolved in favor of coverage. For a particular product received by a particular user, "normally used" refers to a typical or common use of that class of product, regardless of the status of the particular user or of the way in which the particular user actually uses, or expects or is expected to use, the product. A product is a consumer product regardless of whether the product has substantial commercial, industrial or non-consumer uses, unless such uses represent the only significant mode of use of the product. "Installation Information" for a User Product means any methods, procedures, authorization keys, or other information required to install and execute modified versions of a covered work in that User Product from a modified version of its Corresponding Source. The information must suffice to ensure that the continued functioning of the modified object code is in no case prevented or interfered with solely because modification has been made. If you convey an object code work under this section in, or with, or specifically for use in, a User Product, and the conveying occurs as part of a transaction in which the right of possession and use of the User Product is transferred to the recipient in perpetuity or for a fixed term (regardless of how the transaction is characterized), the Corresponding Source conveyed under this section must be accompanied by the Installation Information. But this requirement does not apply if neither you nor any third party retains the ability to install modified object code on the User Product (for example, the work has been installed in ROM). The requirement to provide Installation Information does not include a requirement to continue to provide support service, warranty, or updates for a work that has been modified or installed by the recipient, or for the User Product in which it has been modified or installed. Access to a network may be denied when the modification itself materially and adversely affects the operation of the network or violates the rules and protocols for communication across the network. Corresponding Source conveyed, and Installation Information provided, in accord with this section must be in a format that is publicly documented (and with an implementation available to the public in source code form), and must require no special password or key for unpacking, reading or copying. 7. Additional Terms. "Additional permissions" are terms that supplement the terms of this License by making exceptions from one or more of its conditions. Additional permissions that are applicable to the entire Program shall be treated as though they were included in this License, to the extent that they are valid under applicable law. If additional permissions apply only to part of the Program, that part may be used separately under those permissions, but the entire Program remains governed by this License without regard to the additional permissions. When you convey a copy of a covered work, you may at your option remove any additional permissions from that copy, or from any part of it. (Additional permissions may be written to require their own removal in certain cases when you modify the work.) You may place additional permissions on material, added by you to a covered work, for which you have or can give appropriate copyright permission. Notwithstanding any other provision of this License, for material you add to a covered work, you may (if authorized by the copyright holders of that material) supplement the terms of this License with terms: a) Disclaiming warranty or limiting liability differently from the terms of sections 15 and 16 of this License; or b) Requiring preservation of specified reasonable legal notices or author attributions in that material or in the Appropriate Legal Notices displayed by works containing it; or c) Prohibiting misrepresentation of the origin of that material, or requiring that modified versions of such material be marked in reasonable ways as different from the original version; or d) Limiting the use for publicity purposes of names of licensors or authors of the material; or e) Declining to grant rights under trademark law for use of some trade names, trademarks, or service marks; or f) Requiring indemnification of licensors and authors of that material by anyone who conveys the material (or modified versions of it) with contractual assumptions of liability to the recipient, for any liability that these contractual assumptions directly impose on those licensors and authors. All other non-permissive additional terms are considered "further restrictions" within the meaning of section 10. If the Program as you received it, or any part of it, contains a notice stating that it is governed by this License along with a term that is a further restriction, you may remove that term. If a license document contains a further restriction but permits relicensing or conveying under this License, you may add to a covered work material governed by the terms of that license document, provided that the further restriction does not survive such relicensing or conveying. If you add terms to a covered work in accord with this section, you must place, in the relevant source files, a statement of the additional terms that apply to those files, or a notice indicating where to find the applicable terms. Additional terms, permissive or non-permissive, may be stated in the form of a separately written license, or stated as exceptions; the above requirements apply either way. 8. Termination. You may not propagate or modify a covered work except as expressly provided under this License. Any attempt otherwise to propagate or modify it is void, and will automatically terminate your rights under this License (including any patent licenses granted under the third paragraph of section 11). However, if you cease all violation of this License, then your license from a particular copyright holder is reinstated (a) provisionally, unless and until the copyright holder explicitly and finally terminates your license, and (b) permanently, if the copyright holder fails to notify you of the violation by some reasonable means prior to 60 days after the cessation. Moreover, your license from a particular copyright holder is reinstated permanently if the copyright holder notifies you of the violation by some reasonable means, this is the first time you have received notice of violation of this License (for any work) from that copyright holder, and you cure the violation prior to 30 days after your receipt of the notice. Termination of your rights under this section does not terminate the licenses of parties who have received copies or rights from you under this License. If your rights have been terminated and not permanently reinstated, you do not qualify to receive new licenses for the same material under section 10. 9. Acceptance Not Required for Having Copies. You are not required to accept this License in order to receive or run a copy of the Program. Ancillary propagation of a covered work occurring solely as a consequence of using peer-to-peer transmission to receive a copy likewise does not require acceptance. However, nothing other than this License grants you permission to propagate or modify any covered work. These actions infringe copyright if you do not accept this License. Therefore, by modifying or propagating a covered work, you indicate your acceptance of this License to do so. 10. Automatic Licensing of Downstream Recipients. Each time you convey a covered work, the recipient automatically receives a license from the original licensors, to run, modify and propagate that work, subject to this License. You are not responsible for enforcing compliance by third parties with this License. An "entity transaction" is a transaction transferring control of an organization, or substantially all assets of one, or subdividing an organization, or merging organizations. If propagation of a covered work results from an entity transaction, each party to that transaction who receives a copy of the work also receives whatever licenses to the work the party's predecessor in interest had or could give under the previous paragraph, plus a right to possession of the Corresponding Source of the work from the predecessor in interest, if the predecessor has it or can get it with reasonable efforts. You may not impose any further restrictions on the exercise of the rights granted or affirmed under this License. For example, you may not impose a license fee, royalty, or other charge for exercise of rights granted under this License, and you may not initiate litigation (including a cross-claim or counterclaim in a lawsuit) alleging that any patent claim is infringed by making, using, selling, offering for sale, or importing the Program or any portion of it. 11. Patents. A "contributor" is a copyright holder who authorizes use under this License of the Program or a work on which the Program is based. The work thus licensed is called the contributor's "contributor version". A contributor's "essential patent claims" are all patent claims owned or controlled by the contributor, whether already acquired or hereafter acquired, that would be infringed by some manner, permitted by this License, of making, using, or selling its contributor version, but do not include claims that would be infringed only as a consequence of further modification of the contributor version. For purposes of this definition, "control" includes the right to grant patent sublicenses in a manner consistent with the requirements of this License. Each contributor grants you a non-exclusive, worldwide, royalty-free patent license under the contributor's essential patent claims, to make, use, sell, offer for sale, import and otherwise run, modify and propagate the contents of its contributor version. In the following three paragraphs, a "patent license" is any express agreement or commitment, however denominated, not to enforce a patent (such as an express permission to practice a patent or covenant not to sue for patent infringement). To "grant" such a patent license to a party means to make such an agreement or commitment not to enforce a patent against the party. If you convey a covered work, knowingly relying on a patent license, and the Corresponding Source of the work is not available for anyone to copy, free of charge and under the terms of this License, through a publicly available network server or other readily accessible means, then you must either (1) cause the Corresponding Source to be so available, or (2) arrange to deprive yourself of the benefit of the patent license for this particular work, or (3) arrange, in a manner consistent with the requirements of this License, to extend the patent license to downstream recipients. "Knowingly relying" means you have actual knowledge that, but for the patent license, your conveying the covered work in a country, or your recipient's use of the covered work in a country, would infringe one or more identifiable patents in that country that you have reason to believe are valid. If, pursuant to or in connection with a single transaction or arrangement, you convey, or propagate by procuring conveyance of, a covered work, and grant a patent license to some of the parties receiving the covered work authorizing them to use, propagate, modify or convey a specific copy of the covered work, then the patent license you grant is automatically extended to all recipients of the covered work and works based on it. A patent license is "discriminatory" if it does not include within the scope of its coverage, prohibits the exercise of, or is conditioned on the non-exercise of one or more of the rights that are specifically granted under this License. You may not convey a covered work if you are a party to an arrangement with a third party that is in the business of distributing software, under which you make payment to the third party based on the extent of your activity of conveying the work, and under which the third party grants, to any of the parties who would receive the covered work from you, a discriminatory patent license (a) in connection with copies of the covered work conveyed by you (or copies made from those copies), or (b) primarily for and in connection with specific products or compilations that contain the covered work, unless you entered into that arrangement, or that patent license was granted, prior to 28 March 2007. Nothing in this License shall be construed as excluding or limiting any implied license or other defenses to infringement that may otherwise be available to you under applicable patent law. 12. No Surrender of Others' Freedom. If conditions are imposed on you (whether by court order, agreement or otherwise) that contradict the conditions of this License, they do not excuse you from the conditions of this License. If you cannot convey a covered work so as to satisfy simultaneously your obligations under this License and any other pertinent obligations, then as a consequence you may not convey it at all. For example, if you agree to terms that obligate you to collect a royalty for further conveying from those to whom you convey the Program, the only way you could satisfy both those terms and this License would be to refrain entirely from conveying the Program. 13. Remote Network Interaction; Use with the GNU General Public License. Notwithstanding any other provision of this License, if you modify the Program, your modified version must prominently offer all users interacting with it remotely through a computer network (if your version supports such interaction) an opportunity to receive the Corresponding Source of your version by providing access to the Corresponding Source from a network server at no charge, through some standard or customary means of facilitating copying of software. This Corresponding Source shall include the Corresponding Source for any work covered by version 3 of the GNU General Public License that is incorporated pursuant to the following paragraph. Notwithstanding any other provision of this License, you have permission to link or combine any covered work with a work licensed under version 3 of the GNU General Public License into a single combined work, and to convey the resulting work. The terms of this License will continue to apply to the part which is the covered work, but the work with which it is combined will remain governed by version 3 of the GNU General Public License. 14. Revised Versions of this License. The Free Software Foundation may publish revised and/or new versions of the GNU Affero General Public License from time to time. Such new versions will be similar in spirit to the present version, but may differ in detail to address new problems or concerns. Each version is given a distinguishing version number. If the Program specifies that a certain numbered version of the GNU Affero General Public License "or any later version" applies to it, you have the option of following the terms and conditions either of that numbered version or of any later version published by the Free Software Foundation. If the Program does not specify a version number of the GNU Affero General Public License, you may choose any version ever published by the Free Software Foundation. If the Program specifies that a proxy can decide which future versions of the GNU Affero General Public License can be used, that proxy's public statement of acceptance of a version permanently authorizes you to choose that version for the Program. Later license versions may give you additional or different permissions. However, no additional obligations are imposed on any author or copyright holder as a result of your choosing to follow a later version. 15. Disclaimer of Warranty. THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 16. Limitation of Liability. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 17. Interpretation of Sections 15 and 16. If the disclaimer of warranty and limitation of liability provided above cannot be given local legal effect according to their terms, reviewing courts shall apply local law that most closely approximates an absolute waiver of all civil liability in connection with the Program, unless a warranty or assumption of liability accompanies a copy of the Program in return for a fee. END OF TERMS AND CONDITIONS How to Apply These Terms to Your New Programs If you develop a new program, and you want it to be of the greatest possible use to the public, the best way to achieve this is to make it free software which everyone can redistribute and change under these terms. To do so, attach the following notices to the program. It is safest to attach them to the start of each source file to most effectively state the exclusion of warranty; and each file should have at least the "copyright" line and a pointer to where the full notice is found. Copyright (C) This program is free software: you can redistribute it and/or modify it under the terms of the GNU Affero General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. You should have received a copy of the GNU Affero General Public License along with this program. If not, see . Also add information on how to contact you by electronic and paper mail. If your software can interact with users remotely through a computer network, you should also make sure that it provides a way for users to get its source. For example, if your program is a web application, its interface could display a "Source" link that leads users to an archive of the code. There are many ways you could offer source, and different solutions will be better for different programs; see section 13 for the specific requirements. You should also get your employer (if you work as a programmer) or school, if any, to sign a "copyright disclaimer" for the program, if necessary. For more information on this, and how to apply and follow the GNU AGPL, see . ================================================ FILE: README.md ================================================
Special thanks to:

Warp sponsorship ### [Warp, built for coding with multiple AI agents](https://go.warp.dev/text-generation-webui) [Available for macOS, Linux, & Windows](https://go.warp.dev/text-generation-webui)

# Text Generation Web UI A Gradio web UI for running Large Language Models locally. 100% private and offline. Supports text generation, vision, tool-calling, training, image generation, and more. [Try the Deep Reason extension](https://oobabooga.gumroad.com/l/deep_reason) |![Image1](https://github.com/oobabooga/screenshots/raw/main/INSTRUCT-3.5.png) | ![Image2](https://github.com/oobabooga/screenshots/raw/main/CHAT-3.5.png) | |:---:|:---:| |![Image1](https://github.com/oobabooga/screenshots/raw/main/DEFAULT-3.5.png) | ![Image2](https://github.com/oobabooga/screenshots/raw/main/PARAMETERS-3.5.png) | ## Features - **Multiple backends**: [llama.cpp](https://github.com/ggerganov/llama.cpp), [Transformers](https://github.com/huggingface/transformers), [ExLlamaV3](https://github.com/turboderp-org/exllamav3), and [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM). Switch between backends and models without restarting. - **File attachments**: Upload text files, PDF documents, and .docx documents to talk about their contents. - **Vision (multimodal)**: Attach images to messages for visual understanding ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Multimodal-Tutorial)). - **Tool-calling**: Models can call custom functions during chat — web search, page fetching, math, and more. Each tool is a single `.py` file, easy to create and extend ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Tool-Calling-Tutorial)). - **OpenAI-compatible API**: Chat and Completions endpoints with tool-calling support. Use as a local drop-in replacement for the OpenAI API ([examples](https://github.com/oobabooga/text-generation-webui/wiki/12-%E2%80%90-OpenAI-API#examples)). - **Training**: Fine-tune LoRAs on multi-turn chat or raw text datasets. Supports resuming interrupted runs ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/05-%E2%80%90-Training-Tab)). - **Image generation**: A dedicated tab for `diffusers` models like **Z-Image-Turbo**. Features 4-bit/8-bit quantization and a persistent gallery with metadata ([tutorial](https://github.com/oobabooga/text-generation-webui/wiki/Image-Generation-Tutorial)). - **Easy setup**: [Portable builds](https://github.com/oobabooga/text-generation-webui/releases) (zero setup, just unzip and run) for GGUF models on Windows/Linux/macOS, or a one-click installer for the full feature set. - 100% offline and private, with zero telemetry, external resources, or remote update requests. - `instruct` mode for instruction-following (like ChatGPT), and `chat-instruct`/`chat` modes for talking to custom characters. Prompts are automatically formatted with Jinja2 templates. - Edit messages, navigate between message versions, and branch conversations at any point. - Free-form text generation in the Notebook tab without being limited to chat turns. - Multiple sampling parameters and generation options for sophisticated text generation control. - Aesthetic UI with dark and light themes. - Syntax highlighting for code blocks and LaTeX rendering for mathematical expressions. - Extension support, with numerous built-in and user-contributed extensions available. See the [wiki](https://github.com/oobabooga/text-generation-webui/wiki/07-%E2%80%90-Extensions) and [extensions directory](https://github.com/oobabooga/text-generation-webui-extensions) for details. ## How to install #### ✅ Option 1: Portable builds (get started in 1 minute) No installation needed – just download, unzip and run. All dependencies included. Download from here: **https://github.com/oobabooga/text-generation-webui/releases** - Builds are provided for Linux, Windows, and macOS, with options for CUDA, Vulkan, ROCm, and CPU-only. - Compatible with GGUF (llama.cpp) models. #### Option 2: Manual portable install with venv Very fast setup that should work on any Python 3.9+: ```bash # Clone repository git clone https://github.com/oobabooga/text-generation-webui cd text-generation-webui # Create virtual environment python -m venv venv # Activate virtual environment # On Windows: venv\Scripts\activate # On macOS/Linux: source venv/bin/activate # Install dependencies (choose appropriate file under requirements/portable for your hardware) pip install -r requirements/portable/requirements.txt --upgrade # Launch server (basic command) python server.py --portable --api --auto-launch # When done working, deactivate deactivate ``` #### Option 3: One-click installer For users who need additional backends (ExLlamaV3, Transformers), training, image generation, or extensions (TTS, voice input, translation, etc). Requires ~10GB disk space and downloads PyTorch. 1. Clone the repository, or [download its source code](https://github.com/oobabooga/text-generation-webui/archive/refs/heads/main.zip) and extract it. 2. Run the startup script for your OS: `start_windows.bat`, `start_linux.sh`, or `start_macos.sh`. 3. When prompted, select your GPU vendor. 4. After installation, open `http://127.0.0.1:7860` in your browser. To restart the web UI later, run the same `start_` script. You can pass command-line flags directly (e.g., `./start_linux.sh --help`), or add them to `user_data/CMD_FLAGS.txt` (e.g., `--api` to enable the API). To update, run the update script for your OS: `update_wizard_windows.bat`, `update_wizard_linux.sh`, or `update_wizard_macos.sh`. To reinstall with a fresh Python environment, delete the `installer_files` folder and run the `start_` script again.
One-click installer details ### One-click-installer The script uses Miniforge to set up a Conda environment in the `installer_files` folder. If you ever need to install something manually in the `installer_files` environment, you can launch an interactive shell using the cmd script: `cmd_linux.sh`, `cmd_windows.bat`, or `cmd_macos.sh`. * There is no need to run any of those scripts (`start_`, `update_wizard_`, or `cmd_`) as admin/root. * To install requirements for extensions, it is recommended to use the update wizard script with the "Install/update extensions requirements" option. At the end, this script will install the main requirements for the project to make sure that they take precedence in case of version conflicts. * For automated installation, you can use the `GPU_CHOICE`, `LAUNCH_AFTER_INSTALL`, and `INSTALL_EXTENSIONS` environment variables. For instance: `GPU_CHOICE=A LAUNCH_AFTER_INSTALL=FALSE INSTALL_EXTENSIONS=TRUE ./start_linux.sh`.
Manual full installation with conda or docker ### Full installation with Conda #### 0. Install Conda https://github.com/conda-forge/miniforge On Linux or WSL, Miniforge can be automatically installed with these two commands: ``` curl -sL "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh" > "Miniforge3.sh" bash Miniforge3.sh ``` For other platforms, download from: https://github.com/conda-forge/miniforge/releases/latest #### 1. Create a new conda environment ``` conda create -n textgen python=3.13 conda activate textgen ``` #### 2. Install Pytorch | System | GPU | Command | |--------|---------|---------| | Linux/WSL | NVIDIA | `pip3 install torch==2.9.1 --index-url https://download.pytorch.org/whl/cu128` | | Linux/WSL | CPU only | `pip3 install torch==2.9.1 --index-url https://download.pytorch.org/whl/cpu` | | Linux | AMD | `pip3 install https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torch-2.9.1%2Brocm7.2.0.lw.git7e1940d4-cp313-cp313-linux_x86_64.whl` | | MacOS + MPS | Any | `pip3 install torch==2.9.1` | | Windows | NVIDIA | `pip3 install torch==2.9.1 --index-url https://download.pytorch.org/whl/cu128` | | Windows | CPU only | `pip3 install torch==2.9.1` | The up-to-date commands can be found here: https://pytorch.org/get-started/locally/. If you need `nvcc` to compile some library manually, you will additionally need to install this: ``` conda install -y -c "nvidia/label/cuda-12.8.1" cuda ``` #### 3. Install the web UI ``` git clone https://github.com/oobabooga/text-generation-webui cd text-generation-webui pip install -r requirements/full/ ``` Requirements file to use: | GPU | requirements file to use | |--------|---------| | NVIDIA | `requirements.txt` | | AMD | `requirements_amd.txt` | | CPU only | `requirements_cpu_only.txt` | | Apple Intel | `requirements_apple_intel.txt` | | Apple Silicon | `requirements_apple_silicon.txt` | ### Start the web UI ``` conda activate textgen cd text-generation-webui python server.py ``` Then browse to `http://127.0.0.1:7860` #### Manual install The `requirements*.txt` above contain various wheels precompiled through GitHub Actions. If you wish to compile things manually, or if you need to because no suitable wheels are available for your hardware, you can use `requirements_nowheels.txt` and then install your desired loaders manually. ### Alternative: Docker ``` For NVIDIA GPU: ln -s docker/{nvidia/Dockerfile,nvidia/docker-compose.yml,.dockerignore} . For AMD GPU: ln -s docker/{amd/Dockerfile,amd/docker-compose.yml,.dockerignore} . For Intel GPU: ln -s docker/{intel/Dockerfile,intel/docker-compose.yml,.dockerignore} . For CPU only ln -s docker/{cpu/Dockerfile,cpu/docker-compose.yml,.dockerignore} . cp docker/.env.example .env #Create logs/cache dir : mkdir -p user_data/logs user_data/cache # Edit .env and set: # TORCH_CUDA_ARCH_LIST based on your GPU model # APP_RUNTIME_GID your host user's group id (run `id -g` in a terminal) # BUILD_EXTENIONS optionally add comma separated list of extensions to build # Edit user_data/CMD_FLAGS.txt and add in it the options you want to execute (like --listen --cpu) # docker compose up --build ``` * You need to have Docker Compose v2.17 or higher installed. See [this guide](https://github.com/oobabooga/text-generation-webui/wiki/09-%E2%80%90-Docker) for instructions. * For additional docker files, check out [this repository](https://github.com/Atinoda/text-generation-webui-docker). ### Updating the requirements From time to time, the `requirements*.txt` change. To update, use these commands: ``` conda activate textgen cd text-generation-webui pip install -r --upgrade ```
List of command-line flags ```txt usage: server.py [-h] [--user-data-dir USER_DATA_DIR] [--multi-user] [--model MODEL] [--lora LORA [LORA ...]] [--model-dir MODEL_DIR] [--lora-dir LORA_DIR] [--model-menu] [--settings SETTINGS] [--extensions EXTENSIONS [EXTENSIONS ...]] [--verbose] [--idle-timeout IDLE_TIMEOUT] [--image-model IMAGE_MODEL] [--image-model-dir IMAGE_MODEL_DIR] [--image-dtype {bfloat16,float16}] [--image-attn-backend {flash_attention_2,sdpa}] [--image-cpu-offload] [--image-compile] [--image-quant {none,bnb-8bit,bnb-4bit,torchao-int8wo,torchao-fp4,torchao-float8wo}] [--loader LOADER] [--ctx-size N] [--cache-type N] [--model-draft MODEL_DRAFT] [--draft-max DRAFT_MAX] [--gpu-layers-draft GPU_LAYERS_DRAFT] [--device-draft DEVICE_DRAFT] [--ctx-size-draft CTX_SIZE_DRAFT] [--spec-type {none,ngram-mod,ngram-simple,ngram-map-k,ngram-map-k4v,ngram-cache}] [--spec-ngram-size-n SPEC_NGRAM_SIZE_N] [--spec-ngram-size-m SPEC_NGRAM_SIZE_M] [--spec-ngram-min-hits SPEC_NGRAM_MIN_HITS] [--gpu-layers N] [--cpu-moe] [--mmproj MMPROJ] [--streaming-llm] [--tensor-split TENSOR_SPLIT] [--row-split] [--no-mmap] [--mlock] [--no-kv-offload] [--batch-size BATCH_SIZE] [--ubatch-size UBATCH_SIZE] [--threads THREADS] [--threads-batch THREADS_BATCH] [--numa] [--parallel PARALLEL] [--fit-target FIT_TARGET] [--extra-flags EXTRA_FLAGS] [--cpu] [--cpu-memory CPU_MEMORY] [--disk] [--disk-cache-dir DISK_CACHE_DIR] [--load-in-8bit] [--bf16] [--no-cache] [--trust-remote-code] [--force-safetensors] [--no_use_fast] [--attn-implementation IMPLEMENTATION] [--load-in-4bit] [--use_double_quant] [--compute_dtype COMPUTE_DTYPE] [--quant_type QUANT_TYPE] [--gpu-split GPU_SPLIT] [--enable-tp] [--tp-backend TP_BACKEND] [--cfg-cache] [--listen] [--listen-port LISTEN_PORT] [--listen-host LISTEN_HOST] [--share] [--auto-launch] [--gradio-auth GRADIO_AUTH] [--gradio-auth-path GRADIO_AUTH_PATH] [--ssl-keyfile SSL_KEYFILE] [--ssl-certfile SSL_CERTFILE] [--subpath SUBPATH] [--old-colors] [--portable] [--api] [--public-api] [--public-api-id PUBLIC_API_ID] [--api-port API_PORT] [--api-key API_KEY] [--admin-key ADMIN_KEY] [--api-enable-ipv6] [--api-disable-ipv4] [--nowebui] [--temperature N] [--dynatemp-low N] [--dynatemp-high N] [--dynatemp-exponent N] [--smoothing-factor N] [--smoothing-curve N] [--min-p N] [--top-p N] [--top-k N] [--typical-p N] [--xtc-threshold N] [--xtc-probability N] [--epsilon-cutoff N] [--eta-cutoff N] [--tfs N] [--top-a N] [--top-n-sigma N] [--adaptive-target N] [--adaptive-decay N] [--dry-multiplier N] [--dry-allowed-length N] [--dry-base N] [--repetition-penalty N] [--frequency-penalty N] [--presence-penalty N] [--encoder-repetition-penalty N] [--no-repeat-ngram-size N] [--repetition-penalty-range N] [--penalty-alpha N] [--guidance-scale N] [--mirostat-mode N] [--mirostat-tau N] [--mirostat-eta N] [--do-sample | --no-do-sample] [--dynamic-temperature | --no-dynamic-temperature] [--temperature-last | --no-temperature-last] [--sampler-priority N] [--dry-sequence-breakers N] [--enable-thinking | --no-enable-thinking] [--reasoning-effort N] [--chat-template-file CHAT_TEMPLATE_FILE] Text Generation Web UI options: -h, --help show this help message and exit Basic settings: --user-data-dir USER_DATA_DIR Path to the user data directory. Default: auto-detected. --multi-user Multi-user mode. Chat histories are not saved or automatically loaded. Best suited for small trusted teams. --model MODEL Name of the model to load by default. --lora LORA [LORA ...] The list of LoRAs to load. If you want to load more than one LoRA, write the names separated by spaces. --model-dir MODEL_DIR Path to directory with all the models. --lora-dir LORA_DIR Path to directory with all the loras. --model-menu Show a model menu in the terminal when the web UI is first launched. --settings SETTINGS Load the default interface settings from this yaml file. See user_data/settings-template.yaml for an example. If you create a file called user_data/settings.yaml, this file will be loaded by default without the need to use the --settings flag. --extensions EXTENSIONS [EXTENSIONS ...] The list of extensions to load. If you want to load more than one extension, write the names separated by spaces. --verbose Print the prompts to the terminal. --idle-timeout IDLE_TIMEOUT Unload model after this many minutes of inactivity. It will be automatically reloaded when you try to use it again. Image model: --image-model IMAGE_MODEL Name of the image model to select on startup (overrides saved setting). --image-model-dir IMAGE_MODEL_DIR Path to directory with all the image models. --image-dtype {bfloat16,float16} Data type for image model. --image-attn-backend {flash_attention_2,sdpa} Attention backend for image model. --image-cpu-offload Enable CPU offloading for image model. --image-compile Compile the image model for faster inference. --image-quant {none,bnb-8bit,bnb-4bit,torchao-int8wo,torchao-fp4,torchao-float8wo} Quantization method for image model. Model loader: --loader LOADER Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, ExLlamav3_HF, ExLlamav3, TensorRT- LLM. Context and cache: --ctx-size, --n_ctx, --max_seq_len N Context size in tokens. 0 = auto for llama.cpp (requires gpu-layers=-1), 8192 for other loaders. --cache-type, --cache_type N KV cache type; valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV3 - fp16, q2 to q8 (can specify k_bits and v_bits separately, e.g. q4_q8). Speculative decoding: --model-draft MODEL_DRAFT Path to the draft model for speculative decoding. --draft-max DRAFT_MAX Number of tokens to draft for speculative decoding. --gpu-layers-draft GPU_LAYERS_DRAFT Number of layers to offload to the GPU for the draft model. --device-draft DEVICE_DRAFT Comma-separated list of devices to use for offloading the draft model. Example: CUDA0,CUDA1 --ctx-size-draft CTX_SIZE_DRAFT Size of the prompt context for the draft model. If 0, uses the same as the main model. --spec-type {none,ngram-mod,ngram-simple,ngram-map-k,ngram-map-k4v,ngram-cache} Draftless speculative decoding type. Recommended: ngram-mod. --spec-ngram-size-n SPEC_NGRAM_SIZE_N N-gram lookup size for ngram speculative decoding. --spec-ngram-size-m SPEC_NGRAM_SIZE_M Draft n-gram size for ngram speculative decoding. --spec-ngram-min-hits SPEC_NGRAM_MIN_HITS Minimum n-gram hits for ngram-map speculative decoding. llama.cpp: --gpu-layers, --n-gpu-layers N Number of layers to offload to the GPU. -1 = auto. --cpu-moe Move the experts to the CPU (for MoE models). --mmproj MMPROJ Path to the mmproj file for vision models. --streaming-llm Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed. --tensor-split TENSOR_SPLIT Split the model across multiple GPUs. Comma-separated list of proportions. Example: 60,40. --row-split Split the model by rows across GPUs. This may improve multi-gpu performance. --no-mmap Prevent mmap from being used. --mlock Force the system to keep the model in RAM. --no-kv-offload Do not offload the K, Q, V to the GPU. This saves VRAM but reduces the performance. --batch-size BATCH_SIZE Maximum number of prompt tokens to batch together when calling llama-server. This is the application level batch size. --ubatch-size UBATCH_SIZE Maximum number of prompt tokens to batch together when calling llama-server. This is the max physical batch size for computation (device level). --threads THREADS Number of threads to use. --threads-batch THREADS_BATCH Number of threads to use for batches/prompt processing. --numa Activate NUMA task allocation for llama.cpp. --parallel PARALLEL Number of parallel request slots. The context size is divided equally among slots. For example, to have 4 slots with 8192 context each, set ctx_size to 32768. --fit-target FIT_TARGET Target VRAM margin per device for auto GPU layers, comma-separated list of values in MiB. A single value is broadcast across all devices. Default: 1024. --extra-flags EXTRA_FLAGS Extra flags to pass to llama-server. Format: "flag1=value1,flag2,flag3=value3". Example: "override-tensor=exps=CPU" Transformers/Accelerate: --cpu Use the CPU to generate text. Warning: Training on CPU is extremely slow. --cpu-memory CPU_MEMORY Maximum CPU memory in GiB. Use this for CPU offloading. --disk If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. --disk-cache-dir DISK_CACHE_DIR Directory to save the disk cache to. --load-in-8bit Load the model with 8-bit precision (using bitsandbytes). --bf16 Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. --no-cache Set use_cache to False while generating text. This reduces VRAM usage slightly, but it comes at a performance cost. --trust-remote-code Set trust_remote_code=True while loading the model. Necessary for some models. --force-safetensors Set use_safetensors=True while loading the model. This prevents arbitrary code execution. --no_use_fast Set use_fast=False while loading the tokenizer (it's True by default). Use this if you have any problems related to use_fast. --attn-implementation IMPLEMENTATION Attention implementation. Valid options: sdpa, eager, flash_attention_2. bitsandbytes 4-bit: --load-in-4bit Load the model with 4-bit precision (using bitsandbytes). --use_double_quant use_double_quant for 4-bit. --compute_dtype COMPUTE_DTYPE compute dtype for 4-bit. Valid options: bfloat16, float16, float32. --quant_type QUANT_TYPE quant_type for 4-bit. Valid options: nf4, fp4. ExLlamaV3: --gpu-split GPU_SPLIT Comma-separated list of VRAM (in GB) to use per GPU device for model layers. Example: 20,7,7. --enable-tp, --enable_tp Enable Tensor Parallelism (TP) to split the model across GPUs. --tp-backend TP_BACKEND The backend for tensor parallelism. Valid options: native, nccl. Default: native. --cfg-cache Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader. Gradio: --listen Make the web UI reachable from your local network. --listen-port LISTEN_PORT The listening port that the server will use. --listen-host LISTEN_HOST The hostname that the server will use. --share Create a public URL. This is useful for running the web UI on Google Colab or similar. --auto-launch Open the web UI in the default browser upon launch. --gradio-auth GRADIO_AUTH Set Gradio authentication password in the format "username:password". Multiple credentials can also be supplied with "u1:p1,u2:p2,u3:p3". --gradio-auth-path GRADIO_AUTH_PATH Set the Gradio authentication file path. The file should contain one or more user:password pairs in the same format as above. --ssl-keyfile SSL_KEYFILE The path to the SSL certificate key file. --ssl-certfile SSL_CERTFILE The path to the SSL certificate cert file. --subpath SUBPATH Customize the subpath for gradio, use with reverse proxy --old-colors Use the legacy Gradio colors, before the December/2024 update. --portable Hide features not available in portable mode like training. API: --api Enable the API extension. --public-api Create a public URL for the API using Cloudflare. --public-api-id PUBLIC_API_ID Tunnel ID for named Cloudflare Tunnel. Use together with public-api option. --api-port API_PORT The listening port for the API. --api-key API_KEY API authentication key. --admin-key ADMIN_KEY API authentication key for admin tasks like loading and unloading models. If not set, will be the same as --api-key. --api-enable-ipv6 Enable IPv6 for the API --api-disable-ipv4 Disable IPv4 for the API --nowebui Do not launch the Gradio UI. Useful for launching the API in standalone mode. API generation defaults: --temperature N Temperature --dynatemp-low N Dynamic temperature low --dynatemp-high N Dynamic temperature high --dynatemp-exponent N Dynamic temperature exponent --smoothing-factor N Smoothing factor --smoothing-curve N Smoothing curve --min-p N Min P --top-p N Top P --top-k N Top K --typical-p N Typical P --xtc-threshold N XTC threshold --xtc-probability N XTC probability --epsilon-cutoff N Epsilon cutoff --eta-cutoff N Eta cutoff --tfs N TFS --top-a N Top A --top-n-sigma N Top N Sigma --adaptive-target N Adaptive target --adaptive-decay N Adaptive decay --dry-multiplier N DRY multiplier --dry-allowed-length N DRY allowed length --dry-base N DRY base --repetition-penalty N Repetition penalty --frequency-penalty N Frequency penalty --presence-penalty N Presence penalty --encoder-repetition-penalty N Encoder repetition penalty --no-repeat-ngram-size N No repeat ngram size --repetition-penalty-range N Repetition penalty range --penalty-alpha N Penalty alpha --guidance-scale N Guidance scale --mirostat-mode N Mirostat mode --mirostat-tau N Mirostat tau --mirostat-eta N Mirostat eta --do-sample, --no-do-sample Do sample --dynamic-temperature, --no-dynamic-temperature Dynamic temperature --temperature-last, --no-temperature-last Temperature last --sampler-priority N Sampler priority --dry-sequence-breakers N DRY sequence breakers --enable-thinking, --no-enable-thinking Enable thinking --reasoning-effort N Reasoning effort --chat-template-file CHAT_TEMPLATE_FILE Path to a chat template file (.jinja, .jinja2, or .yaml) to use as the default instruction template for API requests. Overrides the model's built-in template. ```
## Downloading models 1. Download a GGUF model file from [Hugging Face](https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads&search=gguf). 2. Place it in the `user_data/models` folder. That's it. The UI will detect it automatically. To check what will fit your GPU, you can use the [VRAM Calculator](https://huggingface.co/spaces/oobabooga/accurate-gguf-vram-calculator).
Other model types (Transformers, EXL3) Models that consist of multiple files (like 16-bit Transformers models and EXL3 models) should be placed in a subfolder inside `user_data/models`: ``` text-generation-webui └── user_data └── models └── Qwen_Qwen3-8B ├── config.json ├── generation_config.json ├── model-00001-of-00004.safetensors ├── ... ├── tokenizer_config.json └── tokenizer.json ``` These formats require the one-click installer (not the portable build).
## Documentation https://github.com/oobabooga/text-generation-webui/wiki ## Community https://www.reddit.com/r/Oobabooga/ ## Acknowledgments - In August 2023, [Andreessen Horowitz](https://a16z.com/) (a16z) provided a generous grant to encourage and support my independent work on this project. I am **extremely** grateful for their trust and recognition. - This project was inspired by [AUTOMATIC1111/stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) and wouldn't exist without it. ================================================ FILE: cmd_linux.sh ================================================ #!/usr/bin/env bash cd "$(dirname "${BASH_SOURCE[0]}")" if [[ "$(pwd)" =~ " " ]]; then echo This script relies on Miniforge which can not be silently installed under a path with spaces. && exit; fi # deactivate existing conda envs as needed to avoid conflicts { conda deactivate && conda deactivate && conda deactivate; } 2> /dev/null # config CONDA_ROOT_PREFIX="$(pwd)/installer_files/conda" INSTALL_ENV_DIR="$(pwd)/installer_files/env" # environment isolation export PYTHONNOUSERSITE=1 unset PYTHONPATH unset PYTHONHOME export CUDA_PATH="$INSTALL_ENV_DIR" export CUDA_HOME="$CUDA_PATH" # activate env bash --init-file <(echo "source \"$CONDA_ROOT_PREFIX/etc/profile.d/conda.sh\" && conda activate \"$INSTALL_ENV_DIR\"") ================================================ FILE: cmd_macos.sh ================================================ #!/bin/bash cd "$(dirname "${BASH_SOURCE[0]}")" if [[ "$(pwd)" =~ " " ]]; then echo This script relies on Miniforge which can not be silently installed under a path with spaces. && exit; fi # deactivate existing conda envs as needed to avoid conflicts { conda deactivate && conda deactivate && conda deactivate; } 2> /dev/null # config CONDA_ROOT_PREFIX="$(pwd)/installer_files/conda" INSTALL_ENV_DIR="$(pwd)/installer_files/env" # environment isolation export PYTHONNOUSERSITE=1 unset PYTHONPATH unset PYTHONHOME export CUDA_PATH="$INSTALL_ENV_DIR" export CUDA_HOME="$CUDA_PATH" # activate env source $CONDA_ROOT_PREFIX/etc/profile.d/conda.sh conda activate $INSTALL_ENV_DIR exec bash --norc ================================================ FILE: cmd_windows.bat ================================================ @echo off cd /D "%~dp0" set PATH=%PATH%;%SystemRoot%\system32 echo "%CD%"| findstr /C:" " >nul && echo This script relies on Miniforge which can not be silently installed under a path with spaces. && goto end @rem fix failed install when installing to a separate drive set TMP=%cd%\installer_files set TEMP=%cd%\installer_files @rem deactivate existing conda envs as needed to avoid conflicts (call conda deactivate && call conda deactivate && call conda deactivate) 2>nul @rem config set CONDA_ROOT_PREFIX=%cd%\installer_files\conda set INSTALL_ENV_DIR=%cd%\installer_files\env @rem environment isolation set PYTHONNOUSERSITE=1 set PYTHONPATH= set PYTHONHOME= set PYTHONUTF8=1 set "CUDA_PATH=%INSTALL_ENV_DIR%" set "CUDA_HOME=%CUDA_PATH%" @rem activate installer env call "%CONDA_ROOT_PREFIX%\condabin\conda.bat" activate "%INSTALL_ENV_DIR%" || ( echo. && echo Miniforge hook not found. && goto end ) @rem enter commands cmd /k "%*" :end pause ================================================ FILE: docker/.dockerignore ================================================ .env Dockerfile /user_data ================================================ FILE: docker/TensorRT-LLM/Dockerfile ================================================ FROM nvidia/cuda:13.0.1-cudnn-runtime-ubuntu24.04 # Install Python 3.12, Git, and OpenMPI RUN apt update && apt install -y python3.12 python3-pip git build-essential openmpi-bin libopenmpi-dev # Set the working directory WORKDIR /app # This is needed to avoid an error about "Failed to build mpi4py" in the next command ENV LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH # Install text-generation-webui RUN git clone https://github.com/oobabooga/text-generation-webui WORKDIR /app/text-generation-webui RUN pip install --break-system-packages -r requirements/full/requirements.txt # Install TensorRT-LLM RUN pip3 install --break-system-packages tensorrt_llm==1.1.0 --extra-index-url https://pypi.nvidia.com # Expose the necessary port for the Python server EXPOSE 7860 5000 # Run the Python server.py script with the specified command CMD ["python3", "server.py", "--api", "--listen"] ================================================ FILE: docker/amd/Dockerfile ================================================ # BUILDER FROM ubuntu:22.04 WORKDIR /builder ARG BUILD_EXTENSIONS="${BUILD_EXTENSIONS:-}" ARG APP_UID="${APP_UID:-6972}" ARG APP_GID="${APP_GID:-6972}" RUN --mount=type=cache,target=/var/cache/apt,sharing=locked,rw \ apt update && \ apt install --no-install-recommends -y git vim build-essential python3-dev pip bash curl && \ rm -rf /var/lib/apt/lists/* WORKDIR /home/app/ RUN git clone https://github.com/oobabooga/text-generation-webui.git WORKDIR /home/app/text-generation-webui RUN GPU_CHOICE=B LAUNCH_AFTER_INSTALL=FALSE INSTALL_EXTENSIONS=TRUE ./start_linux.sh --verbose EXPOSE ${CONTAINER_PORT:-7860} ${CONTAINER_API_PORT:-5000} WORKDIR /home/app/text-generation-webui # set umask to ensure group read / write at runtime CMD umask 0002 && export HOME=/home/app/text-generation-webui && ./start_linux.sh --listen ================================================ FILE: docker/amd/docker-compose.yml ================================================ version: "3.3" services: text-generation-webui: build: context: . args: BUILD_EXTENSIONS: ${BUILD_EXTENSIONS:-} APP_GID: ${APP_GID:-6972} APP_UID: ${APP_UID:-6972} env_file: .env user: "${APP_RUNTIME_UID:-6972}:${APP_RUNTIME_GID:-6972}" ports: - "${HOST_PORT:-7860}:${CONTAINER_PORT:-7860}" - "${HOST_API_PORT:-5000}:${CONTAINER_API_PORT:-5000}" stdin_open: true group_add: - video tty: true ipc: host devices: - /dev/kfd - /dev/dri cap_add: - SYS_PTRACE security_opt: - seccomp=unconfined volumes: - ./user_data:/home/app/text-generation-webui/user_data ================================================ FILE: docker/cpu/Dockerfile ================================================ # BUILDER FROM ubuntu:22.04 WORKDIR /builder ARG BUILD_EXTENSIONS="${BUILD_EXTENSIONS:-}" ARG APP_UID="${APP_UID:-6972}" ARG APP_GID="${APP_GID:-6972}" RUN --mount=type=cache,target=/var/cache/apt,sharing=locked,rw \ apt update && \ apt install --no-install-recommends -y git vim build-essential python3-dev pip bash curl && \ rm -rf /var/lib/apt/lists/* WORKDIR /home/app/ RUN git clone https://github.com/oobabooga/text-generation-webui.git WORKDIR /home/app/text-generation-webui RUN GPU_CHOICE=N LAUNCH_AFTER_INSTALL=FALSE INSTALL_EXTENSIONS=TRUE ./start_linux.sh --verbose EXPOSE ${CONTAINER_PORT:-7860} ${CONTAINER_API_PORT:-5000} # set umask to ensure group read / write at runtime WORKDIR /home/app/text-generation-webui CMD umask 0002 && export HOME=/home/app/text-generation-webui && ./start_linux.sh --listen ================================================ FILE: docker/cpu/docker-compose.yml ================================================ version: "3.3" services: text-generation-webui: build: context: . args: BUILD_EXTENSIONS: ${BUILD_EXTENSIONS:-} APP_GID: ${APP_GID:-6972} APP_UID: ${APP_UID:-6972} env_file: .env user: "${APP_RUNTIME_UID:-6972}:${APP_RUNTIME_GID:-6972}" ports: - "${HOST_PORT:-7860}:${CONTAINER_PORT:-7860}" - "${HOST_API_PORT:-5000}:${CONTAINER_API_PORT:-5000}" stdin_open: true tty: true volumes: - ./user_data:/home/app/text-generation-webui/user_data ================================================ FILE: docker/intel/Dockerfile ================================================ # BUILDER FROM ubuntu:22.04 WORKDIR /builder ARG BUILD_EXTENSIONS="${BUILD_EXTENSIONS:-}" ARG APP_UID="${APP_UID:-6972}" ARG APP_GID="${APP_GID:-6972}" RUN --mount=type=cache,target=/var/cache/apt,sharing=locked,rw \ apt update && \ apt install --no-install-recommends -y git vim build-essential python3-dev pip bash curl && \ rm -rf /var/lib/apt/lists/* WORKDIR /home/app/ RUN git clone https://github.com/oobabooga/text-generation-webui.git WORKDIR /home/app/text-generation-webui RUN GPU_CHOICE=D LAUNCH_AFTER_INSTALL=FALSE INSTALL_EXTENSIONS=TRUE ./start_linux.sh --verbose EXPOSE ${CONTAINER_PORT:-7860} ${CONTAINER_API_PORT:-5000} # set umask to ensure group read / write at runtime WORKDIR /home/app/text-generation-webui CMD umask 0002 && export HOME=/home/app/text-generation-webui && ./start_linux.sh --listen ================================================ FILE: docker/intel/docker-compose.yml ================================================ version: "3.3" services: text-generation-webui: build: context: . args: BUILD_EXTENSIONS: ${BUILD_EXTENSIONS:-} APP_GID: ${APP_GID:-6972} APP_UID: ${APP_UID:-6972} env_file: .env user: "${APP_RUNTIME_UID:-6972}:${APP_RUNTIME_GID:-6972}" ports: - "${HOST_PORT:-7860}:${CONTAINER_PORT:-7860}" - "${HOST_API_PORT:-5000}:${CONTAINER_API_PORT:-5000}" stdin_open: true group_add: - video tty: true ipc: host devices: - /dev/kfd - /dev/dri cap_add: - SYS_PTRACE security_opt: - seccomp=unconfined volumes: - ./user_data:/home/app/text-generation-webui/user_data ================================================ FILE: docker/nvidia/Dockerfile ================================================ # BUILDER FROM ubuntu:22.04 WORKDIR /builder ARG TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST:-3.5;5.0;6.0;6.1;7.0;7.5;8.0;8.6+PTX}" ARG BUILD_EXTENSIONS="${BUILD_EXTENSIONS:-}" ARG APP_UID="${APP_UID:-6972}" ARG APP_GID="${APP_GID:-6972}" RUN --mount=type=cache,target=/var/cache/apt,sharing=locked,rw \ apt update && \ apt install --no-install-recommends -y git vim build-essential python3-dev pip bash curl && \ rm -rf /var/lib/apt/lists/* WORKDIR /home/app/ RUN git clone https://github.com/oobabooga/text-generation-webui.git WORKDIR /home/app/text-generation-webui RUN GPU_CHOICE=A LAUNCH_AFTER_INSTALL=FALSE INSTALL_EXTENSIONS=TRUE ./start_linux.sh --verbose EXPOSE ${CONTAINER_PORT:-7860} ${CONTAINER_API_PORT:-5000} WORKDIR /home/app/text-generation-webui # set umask to ensure group read / write at runtime CMD umask 0002 && export HOME=/home/app/text-generation-webui && ./start_linux.sh --listen ================================================ FILE: docker/nvidia/docker-compose.yml ================================================ version: "3.3" services: text-generation-webui: build: context: . args: # specify which cuda version your card supports: https://developer.nvidia.com/cuda-gpus TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST:-8.6;8.9+PTX} BUILD_EXTENSIONS: ${BUILD_EXTENSIONS:-} APP_GID: ${APP_GID:-6972} APP_UID: ${APP_UID:-6972} env_file: .env user: "${APP_RUNTIME_UID:-6972}:${APP_RUNTIME_GID:-6972}" ports: - "${HOST_PORT:-7860}:${CONTAINER_PORT:-7860}" - "${HOST_API_PORT:-5000}:${CONTAINER_API_PORT:-5000}" stdin_open: true tty: true volumes: - ./user_data:/home/app/text-generation-webui/user_data deploy: resources: reservations: devices: - driver: nvidia count: all capabilities: [gpu] ================================================ FILE: docs/01 - Chat Tab.md ================================================ Used to have multi-turn conversations with the model. ## Input area The main action buttons are: * **Send**: sends your message and makes the model start a reply. * **Stop**: stops an ongoing generation as soon as the next token is generated (which can take a while for a slow model). The hover menu (☰) that appears over the chat area contains: * **Regenerate**: similar to Send, but your last message is used as input instead of the text in the input field. Note that if the temperature/top_p/top_k parameters are low in the "Parameters" tab of the UI, the new reply may end up identical to the previous one. * **Continue**: makes the model attempt to continue the existing reply. In some cases, the model may simply end the existing turn immediately without generating anything new, but in other cases, it may generate a longer reply. * **Remove last reply**: removes the last input/output pair from the history and sends your last message back into the input field. * **Impersonate**: makes the model generate a new message on your behalf in the input field, taking into consideration the existing chat history. * **Send dummy message**: adds a new message to the chat history without causing the model to generate a reply. * **Send dummy reply**: adds a new reply to the chat history as if the model had generated this reply. Useful in conjunction with "Send dummy message". * **Send to Notebook**: sends the entire chat prompt up to now to the Notebook tab. * **Show controls**: checkbox that toggles the visibility of the sidebar controls (Start reply with, Mode, Chat style, etc.). Shortcut: Ctrl+S. ## Past chats Allows you to switch between the current and previous conversations with the current character, or between the current and previous instruct conversations (if in "instruct" mode). The available buttons are: * **Branch**: creates a branch of the current conversation at a specific message. * **Rename**: allows you to give a unique name to the selected conversation. * **🗑️**: deletes the selected conversation. * **New chat**: starts a new conversation. If you are talking to a character that has a "Greeting" message defined, this message will be automatically added to the new history. A search field is also available to filter conversations by name. ## Sidebar controls The sidebar (toggled via "Show controls") contains: * **Start reply with**: whatever you type there will appear at the start of every reply by the bot. This is useful to guide the response in the desired direction. * **Reasoning effort**: controls the thinking depth for models that support it. Options: low, medium, high. * **Enable thinking**: enables extended thinking mode for models that support it. * **Activate web search**: when enabled, the model can search the web for information before replying. You can also set the number of pages to download. * **Mode**: see below. * **Chat style**: see below. * **Command for chat-instruct mode**: the command that is used in chat-instruct mode to query the model to generate a reply on behalf of the character. Can be used creatively to generate specific kinds of responses. Inside this string, `<|character|>` is a placeholder that gets replaced with the bot name, and `<|prompt|>` is a placeholder that gets replaced with the full chat prompt. ## Mode The most important input field. It defines how the chat prompt is formatted. There are 3 options: chat, chat-instruct, and instruct. It is worth going into more detail about this because it seems to not be obvious to a lot of people. ### Instruction-following models There are two kinds of models: base models, like Llama and GPT-J, and fine-tuned models, like Alpaca and Vicuna. Fine-tuned models are trained starting from base models, most often with the goal of getting the model to understand and respond to instructions just like ChatGPT does. Let's call such models *instruction-following models*. Each instruction-following model was trained on a specific prompt format, and you have to use that exact prompt format if you want the model to follow your instructions as accurately as it can. As an example, this is the Alpaca format: ``` Below is an instruction that describes a task. Write a response that appropriately completes the request. ### Instruction: Hi there! ### Response: Hello! It's nice to meet you. What can I help with? ### Instruction: How are you? ### Response: I'm doing well, thank you for asking! Is there something specific you would like to talk about or ask me? I'm here to help answer any questions you may have. ``` This format is characterized by a context string at the top, and alternating turns where each user input starts with `### Instruction:` and each bot turn starts with `### Response:`. There are also weirder formats, like the one used by the Llama-2-chat models released by Meta AI: ``` [INST] <> Answer the questions. <> Hi there! [/INST] Hello! It's nice to meet you. What can I help with? [INST] How are you? [/INST] I'm doing well, thank you for asking! Is there something specific you would like to talk about or ask me? I'm here to help answer any questions you may have. ``` In this format, there are special tokens at the end of each bot reply (``, the end of sequence token, and ``, the beginning of sequence token); no new lines separating the turns; and the context string is written between `<>` and `<>`. Despite the intimidating look of this format, the logic is the same: there are user turns and bot turns, and each one appears in a specific place in the template. It is important to emphasize that instruction-following models **have to be used with the exact prompt format that they were trained on**. Using those models with any other prompt format should be considered undefined behavior. The model will still generate replies, but they will be less accurate to your inputs. Now that an instruction-following model is defined, we can move on to describing the 3 chat modes. ### Chat Used for talking to the character defined under "Character" tab using a simple chat prompt in this format: ``` Chiharu Yamada's Persona: Chiharu Yamada is a young, computer engineer-nerd with a knack for problem solving and a passion for technology. You: Hi there! Chiharu Yamada: Hello! It's nice to meet you. What can I help with? You: How are you? Chiharu Yamada: I'm doing well, thank you for asking! Is there something specific you would like to talk about or ask me? I'm here to help answer any questions you may have. ``` There are 3 adjustable parameters in the "Character" tab being used in this prompt: * The **Context** string appears at the top of the prompt. Most often it describes the bot's personality and adds a few example messages to guide the model towards the desired reply length and format. This string never gets truncated: as the prompt size increases, old messages get removed one at a time until the prompt becomes smaller than the truncation length set under "Parameters" > "Generation" > "Truncate the prompt up to this length". * The **Your name** string appears at the beginning of each user reply. By default, this string is "You". * The **Character's name** string appears at the beginning of each bot reply. Additionally, the **Greeting** string appears as the bot's opening message whenever the history is cleared. The "Chat" option should typically be used only for base models or non-instruct fine tunes, and should not be used for instruction-following models. ### Instruct Used for talking to an instruction-following model using the prompt format defined under "Parameters" > "Instruction template". Think of this option as an offline ChatGPT. The prompt format is defined by the **Instruction template** parameter in "Parameters" > "Instruction template", which represents a Jinja2 template. Note that when you load a model in the "Model" tab, the web UI will try to automatically detect its instruction template (if any), and will update the values under "Parameters" > "Instruction template" accordingly. This is done using a set of regular expressions defined in `user_data/models/config.yaml`. This detection is not guaranteed to be accurate. You should check the model card on Hugging Face to see if you are using the correct prompt format. ### Chat-instruct As said above, instruction-following models are meant to be used with their specific prompt templates. The chat-instruct mode allows you to use those templates to generate a chat reply, thus mixing Chat and Instruct modes (hence the name). It works by creating a single instruction-following turn where a command is given followed by the regular chat prompt. Here is an example in Alpaca format: ``` Below is an instruction that describes a task. Write a response that appropriately completes the request. ### Instruction: Continue the chat dialogue below. Write a single reply for the character "Chiharu Yamada". Chiharu Yamada's Persona: Chiharu Yamada is a young, computer engineer-nerd with a knack for problem solving and a passion for technology. You: Hi there! Chiharu Yamada: Hello! It's nice to meet you. What can I help with? You: How are you? ### Response: Chiharu Yamada: ``` Here, the command is > Continue the chat dialogue below. Write a single reply for the character "Chiharu Yamada". Below this command, the regular chat prompt is added, including its Context string and the chat history, and then the user turn ends. The bot turn starts with the "Character's name" string followed by `:`, thus prompting the instruction-following model to write a single reply for the character. Note that you can get creative: instead of writing something trivial like "Write a single reply for the character", you could add more complex instructions like > This is an adventure game, and your task is to write a reply in name of "<|character|>" where 3 options are given for the user to then choose from. And it works: ![chat-instruct](https://github.com/oobabooga/text-generation-webui/assets/112222186/e38e3469-8263-4a10-b1a1-3c955026b8e7) ## Chat style This defines the visual style of the chat UI. Each option is a CSS file defined under `text-generation-webui/css/chat_style-name.css`, where "name" is how this style is called in the dropdown menu. You can add new styles by simply copying `chat_style-cai-chat.css` to `chat_style-myNewStyle.css` and editing the contents of this new file. If you end up with a style that you like, you are highly encouraged to submit it to the repository. The styles are only applied to chat and chat-instruct modes. Instruct mode has its separate style defined in `text-generation-webui/css/html_instruct_style.css`. ## Character gallery This menu is a built-in extension defined under `text-generation-webui/extensions/gallery`. It displays a gallery with your characters, and if you click on a character, it will be automatically selected in the Character tab. ================================================ FILE: docs/02 - Default and Notebook Tabs.md ================================================ Used to generate raw completions starting from your prompt. ## Default tab This tab contains two main text boxes: Input, where you enter your prompt, and Output, where the model output will appear. ### Input The number on the lower right of the Input box counts the number of tokens in the input. It gets updated whenever you update the input text as long as a model is loaded (otherwise there is no tokenizer to count the tokens). Below the Input box, the following buttons can be found: * **Continue**: starts a new generation taking as input the text in the "Output" box. * **Generate**: starts a new generation. * **Stop**: stops an ongoing generation as soon as the next token is generated (which can take a while for a slow model). In the **Prompt** menu, you can select from saved prompts stored in `user_data/logs/notebook`. The **New** button creates a new prompt, the **Rename** button renames the selected prompt, and the 🗑️ button deletes it. The 🔄 button refreshes the list. ### Output Five tabs can be found: * **Raw**: where the raw text generated by the model appears. * **Markdown**: it contains a "Render" button. You can click on it at any time to render the current output as markdown. This is particularly useful for models that generate LaTeX equations like GALACTICA. * **HTML**: displays the output in an HTML style that is meant to be easier to read. Its style is defined under `text-generation-webui/css/html_readable_style.css`. * **Logits**: when you click on "Get next token probabilities", this tab displays the 50 most likely next tokens and their probabilities based on your current input. If "Use samplers" is checked, the probabilities will be the ones after the sampling parameters in the "Parameters" > "Generation" tab are applied. Otherwise, they will be the raw probabilities generated by the model. * **Tokens**: allows you to tokenize your prompt and see the ID numbers for the individual tokens. ## Notebook tab Precisely the same thing as the Default tab, with the difference that the output appears in the same text box as the input. It contains the following additional button: * **Regenerate**: uses your previous input for generation while discarding the last output. ================================================ FILE: docs/03 - Parameters Tab.md ================================================ ## Generation Contains parameters that control the text generation. ### Quick rundown LLMs work by generating one token at a time. Given your prompt, the model calculates the probabilities for every possible next token. The actual token generation is done after that. * In *greedy decoding*, the most likely token is always picked. * Most commonly, *sampling* techniques are used to choose from the next-token distribution in a more non-trivial way with the goal of improving the quality of the generated text. ### Preset menu Can be used to save and load combinations of parameters for reuse. * **🎲 button**: creates a random yet interpretable preset. Only 1 parameter of each category is included for the categories: removing tail tokens, avoiding repetition, and flattening the distribution. That is, top_p and top_k are not mixed, and neither are repetition_penalty and frequency_penalty. You can use this button to break out of a loop of bad generations after multiple "Regenerate" attempts. #### Built-in presets These were obtained after a blind contest called "Preset Arena" where hundreds of people voted. The full results can be found [here](https://github.com/oobabooga/oobabooga.github.io/blob/main/arena/results.md). A key takeaway is that the best presets are: * **For Instruct**: Divine Intellect, Big O, simple-1. * **For Chat**: Midnight Enigma, Yara, Shortwave. The other presets are: * Mirostat: a special decoding technique first implemented in llama.cpp and then adapted into this repository for all loaders. Many people have obtained positive results with it for chat. * LLaMA-Precise: a legacy preset that was the default for the web UI before the Preset Arena. * Debug-deterministic: disables sampling. It is useful for debugging, or if you intentionally want to use greedy decoding. ### Parameters description For more information about the parameters, the [transformers documentation](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig) is a good reference. * **max_new_tokens**: Maximum number of tokens to generate. Don't set it higher than necessary: it is used in the truncation calculation through the formula `(prompt_length) = min(truncation_length - max_new_tokens, prompt_length)`, so your prompt will get truncated if you set it too high. * **temperature**: Primary factor to control the randomness of outputs. 0 = deterministic (only the most likely token is used). Higher value = more randomness. * **top_p**: If not set to 1, select tokens with probabilities adding up to less than this number. Higher value = higher range of possible random results. * **min_p**: Tokens with probability smaller than `(min_p) * (probability of the most likely token)` are discarded. This is the same as top_a but without squaring the probability. * **top_k**: Similar to top_p, but select instead only the top_k most likely tokens. Higher value = higher range of possible random results. * **repetition_penalty**: Penalty factor for repeating prior tokens. 1 means no penalty, higher value = less repetition, lower value = more repetition. * **presence_penalty**: Similar to repetition_penalty, but with an additive offset on the raw token scores instead of a multiplicative factor. It may generate better results. 0 means no penalty, higher value = less repetition, lower value = more repetition. Previously called "additive_repetition_penalty". * **frequency_penalty**: Repetition penalty that scales based on how many times the token has appeared in the context. Be careful with this; there's no limit to how much a token can be penalized. * **repetition_penalty_range**: The number of most recent tokens to consider for repetition penalty. 0 makes all tokens be used. * **dry_multiplier**: Set to greater than 0 to enable DRY (Don't Repeat Yourself) sampling. It penalizes tokens that would extend a sequence that already appeared in the context. Recommended value: 0.8. * **dry_allowed_length**: The longest sequence that can be repeated without being penalized by DRY. Shorter values make DRY more aggressive. * **dry_base**: Controls how fast the DRY penalty grows with increasing sequence length. * **typical_p**: If not set to 1, select only tokens that are at least this much more likely to appear than random tokens, given the prior text. * **tfs**: Tries to detect a tail of low-probability tokens in the distribution and removes those tokens. See [this blog post](https://www.trentonbricken.com/Tail-Free-Sampling/) for details. The closer to 0, the more discarded tokens. * **top_a**: Tokens with probability smaller than `(top_a) * (probability of the most likely token)^2` are discarded. * **top_n_sigma**: Keeps only tokens within N standard deviations of the mean log-probability. Acts as an adaptive cutoff that adjusts to the shape of the distribution. 0 disables it. * **xtc_threshold**: eXclusion from Top Choices (XTC) sampling. If 2 or more tokens have probability above this threshold, the top token may be removed. This encourages the model to use less common word choices and can increase creativity. * **xtc_probability**: The probability that XTC removal will actually happen when the threshold condition is met. Set to 1 for it to always apply, or lower for occasional application. * **epsilon_cutoff**: In units of 1e-4; a reasonable value is 3. This sets a probability floor below which tokens are excluded from being sampled. * **eta_cutoff**: In units of 1e-4; a reasonable value is 3. The main parameter of the special Eta Sampling technique. See [this paper](https://arxiv.org/pdf/2210.15191.pdf) for a description. * **guidance_scale**: The main parameter for Classifier-Free Guidance (CFG). [The paper](https://arxiv.org/pdf/2306.17806.pdf) suggests that 1.5 is a good value. It can be used in conjunction with a negative prompt or not. * **Negative prompt**: Only used when `guidance_scale != 1`. It is most useful for instruct models and custom system messages. You place your full prompt in this field with the system message replaced with the default one for the model (like "You are Llama, a helpful assistant...") to make the model pay more attention to your custom system message. * **penalty_alpha**: Contrastive Search is enabled by setting this to greater than zero and unchecking "do_sample". It should be used with a low value of top_k, for instance, top_k = 4. * **mirostat_mode**: Activates Mirostat sampling, an adaptive decoding method that dynamically controls output perplexity for higher-quality text generation. 0 is disabled. 1 is the classic Mirostat algorithm described in [the paper](https://arxiv.org/abs/2007.14966), but can be less stable, or “wobbly,” and produce less coherent text. 2 is the improved version that is more stable and has lower perplexity, recommended for most use cases. *Note: Use either mirostat or dynamic_temperature, not both at the same time.* * **mirostat_tau**: Target perplexity for Mirostat sampling. Controls how “surprising” the text is. Higher values = more diverse, lower = more predictable. Preset Arena suggests 8 as a good value. * **mirostat_eta**: Learning rate for Mirostat’s perplexity adjustment. Higher values = adapts faster but less stable, lower values = slower but more stable. Preset Arena suggests 0.1 as a good value. * **adaptive_target**: Target probability for adaptive-p sampling. This method adjusts the sampling threshold dynamically based on an exponential moving average of recent token probabilities. 0 disables it. * **adaptive_decay**: EMA decay rate for adaptive-p sampling. Controls how quickly the running average adjusts. Default: 0.9. * **dynamic_temperature**: Activates Dynamic Temperature. This modifies temperature to range between "dynatemp_low" (minimum) and "dynatemp_high" (maximum), with an entropy-based scaling. The steepness of the curve is controlled by "dynatemp_exponent". *Note: Use either dynamic_temperature or mirostat, not both at the same time.* * **smoothing_factor**: Activates Quadratic Sampling. When `0 < smoothing_factor < 1`, the logits distribution becomes flatter. When `smoothing_factor > 1`, it becomes more peaked. * **smoothing_curve**: Adjusts the dropoff curve of Quadratic Sampling. Higher values make the curve steeper. Only takes effect when smoothing_factor is set. * **temperature_last**: Makes temperature the last sampler instead of the first. With this, you can remove low probability tokens with a sampler like min_p and then use a high temperature to make the model creative without losing coherency. Note: this parameter takes precedence over "Sampler priority". That means that `temperature`/`dynamic_temperature`/`quadratic_sampling` will be removed from wherever they are and moved to the end of the stack. * **do_sample**: When unchecked, sampling is entirely disabled, and greedy decoding is used instead (the most likely token is always picked). * **Seed**: Set the Pytorch seed to this number. Note that some loaders do not use Pytorch (notably llama.cpp). For these loaders, the seed has no effect. * **encoder_repetition_penalty**: Also known as the "Hallucinations filter". Used to penalize tokens that are *not* in the prior text. Higher value = more likely to stay in context, lower value = more likely to diverge. * **no_repeat_ngram_size**: If not set to 0, specifies the length of token sets that are completely blocked from repeating at all. Higher values = blocks larger phrases, lower values = blocks words or letters from repeating. Only 0 or high values are a good idea in most cases. To the right (or below if you are on mobile), the following parameters are present: * **Truncate the prompt up to this length**: Used to prevent the prompt from getting bigger than the model's context length. In the case of the transformers loader, which allocates memory dynamically, this parameter can also be used to set a VRAM ceiling and prevent out-of-memory errors. This parameter is automatically updated with the model's context length (from "ctx_size" for loaders that use this parameter, and from the model metadata directly for loaders that do not) when you load a model. * **Maximum number of tokens/second**: to make text readable in real-time in case the model is generating too fast. Good if you want to flex and tell everyone how good your GPU is. * **Custom system message**: If not empty, will be used instead of the default system message in the instruction template. Useful for customizing the personality of the chatbot. Example: "You are a duck." * **Custom stopping strings**: The model stops generating as soon as any of the strings set in this field is generated. Note that when generating text in the Chat tab, some default stopping strings are set regardless of this parameter, like "\nYour Name:" and "\nBot name:" for chat mode. That's why this parameter has a "Custom" in its name. * **Custom token bans**: Allows you to ban the model from generating certain tokens altogether. You need to find the token IDs under "Default" > "Tokens" or "Notebook" > "Tokens", or by looking at the `tokenizer.json` for the model directly. * **auto_max_new_tokens**: When checked, the max_new_tokens parameter is expanded in the backend to the available context length. The maximum length is given by the "truncation_length" parameter. This is useful for getting long replies in the Chat tab without having to click on "Continue" many times. * **Ban the eos_token**: One of the possible tokens that a model can generate is the EOS (End of Sequence) token. When it is generated, the generation stops prematurely. When this parameter is checked, that token is banned from being generated, and the generation will always generate "max_new_tokens" tokens. * **Add the bos_token to the beginning of prompts**: By default, the tokenizer will add a BOS (Beginning of Sequence) token to your prompt. During training, BOS tokens are used to separate different documents. If unchecked, no BOS token will be added, and the model will interpret your prompt as being in the middle of a document instead of at the start of one. This significantly changes the output and can make it more creative. * **Skip special tokens**: When decoding the generated tokens, skip special tokens from being converted to their text representation. Otherwise, BOS appears as ``, EOS as ``, etc. * **prompt_lookup_num_tokens**: Activates Prompt Lookup Decoding, a form of speculative decoding for the Transformers loader. It guesses future tokens by looking for matching patterns in the prompt itself, which can speed up generation for tasks that involve repeating or paraphrasing parts of the input. * **Activate text streaming**: When unchecked, the full response is outputted at once, without streaming the words one at a time. I recommend unchecking this parameter on high latency networks like running the webui on Google Colab or using `--share`. * **Static KV cache**: Use a static cache for improved performance with the Transformers loader. May not be compatible with all models. * **Sampler priority**: Allows you to customize the order in which the different samplers are applied. The first sampler on the list gets applied first. With this, custom orders like `top_p -> temperature -> top_k` can be defined. * **DRY sequence breakers**: Tokens across which DRY sequence matching is not continued. Typically punctuation and special tokens. Only used when DRY is active (dry_multiplier > 0). * **Load grammar from file**: Loads a GBNF grammar from a file under `user_data/grammars`. The output is written to the "Grammar" box below. You can also save and delete custom grammars using this menu. * **Grammar**: Allows you to constrain the model output to a particular format. For instance, you can make the model generate lists, JSON, specific words, etc. Grammar is extremely powerful and I highly recommend it. The syntax looks a bit daunting at first sight, but it gets very easy once you understand it. See the [GBNF Guide](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md) for details. ### Chat tab controls The following parameters appear in the Chat tab sidebar rather than the Parameters tab: * **reasoning_effort**: Controls the thinking depth for models that support it (used by GPT-OSS). Options: low, medium, high. * **enable_thinking**: Enables extended thinking mode for models that support it (used by Seed-OSS and pre-2507 Qwen3). When enabled, the model can use a thinking step before generating its reply. ## Instruction template This sub-tab within the Parameters tab defines the instruction template used in the Chat tab when "instruct" or "chat-instruct" are selected under "Mode". * **Saved instruction templates**: A dropdown menu where you can select a template. Click **Load** to apply it. The 💾 button saves the current template, and the 🗑️ button deletes the selected one. * **Instruction template**: A Jinja2 template that defines the prompt format for the instruction-following conversation. * **Send to notebook**: Send the full instruction template in string format to the Notebook tab. * **Chat template**: A Jinja2 template that defines the prompt format for regular chat conversations with characters. ## Character tab The Character tab is a separate top-level tab that contains the following sub-tabs: ### Character Parameters that define the character used in the Chat tab when "chat" or "chat-instruct" are selected under "Mode". * **Character**: A dropdown menu where you can select from saved characters, save a new character (💾 button), and delete the selected character (🗑️). The **Restore character** button resets the character to its last saved state. * **Character's name**: The bot name as it appears in the prompt. * **Context**: A string that is always at the top of the prompt. It never gets truncated. It usually defines the bot's personality and some key elements of the conversation. * **Greeting**: An opening message for the bot. When set, it appears whenever you start a new chat. * **Character picture**: A profile picture for the bot. To make it apply, you need to save the bot by clicking on 💾. * **Your picture**: Your profile picture. It will be used in all conversations. Note: the following replacements take place in the context and greeting fields when the chat prompt is generated: * `{{char}}` and `` get replaced with "Character's name". * `{{user}}` and `` get replaced with "Your name". So you can use those special placeholders in your character definitions. They are commonly found in TavernAI character cards. ### User Allows you to create and manage user profiles. * **User**: A dropdown to select, save (💾), or delete (🗑️) user profiles. * **Name**: Your name as it appears in the prompt. * **Description**: An optional description of yourself that can be referenced in conversations. ### Chat history In this tab, you can download the current chat history in JSON format and upload a previously saved chat history. When a history is uploaded, a new chat is created to hold it. That is, you don't lose your current chat in the Chat tab. ### Upload character #### YAML or JSON Allows you to upload characters in the YAML format used by the web UI, including optionally a profile picture. #### TavernAI PNG Allows you to upload a TavernAI character card. It will be converted to the internal YAML format of the web UI after upload. ================================================ FILE: docs/04 - Model Tab.md ================================================ This is where you load models, apply LoRAs to a loaded model, and download new models. ## Model loaders ### llama.cpp Loads: GGUF models. Note: GGML models have been deprecated and do not work anymore. Example: https://huggingface.co/TheBloke/Llama-2-7b-Chat-GGUF * **gpu_layers**: The number of layers to allocate to the GPU. If set to 0, only the CPU will be used. If you want to offload all layers, you can simply set this to the maximum value. * **ctx_size**: Context length of the model. In llama.cpp, the cache is preallocated, so the higher this value, the higher the VRAM. It is automatically set to the maximum sequence length for the model based on the metadata inside the GGUF file, but you may need to lower this value to fit the model into your GPU. Set to 0 for automatic context size based on available memory. After loading the model, the "Truncate the prompt up to this length" parameter under "Parameters" > "Generation" is automatically set to your chosen "ctx_size" so that you don't have to set the same thing twice. * **cache_type**: KV cache quantization type. Valid options: `fp16`, `q8_0`, `q4_0`. Lower quantization saves VRAM at the cost of some quality. * **tensor_split**: For multi-gpu only. Sets the amount of memory to allocate per GPU as proportions. Not to be confused with other loaders where this is set in GB; here you can set something like `30,70` for 30%/70%. * **batch_size**: Maximum number of prompt tokens to batch together when calling llama_eval. * **ubatch_size**: Physical maximum batch size for prompt processing. * **threads**: Number of threads. Recommended value: your number of physical cores. * **threads_batch**: Number of threads for batch processing. Recommended value: your total number of cores (physical + virtual). * **cpu_moe**: Force MoE expert layers to run on the CPU, keeping the rest on the GPU. * **extra_flags**: Extra flags to pass to llama-server. Format: `flag1=value1,flag2,flag3=value3`. Example: `override-tensor=exps=CPU`. * **mmproj**: Path to the mmproj file for multimodal (vision) models. This enables image understanding capabilities. * **streaming_llm**: Experimental feature to avoid re-evaluating the entire prompt when part of it is removed, for instance, when you hit the context length for the model in chat mode and an old message is removed. * **cpu**: Force a version of llama.cpp compiled without GPU acceleration to be used. Can usually be ignored. Only set this if you want to use CPU only and llama.cpp doesn't work otherwise. * **row_split**: Split the model by rows across GPUs. This may improve multi-gpu performance. * **no_kv_offload**: Do not offload the KV cache to the GPU. This saves VRAM but reduces performance. * **no_mmap**: Loads the model into memory at once, possibly preventing I/O operations later on at the cost of a longer load time. * **mlock**: Force the system to keep the model in RAM rather than swapping or compressing. * **numa**: May improve performance on certain multi-cpu systems. ### Transformers Loads: full precision (16-bit or 32-bit) models, as well as bitsandbytes-quantized models. The repository usually has a clean name without GGUF or EXL3 in its name, and the model files are named `model.safetensors` or split into parts like `model-00001-of-00004.safetensors`. Example: [https://huggingface.co/lmsys/vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5). Full precision models use a ton of VRAM, so you will usually want to select the "load_in_4bit" and "use_double_quant" options to load the model in 4-bit precision using bitsandbytes. Options: * **gpu_split**: When using multiple GPUs, sets the amount of VRAM in GB to allocate per GPU. Example: `20,7,7`. * **cpu_memory**: Maximum CPU memory in GiB to use for CPU offloading via the accelerate library. Whatever doesn't fit in the GPU or CPU will go to a disk cache if the "disk" checkbox is enabled. * **compute_dtype**: Used when "load_in_4bit" is checked. I recommend leaving the default value. * **quant_type**: Used when "load_in_4bit" is checked. I recommend leaving the default value. * **attn_implementation**: Choose the attention implementation. Valid options: `sdpa`, `eager`, `flash_attention_2`. The default (`sdpa`) works well in most cases; `flash_attention_2` may be useful for training. * **cpu**: Loads the model in CPU mode using Pytorch. The model will be loaded in 32-bit precision, so a lot of RAM will be used. CPU inference with transformers is older than llama.cpp and it works, but it's a lot slower. Note: this parameter has a different interpretation in the llama.cpp loader (see above). * **load_in_8bit**: Load the model in 8-bit precision using bitsandbytes. The 8-bit kernel in that library has been optimized for training and not inference, so load_in_8bit is slower than load_in_4bit (but more accurate). * **bf16**: Use bfloat16 precision instead of float16 (the default). Only applies when quantization is not used. * **disk**: Enable disk offloading for layers that don't fit into the GPU and CPU combined. * **load_in_4bit**: Load the model in 4-bit precision using bitsandbytes. * **use_double_quant**: Use double quantization with 4-bit loading for reduced memory usage. * **trust-remote-code**: Some models use custom Python code to load the model or the tokenizer. For such models, this option needs to be set. It doesn't download any remote content: all it does is execute the .py files that get downloaded with the model. Those files can potentially include malicious code; I have never seen it happen, but it is in principle possible. * **no_use_fast**: Do not use the "fast" version of the tokenizer. Can usually be ignored; only check this if you can't load the tokenizer for your model otherwise. ### ExLlamav3_HF Loads: EXL3 models. These models usually have "EXL3" or "exl3" in the model name. Uses the ExLlamaV3 backend with Transformers samplers. * **ctx_size**: Context length of the model. The cache is preallocated, so the higher this value, the higher the VRAM. It is automatically set to the maximum sequence length for the model based on its metadata, but you may need to lower this value to fit the model into your GPU. After loading the model, the "Truncate the prompt up to this length" parameter under "Parameters" > "Generation" is automatically set to your chosen "ctx_size" so that you don't have to set the same thing twice. * **cache_type**: KV cache quantization type. Valid options: `fp16`, `q2` to `q8`. You can also specify key and value bits separately, e.g. `q4_q8`. Lower quantization saves VRAM at the cost of some quality. * **gpu_split**: Comma-separated list of VRAM (in GB) to use per GPU device for model layers. Example: `20,7,7`. * **cfg_cache**: Creates a second cache to hold the CFG negative prompts. You need to set this if and only if you intend to use CFG in the "Parameters" > "Generation" tab. Checking this parameter doubles the cache VRAM usage. * **no_use_fast**: Do not use the "fast" version of the tokenizer. * **enable_tp**: Enable Tensor Parallelism (TP) to split the model across GPUs. * **tp_backend**: The backend for tensor parallelism. Valid options: `native`, `nccl`. Default: `native`. ### ExLlamav3 The same as ExLlamav3_HF but using the internal samplers of ExLlamaV3 instead of the ones in the Transformers library. Supports speculative decoding with a draft model. Also supports multimodal (vision) models natively. * **ctx_size**: Same as ExLlamav3_HF. * **cache_type**: Same as ExLlamav3_HF. * **gpu_split**: Same as ExLlamav3_HF. * **enable_tp**: Enable Tensor Parallelism (TP) to split the model across GPUs. * **tp_backend**: The backend for tensor parallelism. Valid options: `native`, `nccl`. Default: `native`. ### TensorRT-LLM Loads: TensorRT-LLM engine models. These are highly optimized models compiled specifically for NVIDIA GPUs. * **ctx_size**: Context length of the model. * **cpp_runner**: Use the ModelRunnerCpp runner, which is faster than the default ModelRunner but doesn't support streaming yet. ## Model dropdown Here you can select a model to be loaded, refresh the list of available models, load/unload/reload the selected model, and save the settings for the model. The "settings" are the values in the input fields (checkboxes, sliders, dropdowns) below this dropdown. After saving, those settings will get restored whenever you select that model again in the dropdown menu. If the **Autoload the model** checkbox is selected, the model will be loaded as soon as it is selected in this menu. Otherwise, you will have to click on the "Load" button. ## LoRA dropdown Used to apply LoRAs to the model. Note that LoRA support is not implemented for all loaders. Check the [What Works](https://github.com/oobabooga/text-generation-webui/wiki/What-Works) page for details. ## Download model or LoRA Here you can download a model or LoRA directly from the https://huggingface.co/ website. * Models will be saved to `user_data/models`. * LoRAs will be saved to `user_data/loras`. In the input field, you can enter either the Hugging Face username/model path (like `facebook/galactica-125m`) or the full model URL (like `https://huggingface.co/facebook/galactica-125m`). To specify a branch, add it at the end after a ":" character like this: `facebook/galactica-125m:main`. To download a single file, as necessary for models in GGUF format, you can click on "Get file list" after entering the model path in the input field, and then copy and paste the desired file name in the "File name" field before clicking on "Download". ================================================ FILE: docs/05 - Training Tab.md ================================================ ## Training Your Own LoRAs A LoRA is tied to a specific model architecture — a LoRA trained on Llama 3 8B won't work on Mistral 7B. Train on the exact model you plan to use. ### Quick Start 1. Load your base model with the **Transformers** loader (no LoRAs loaded). 2. Open the **Training** tab > **Train LoRA**. 3. Pick a dataset and configure parameters (see [below](#parameters)). 4. Click **Start LoRA Training** and monitor the [loss](#loss). 5. When done, load the LoRA from the **Models** tab and test it. ### Resuming Training To resume from a checkpoint, use the same LoRA name and uncheck `Override Existing Files`. If checkpoints exist (from `Save every n steps`), training will automatically resume from the latest one with full optimizer and scheduler state preserved. Note that you cannot change the `Rank` of an already created LoRA. You should also use `Copy parameters from` to restore the UI settings (learning rate, epochs, etc.) from the previous run, so that training continues with the same configuration. ### Troubleshooting - **Corrupted outputs**: Start over with a lower Learning Rate. - **Not learning enough**: Run more epochs, or increase the Rank. - **Unwanted formatting**: Tweak your dataset, or train for fewer steps. ## Instruction Templates All instruction/chat training uses `apply_chat_template()` with Jinja2 templates. You have two options in the **Instruction Template** dropdown: - **Chat Template**: Uses the model's built-in chat template from its tokenizer. Works with instruct/chat models that ship with a chat template (Llama 3, Qwen, Mistral, etc.). - **Named template** (e.g. ChatML, Alpaca, Llama-v3, etc.): Loads a Jinja2 template from `user_data/instruction-templates/`. This is useful for base models that don't have a built-in template, or when you want to override the model's default template. Both options are functionally identical — the only difference is where the Jinja2 template string comes from. In both cases: - The dataset is tokenized via `apply_chat_template()` - Labels are automatically masked so only assistant responses are trained on - Multi-turn conversations are supported natively - Special tokens are handled correctly by the template The WebUI ships with 50+ templates in `user_data/instruction-templates/`. You can also add your own by creating a `.yaml` file with an `instruction_template` key containing a Jinja2 template string, or a plain `.jinja` file. **Dataset formats:** Your JSON dataset can use either of these structures: OpenAI messages format: ```json [ { "messages": [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What is Python?"}, {"role": "assistant", "content": "A programming language."}, {"role": "user", "content": "What's it used for?"}, {"role": "assistant", "content": "Web dev, data science, scripting, and more."} ] } ] ``` ShareGPT format (`conversations` key with `from`/`value` fields): ```json [ { "conversations": [ {"from": "system", "value": "You are a helpful assistant."}, {"from": "human", "value": "What is Python?"}, {"from": "gpt", "value": "A programming language."}, {"from": "human", "value": "What's it used for?"}, {"from": "gpt", "value": "Web dev, data science, scripting, and more."} ] } ] ``` ## Text Dataset For pretraining-style training on raw text, use the **Text Dataset** tab. Your dataset should be a JSON file with one document per row, each with a `"text"` key: ```json [ {"text": "First document content..."}, {"text": "Second document content..."} ] ``` This is the standard format used by most pretraining datasets (The Pile, RedPajama, etc.). Each document is tokenized (with BOS token), concatenated into one long token sequence, and split into chunks of `Cutoff Length` tokens. The final chunk is padded if shorter than the cutoff length. When `Add EOS token` is enabled, an EOS token is appended after each document before concatenation, helping the model learn document boundaries. - `Stride Length` controls the overlap between consecutive chunks in tokens. Set to 0 for non-overlapping chunks (the standard concatenate-and-split approach). Values like 256 or 512 create overlapping chunks that help the model learn context across chunk boundaries, at the cost of more training samples. ## Target Modules By default, **Target all linear layers** is enabled. This uses peft's `all-linear` mode, which applies LoRA to every `nn.Linear` layer in the model except the output head (`lm_head`). It works for any model architecture. If you uncheck it, you can manually select individual projection modules (`q_proj`, `k_proj`, `v_proj`, `o_proj`, `gate_proj`, `down_proj`, `up_proj`). Targeting fewer modules reduces VRAM usage and adapter size, but also reduces how much the model can learn. The default selection of `q_proj` + `v_proj` is the minimum for basic style/format training. ## Parameters Each parameter has a description in the UI. Below is guidance on the most important choices. ### VRAM VRAM usage during training is roughly similar to inference with ~1000 tokens of context. If you can run the model, you can probably train LoRAs with the default settings. If you run out of VRAM, reduce `Micro Batch Size` or `Cutoff Length`. Training 4-bit quantized models uses more VRAM — set `Micro Batch Size` to `1` to compensate. ### Rank Higher rank = more learning capacity = larger adapter = more VRAM. Use 4–8 for style/format, 128–256 to teach factual knowledge. ### Learning Rate and Epochs These control how aggressively the model learns and how many times it sees the data. Higher LR + fewer epochs = fast but rough. Lower LR + more epochs = slower but higher quality. The scheduler (default: cosine) decays the LR over the course of training — see [HuggingFace docs](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules#schedules) for graphs of each option. ## Loss When you're running training, the WebUI's console window will log reports that include, among other things, a numeric value named `Loss`. It will start as a high number, and gradually get lower and lower as it goes. Loss measures how far the model's predictions are from the training data, with `0` meaning a perfect match. It's calculated as the cross-entropy between the model's output distribution and the expected tokens. In practice, a loss of `0` means the model has overfit — it memorized the training data at the expense of its general capabilities. Loss is a balancing game: you want it low enough that the model learns your data, but not so low that it loses general knowledge. Generally, if it goes below `1.0`, overfitting is likely and you should stop training. In some cases you may want to go as low as `0.5` (if you need very predictable outputs). Different goals have different needs, so experiment and see what works best for you. Note: if you see Loss start at or suddenly jump to exactly `0`, it is likely something has gone wrong in your training process (eg model corruption). ================================================ FILE: docs/06 - Session Tab.md ================================================ Here you can restart the UI with new settings. ## Settings * **Toggle light/dark theme**: switches between light and dark mode. * **Show two columns in the Notebook tab**: toggles between the two-column Default layout and the single-column Notebook layout. * **Turn long pasted text into attachments in the Chat tab**: when enabled, long pasted text is automatically converted into file attachments. * **Include attachments/search results from previous messages in the chat prompt**: when enabled, attachments and web search results from earlier messages are included in subsequent prompts. ## Extensions & flags * **Available extensions**: shows a list of extensions available under `text-generation-webui/extensions` and `text-generation-webui/user_data/extensions`. Note that some of these extensions may require manually installing Python requirements through the command: `pip install -r extensions/extension_name/requirements.txt`. * **Boolean command-line flags**: shows command-line flags of bool (true/false) type. After selecting your desired flags and extensions, you can restart the UI by clicking on **Apply flags/extensions and restart**. ## Install or update an extension In this field, you can enter the GitHub URL for an extension and press enter to either install it (i.e. cloning it into `text-generation-webui/extensions`) or update it with `git pull` in case it is already cloned. Note that some extensions may include additional Python requirements. In this case, to install those you have to run the command ``` pip install -r extensions/extension-name/requirements.txt ``` or ``` pip install -r extensions\extension-name\requirements.txt ``` if you are on Windows. If you used the one-click installer, this command should be executed in the terminal window that appears when you run the "cmd_" script for your OS. ## Saving UI defaults The **Save extensions settings to user_data/settings.yaml** button gathers the visible values in the UI and saves them to `user_data/settings.yaml` so that your settings will persist across multiple restarts of the UI. Note that preset parameters like temperature are not individually saved, so you need to first save your preset and select it in the preset menu before saving the defaults. ================================================ FILE: docs/07 - Extensions.md ================================================ # Extensions Extensions are defined by files named `script.py` inside subfolders of either: - `text-generation-webui/extensions` - `text-generation-webui/user_data/extensions` They are loaded at startup if the folder name is specified after the `--extensions` flag. For instance, `extensions/silero_tts/script.py` or `user_data/extensions/silero_tts/script.py` gets loaded with `python server.py --extensions silero_tts`. **Note:** Extensions in `user_data/extensions/` take priority over those in `extensions/` when both exist with the same name. ## [text-generation-webui-extensions](https://github.com/oobabooga/text-generation-webui-extensions) The repository above contains a directory of user extensions. If you create an extension, you are welcome to host it in a GitHub repository and submit a PR adding it to the list. ## Built-in extensions |Extension|Description| |---------|-----------| |[openai](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/openai)| Creates an API that mimics the OpenAI API and can be used as a drop-in replacement. | |[superboogav2](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/superboogav2)| Enhanced RAG extension with support for PDF, DOCX, and PPTX files. | |[send_pictures](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/send_pictures/)| Creates an image upload field that can be used to send images to the bot in chat mode. Captions are automatically generated using BLIP. | |[coqui_tts](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/coqui_tts)| Text-to-speech extension using Coqui XTTS v2. | |[silero_tts](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/silero_tts)| Text-to-speech extension using [Silero](https://github.com/snakers4/silero-models). When used in chat mode, responses are replaced with an audio widget. | |[whisper_stt](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/whisper_stt)| Allows you to enter your inputs in chat mode using your microphone. | |[perplexity_colors](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/perplexity_colors)| Colors each token in the output text by its associated probability, as derived from the model logits. | |[google_translate](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/google_translate)| Automatically translates inputs and outputs using Google Translate.| |[gallery](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/gallery/)| Creates a gallery with the chat characters and their pictures. | |[sd_api_pictures](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/sd_api_pictures)| Allows you to request pictures from the bot in chat mode, which will be generated using the AUTOMATIC1111 Stable Diffusion API. See examples [here](https://github.com/oobabooga/text-generation-webui/pull/309). | |[long_replies](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/long_replies)| Forces longer replies by suppressing early newlines in the model output. | |[ngrok](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/ngrok)| Allows you to access the web UI remotely using the ngrok reverse tunnel service (free). It's an alternative to the built-in Gradio `--share` feature. | |[superbooga](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/superbooga)| An extension that uses ChromaDB to create an arbitrarily large pseudocontext, taking as input text files, URLs, or pasted text. Based on https://github.com/kaiokendev/superbig. | |[character_bias](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/character_bias)| Just a very simple example that adds a hidden string at the beginning of the bot's reply in chat mode. | ## How to write an extension The extensions framework is based on special functions and variables that you can define in `script.py`. The functions are the following: | Function | Description | |-------------|-------------| | `def setup()` | Is executed when the extension gets imported. | | `def ui()` | Creates custom gradio elements when the UI is launched. | | `def custom_css()` | Returns custom CSS as a string. It is applied whenever the web UI is loaded. | | `def custom_js()` | Same as above but for javascript. | | `def input_modifier(string, state, is_chat=False)` | Modifies the input string before it enters the model. In chat mode, it is applied to the user message. Otherwise, it is applied to the entire prompt. | | `def output_modifier(string, state, is_chat=False)` | Modifies the output string before it is presented in the UI. In chat mode, it is applied to the bot's reply. Otherwise, it is applied to the entire output. | | `def chat_input_modifier(text, visible_text, state)` | Modifies both the visible and internal inputs in chat mode. Can be used to hijack the chat input with custom content. | | `def bot_prefix_modifier(string, state)` | Applied in chat mode to the prefix for the bot's reply. | | `def state_modifier(state)` | Modifies the dictionary containing the UI input parameters before it is used by the text generation functions. | | `def history_modifier(history)` | Modifies the chat history before the text generation in chat mode begins. | | `def custom_generate_reply(...)` | Overrides the main text generation function. | | `def custom_generate_chat_prompt(...)` | Overrides the prompt generator in chat mode. | | `def tokenizer_modifier(state, prompt, input_ids, input_embeds)` | Modifies the `input_ids`/`input_embeds` fed to the model. Should return `prompt`, `input_ids`, `input_embeds`. See the `example` extension for a template. | | `def custom_tokenized_length(prompt)` | Used in conjunction with `tokenizer_modifier`, returns the length in tokens of `prompt`. See the `example` extension for a template. | Additionally, you can define a special `params` dictionary. In it, the `display_name` key is used to define the displayed name of the extension in the UI, and the `is_tab` key is used to define whether the extension should appear in a new tab. By default, extensions appear at the bottom of the "Text generation" tab. Example: ```python params = { "display_name": "Google Translate", "is_tab": True, } ``` The `params` dict may also contain variables that you want to be customizable through a `settings.yaml` file. For instance, assuming the extension is in `extensions/google_translate`, the variable `language string` in ```python params = { "display_name": "Google Translate", "is_tab": True, "language string": "jp" } ``` can be customized by adding a key called `google_translate-language string` to `settings.yaml`: ```python google_translate-language string: 'fr' ``` That is, the syntax for the key is `extension_name-variable_name`. ## Using multiple extensions at the same time You can activate more than one extension at a time by providing their names separated by spaces after `--extensions`. The input, output, and bot prefix modifiers will be applied in the specified order. Example: ``` python server.py --extensions enthusiasm translate # First apply enthusiasm, then translate python server.py --extensions translate enthusiasm # First apply translate, then enthusiasm ``` Do note, that for: - `custom_generate_chat_prompt` - `custom_generate_reply` - `custom_tokenized_length` only the first declaration encountered will be used and the rest will be ignored. ## A full example The source code below can be found at [extensions/example/script.py](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/example/script.py). ```python """ An example of extension. It does nothing, but you can add transformations before the return statements to customize the webui behavior. Starting from history_modifier and ending in output_modifier, the functions are declared in the same order that they are called at generation time. """ import gradio as gr import torch from transformers import LogitsProcessor from modules import chat, shared from modules.text_generation import ( decode, encode, generate_reply, ) params = { "display_name": "Example Extension", "is_tab": False, } class MyLogits(LogitsProcessor): """ Manipulates the probabilities for the next token before it gets sampled. Used in the logits_processor_modifier function below. """ def __init__(self): pass def __call__(self, input_ids, scores): # probs = torch.softmax(scores, dim=-1, dtype=torch.float) # probs[0] /= probs[0].sum() # scores = torch.log(probs / (1 - probs)) return scores def history_modifier(history): """ Modifies the chat history. Only used in chat mode. """ return history def state_modifier(state): """ Modifies the state variable, which is a dictionary containing the input values in the UI like sliders and checkboxes. """ return state def chat_input_modifier(text, visible_text, state): """ Modifies the user input string in chat mode (visible_text). You can also modify the internal representation of the user input (text) to change how it will appear in the prompt. """ return text, visible_text def input_modifier(string, state, is_chat=False): """ In default/notebook modes, modifies the whole prompt. In chat mode, it is the same as chat_input_modifier but only applied to "text", here called "string", and not to "visible_text". """ return string def bot_prefix_modifier(string, state): """ Modifies the prefix for the next bot reply in chat mode. By default, the prefix will be something like "Bot Name:". """ return string def tokenizer_modifier(state, prompt, input_ids, input_embeds): """ Modifies the input ids and embeds. Modifies the input ids and embeds fed to the model. Only used by loaders that use the transformers library for sampling. """ return prompt, input_ids, input_embeds def logits_processor_modifier(processor_list, input_ids): """ Adds logits processors to the list, allowing you to access and modify the next token probabilities. Only used by loaders that use the transformers library for sampling. """ processor_list.append(MyLogits()) return processor_list def output_modifier(string, state, is_chat=False): """ Modifies the LLM output before it gets presented. In chat mode, the modified version goes into history['visible'], and the original version goes into history['internal']. """ return string def custom_generate_chat_prompt(user_input, state, **kwargs): """ Replaces the function that generates the prompt from the chat history. Only used in chat mode. """ result = chat.generate_chat_prompt(user_input, state, **kwargs) return result def custom_css(): """ Returns a CSS string that gets appended to the CSS for the webui. """ return '' def custom_js(): """ Returns a javascript string that gets appended to the javascript for the webui. """ return '' def setup(): """ Gets executed only once, when the extension is imported. """ pass def ui(): """ Gets executed when the UI is drawn. Custom gradio elements and their corresponding event handlers should be defined here. To learn about gradio components, check out the docs: https://gradio.app/docs/ """ pass ``` ================================================ FILE: docs/08 - Additional Tips.md ================================================ ## Audio notification If your computer takes a long time to generate each response for the model that you are using, you can enable an audio notification for when the response is completed. This feature was kindly contributed by HappyWorldGames in [#1277](https://github.com/oobabooga/text-generation-webui/pull/1277). ### Installation Simply place a file called "notification.mp3" in the same folder as `server.py`. Here you can find some examples: * https://pixabay.com/sound-effects/search/ding/?duration=0-30 * https://pixabay.com/sound-effects/search/notification/?duration=0-30 Source: https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/1126 This file will be automatically detected the next time you start the web UI. ## Miscellaneous info ### You can train LoRAs in CPU mode Load the web UI with ``` python server.py --cpu ``` and start training the LoRA from the training tab as usual. ### You can check the sha256sum of downloaded models with the download script ``` python download-model.py facebook/galactica-125m --check ``` ### The download script continues interrupted downloads by default It doesn't start over. ================================================ FILE: docs/09 - Docker.md ================================================ Docker Compose is a way of installing and launching the web UI in an isolated Ubuntu image using only a few commands. ## Prerequisites You need Docker Compose v2.17 or higher: ``` ~$ docker compose version Docker Compose version v2.21.0 ``` Installation instructions: https://docs.docker.com/engine/install/ For NVIDIA GPUs, you also need the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html). ## Quick start There are four Docker variants available under `docker/`: | Directory | GPU | Notes | |-----------|-----|-------| | `docker/nvidia` | NVIDIA | Requires NVIDIA Container Toolkit | | `docker/amd` | AMD | Requires ROCm-compatible GPU | | `docker/intel` | Intel Arc | Beta support | | `docker/cpu` | None | CPU-only inference | To launch (using NVIDIA as an example): ```bash cd text-generation-webui/docker/nvidia cp ../.env.example .env # Optionally edit .env to customize ports, TORCH_CUDA_ARCH_LIST, etc. docker compose up --build ``` The web UI will be available at `http://localhost:7860`. ## User data Create a `user_data/` directory next to the `docker-compose.yml` to persist your models, characters, presets, and settings between container rebuilds: ```bash mkdir -p user_data ``` This directory is mounted into the container at runtime. You can place a `CMD_FLAGS.txt` inside it to pass persistent flags to the web UI (e.g., `--api`). Models can be downloaded through the web UI's “Model” tab once it's running, and they will be saved to `user_data/models/`. ## Dedicated docker repository An external repository maintains a docker wrapper for this project as well as several pre-configured 'one-click' `docker compose` variants. It can be found at: [Atinoda/text-generation-webui-docker](https://github.com/Atinoda/text-generation-webui-docker). ================================================ FILE: docs/11 - AMD Setup.md ================================================ ## Using an AMD GPU in Linux Requires ROCm 6.4 to be installed. ### Option 1: One-click installer The one-click installer (`start_linux.sh`) automatically detects AMD GPUs. When prompted, select the AMD option, or set the `GPU_CHOICE` environment variable before running: ``` GPU_CHOICE=B ./start_linux.sh ``` ### Option 2: Manual conda install Follow the manual conda installation instructions in the README, using the AMD PyTorch command: ``` pip3 install torch==2.9.1 --index-url https://download.pytorch.org/whl/rocm6.4 ``` Then install the project requirements with the AMD requirements file: ``` pip install -r requirements/full/requirements_amd.txt ``` ================================================ FILE: docs/12 - OpenAI API.md ================================================ ## OpenAI compatible API The main API for this project is meant to be a drop-in replacement to the OpenAI API, including Chat and Completions endpoints. * It is 100% offline and private. * It doesn't create any logs. * It doesn't connect to OpenAI. * It doesn't use the openai-python library. ### Starting the API Add `--api` to your command-line flags. * To create a public Cloudflare URL, add the `--public-api` flag. * To listen on your local network, add the `--listen` flag. * To change the port, which is 5000 by default, use `--api-port 1234` (change 1234 to your desired port number). * To use SSL, add `--ssl-keyfile key.pem --ssl-certfile cert.pem`. ⚠️ **Note**: this doesn't work with `--public-api` since Cloudflare already uses HTTPS by default. * To use an API key for authentication, add `--api-key yourkey`. ### Examples For the documentation with all the endpoints, parameters and their types, consult `http://127.0.0.1:5000/docs` or the [typing.py](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/openai/typing.py) file. The official examples in the [OpenAI documentation](https://platform.openai.com/docs/api-reference) should also work, and the same parameters apply (although the API here has more optional parameters). #### Completions ```shell curl http://127.0.0.1:5000/v1/completions \ -H "Content-Type: application/json" \ -d '{ "prompt": "This is a cake recipe:\n\n1.", "max_tokens": 512, "temperature": 0.6, "top_p": 0.95, "top_k": 20 }' ``` #### Chat completions Works best with instruction-following models. If the "instruction_template" variable is not provided, it will be guessed automatically based on the model name using the regex patterns in `user_data/models/config.yaml`. ```shell curl http://127.0.0.1:5000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "messages": [ { "role": "user", "content": "Hello!" } ], "temperature": 0.6, "top_p": 0.95, "top_k": 20 }' ``` #### Chat completions with characters ```shell curl http://127.0.0.1:5000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "messages": [ { "role": "user", "content": "Hello! Who are you?" } ], "mode": "chat-instruct", "character": "Example", "temperature": 0.6, "top_p": 0.95, "top_k": 20 }' ``` #### Multimodal/vision (llama.cpp and ExLlamaV3) ##### With /v1/chat/completions (recommended!) ```shell curl http://127.0.0.1:5000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "messages": [ { "role": "user", "content": [ {"type": "text", "text": "Please describe what you see in this image."}, {"type": "image_url", "image_url": {"url": "https://github.com/turboderp-org/exllamav3/blob/master/examples/media/cat.png?raw=true"}} ] } ], "temperature": 0.6, "top_p": 0.95, "top_k": 20 }' ``` For base64-encoded images, just replace the inner "url" value with this format: `data:image/FORMAT;base64,BASE64_STRING` where FORMAT is the file type (png, jpeg, gif, etc.) and BASE64_STRING is your base64-encoded image data. ##### With /v1/completions ```shell curl http://127.0.0.1:5000/v1/completions \ -H "Content-Type: application/json" \ -d '{ "messages": [ { "role": "user", "content": [ { "type": "text", "text": "About image <__media__> and image <__media__>, what I can say is that the first one" }, { "type": "image_url", "image_url": { "url": "https://github.com/turboderp-org/exllamav3/blob/master/examples/media/cat.png?raw=true" } }, { "type": "image_url", "image_url": { "url": "https://github.com/turboderp-org/exllamav3/blob/master/examples/media/strawberry.png?raw=true" } } ] } ], "temperature": 0.6, "top_p": 0.95, "top_k": 20 }' ``` For base64-encoded images, just replace the inner "url" values with this format: `data:image/FORMAT;base64,BASE64_STRING` where FORMAT is the file type (png, jpeg, gif, etc.) and BASE64_STRING is your base64-encoded image data. #### Image generation ```shell curl http://127.0.0.1:5000/v1/images/generations \ -H "Content-Type: application/json" \ -d '{ "prompt": "an orange tree", "steps": 9, "cfg_scale": 0, "batch_size": 1, "batch_count": 1 }' ``` You need to load an image model first. You can do this via the UI, or by adding `--image-model your_model_name` when launching the server. The output is a JSON object containing a `data` array. Each element has a `b64_json` field with the base64-encoded PNG image: ```json { "created": 1764791227, "data": [ { "b64_json": "iVBORw0KGgo..." } ] } ``` #### SSE streaming ```shell curl http://127.0.0.1:5000/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ "messages": [ { "role": "user", "content": "Hello!" } ], "temperature": 0.6, "top_p": 0.95, "top_k": 20, "stream": true }' ``` #### Logits ```shell curl -k http://127.0.0.1:5000/v1/internal/logits \ -H "Content-Type: application/json" \ -d '{ "prompt": "Who is best, Asuka or Rei? Answer:", "use_samplers": false }' ``` #### Logits after sampling parameters ```shell curl -k http://127.0.0.1:5000/v1/internal/logits \ -H "Content-Type: application/json" \ -d '{ "prompt": "Who is best, Asuka or Rei? Answer:", "use_samplers": true, "top_k": 3 }' ``` #### List models ```shell curl -k http://127.0.0.1:5000/v1/internal/model/list \ -H "Content-Type: application/json" ``` #### Load model ```shell curl -k http://127.0.0.1:5000/v1/internal/model/load \ -H "Content-Type: application/json" \ -d '{ "model_name": "Qwen_Qwen3-0.6B-Q4_K_M.gguf", "args": { "ctx_size": 32768, "flash_attn": true, "cache_type": "q8_0" } }' ``` #### Python chat example ```python import requests url = "http://127.0.0.1:5000/v1/chat/completions" headers = { "Content-Type": "application/json" } history = [] while True: user_message = input("> ") history.append({"role": "user", "content": user_message}) data = { "messages": history, "temperature": 0.6, "top_p": 0.95, "top_k": 20 } response = requests.post(url, headers=headers, json=data, verify=False) assistant_message = response.json()['choices'][0]['message']['content'] history.append({"role": "assistant", "content": assistant_message}) print(assistant_message) ``` #### Python chat example with streaming Start the script with `python -u` to see the output in real time. ```python import requests import sseclient # pip install sseclient-py import json url = "http://127.0.0.1:5000/v1/chat/completions" headers = { "Content-Type": "application/json" } history = [] while True: user_message = input("> ") history.append({"role": "user", "content": user_message}) data = { "stream": True, "messages": history, "temperature": 0.6, "top_p": 0.95, "top_k": 20 } stream_response = requests.post(url, headers=headers, json=data, verify=False, stream=True) client = sseclient.SSEClient(stream_response) assistant_message = '' for event in client.events(): payload = json.loads(event.data) chunk = payload['choices'][0]['delta']['content'] assistant_message += chunk print(chunk, end='') print() history.append({"role": "assistant", "content": assistant_message}) ``` #### Python completions example with streaming Start the script with `python -u` to see the output in real time. ```python import json import requests import sseclient # pip install sseclient-py url = "http://127.0.0.1:5000/v1/completions" headers = { "Content-Type": "application/json" } data = { "prompt": "This is a cake recipe:\n\n1.", "max_tokens": 512, "temperature": 0.6, "top_p": 0.95, "top_k": 20, "stream": True, } stream_response = requests.post(url, headers=headers, json=data, verify=False, stream=True) client = sseclient.SSEClient(stream_response) print(data['prompt'], end='') for event in client.events(): payload = json.loads(event.data) print(payload['choices'][0]['text'], end='') print() ``` #### Python parallel requests example The API supports handling multiple requests in parallel. For ExLlamaV3, this works out of the box. For llama.cpp, you need to pass `--parallel N` to set the number of concurrent slots. ```python import concurrent.futures import requests url = "http://127.0.0.1:5000/v1/chat/completions" prompts = [ "Write a haiku about the ocean.", "Explain quantum computing in simple terms.", "Tell me a joke about programmers.", ] def send_request(prompt): response = requests.post(url, json={ "messages": [{"role": "user", "content": prompt}], "max_tokens": 200, }) return response.json()["choices"][0]["message"]["content"] with concurrent.futures.ThreadPoolExecutor() as executor: results = list(executor.map(send_request, prompts)) for prompt, result in zip(prompts, results): print(f"Q: {prompt}\nA: {result}\n") ``` #### Python example with API key Replace ```python headers = { "Content-Type": "application/json" } ``` with ```python headers = { "Content-Type": "application/json", "Authorization": "Bearer yourPassword123" } ``` in any of the examples above. #### Tool/Function calling Use a model with tool calling support (Qwen, Mistral, GPT-OSS, etc). Tools are passed via the `tools` parameter and the prompt is automatically formatted using the model's Jinja2 template. When the model decides to call a tool, the response will have `finish_reason: "tool_calls"` and a `tool_calls` array with structured function names and arguments. You then execute the tool, send the result back as a `role: "tool"` message, and continue until the model responds with `finish_reason: "stop"`. Some models call multiple tools in parallel (Qwen, Mistral), while others call one at a time (GPT-OSS). The loop below handles both styles. ```python import json import requests url = "http://127.0.0.1:5000/v1/chat/completions" # Define your tools tools = [ { "type": "function", "function": { "name": "get_weather", "description": "Get the current weather for a given location", "parameters": { "type": "object", "properties": { "location": {"type": "string", "description": "City name"}, }, "required": ["location"] } } }, { "type": "function", "function": { "name": "get_time", "description": "Get the current time in a given timezone", "parameters": { "type": "object", "properties": { "timezone": {"type": "string", "description": "IANA timezone string"}, }, "required": ["timezone"] } } }, ] def execute_tool(name, arguments): """Replace this with your actual tool implementations.""" if name == "get_weather": return {"temperature": 22, "condition": "sunny", "humidity": 45} elif name == "get_time": return {"time": "2:30 PM", "timezone": "JST"} return {"error": f"Unknown tool: {name}"} messages = [{"role": "user", "content": "What time is it in Tokyo and what's the weather like there?"}] # Tool-calling loop: keep going until the model gives a final answer for _ in range(10): response = requests.post(url, json={"messages": messages, "tools": tools}).json() choice = response["choices"][0] if choice["finish_reason"] == "tool_calls": # Add the assistant's response (with tool_calls) to history messages.append({ "role": "assistant", "content": choice["message"]["content"], "tool_calls": choice["message"]["tool_calls"], }) # Execute each tool and add results to history for tool_call in choice["message"]["tool_calls"]: name = tool_call["function"]["name"] arguments = json.loads(tool_call["function"]["arguments"]) result = execute_tool(name, arguments) print(f"Tool call: {name}({arguments}) => {result}") messages.append({ "role": "tool", "tool_call_id": tool_call["id"], "content": json.dumps(result), }) else: # Final answer print(f"\nAssistant: {choice['message']['content']}") break ``` ### Environment variables The following environment variables can be used (they take precedence over everything else): | Variable Name | Description | Example Value | |------------------------|------------------------------------|----------------------------| | `OPENEDAI_PORT` | Port number | 5000 | | `OPENEDAI_CERT_PATH` | SSL certificate file path | cert.pem | | `OPENEDAI_KEY_PATH` | SSL key file path | key.pem | | `OPENEDAI_DEBUG` | Enable debugging (set to 1) | 1 | | `OPENEDAI_EMBEDDING_MODEL` | Embedding model (if applicable) | sentence-transformers/all-mpnet-base-v2 | | `OPENEDAI_EMBEDDING_DEVICE` | Embedding device (if applicable) | cuda | #### Persistent settings with `settings.yaml` You can also set the following variables in your `settings.yaml` file: ``` openai-embedding_device: cuda openai-embedding_model: "sentence-transformers/all-mpnet-base-v2" openai-debug: 1 ``` ### Third-party application setup You can usually force an application that uses the OpenAI API to connect to the local API by using the following environment variables: ```shell OPENAI_API_HOST=http://127.0.0.1:5000 ``` or ```shell OPENAI_API_KEY=sk-111111111111111111111111111111111111111111111111 OPENAI_API_BASE=http://127.0.0.1:5000/v1 ``` With the [official python openai client](https://github.com/openai/openai-python) (v1.x), the address can be set like this: ```python from openai import OpenAI client = OpenAI( api_key="sk-111111111111111111111111111111111111111111111111", base_url="http://127.0.0.1:5000/v1" ) response = client.chat.completions.create( model="x", messages=[{"role": "user", "content": "Hello!"}] ) print(response.choices[0].message.content) ``` With the [official Node.js openai client](https://github.com/openai/openai-node) (v4.x): ```js import OpenAI from "openai"; const client = new OpenAI({ apiKey: process.env.OPENAI_API_KEY, baseURL: "http://127.0.0.1:5000/v1", }); const response = await client.chat.completions.create({ model: "x", messages: [{ role: "user", content: "Hello!" }], }); console.log(response.choices[0].message.content); ``` ### Embeddings (alpha) Embeddings requires `sentence-transformers` installed, but chat and completions will function without it loaded. The embeddings endpoint is currently using the HuggingFace model: `sentence-transformers/all-mpnet-base-v2` for embeddings. This produces 768 dimensional embeddings. The model is small and fast. This model and embedding size may change in the future. | model name | dimensions | input max tokens | speed | size | Avg. performance | | ---------------------- | ---------- | ---------------- | ----- | ---- | ---------------- | | all-mpnet-base-v2 | 768 | 384 | 2800 | 420M | 63.3 | | all-MiniLM-L6-v2 | 384 | 256 | 14200 | 80M | 58.8 | In short, the all-MiniLM-L6-v2 model is 5x faster, 5x smaller ram, 2x smaller storage, and still offers good quality. Stats from (https://www.sbert.net/docs/pretrained_models.html). To change the model from the default you can set the environment variable `OPENEDAI_EMBEDDING_MODEL`, ex. "OPENEDAI_EMBEDDING_MODEL=all-MiniLM-L6-v2". Warning: You cannot mix embeddings from different models even if they have the same dimensions. They are not comparable. ### Compatibility | API endpoint | notes | | ------------------------- | --------------------------------------------------------------------------- | | /v1/chat/completions | Use with instruction-following models. Supports streaming, tool calls. | | /v1/completions | Text completion endpoint. | | /v1/embeddings | Using SentenceTransformer embeddings. | | /v1/images/generations | Image generation, response_format='b64_json' only. | | /v1/moderations | Basic support via embeddings. | | /v1/models | Lists models. Currently loaded model first. | | /v1/models/{id} | Returns model info. | | /v1/audio/\* | Supported. | | /v1/images/edits | Not yet supported. | | /v1/images/variations | Not yet supported. | #### Applications Almost everything needs the `OPENAI_API_KEY` and `OPENAI_API_BASE` environment variables set, but there are some exceptions. | Compatibility | Application/Library | Website | Notes | | ------------- | -------------------- | ------------------------------------------------------------------------------ | ----------------------------------------------------------------------------------------- | | ✅❌ | openai-python | https://github.com/openai/openai-python | Use `OpenAI(base_url="http://127.0.0.1:5000/v1")`. Only the endpoints from above work. | | ✅❌ | openai-node | https://github.com/openai/openai-node | Use `new OpenAI({baseURL: "http://127.0.0.1:5000/v1"})`. See example above. | | ✅ | anse | https://github.com/anse-app/anse | API Key & URL configurable in UI, Images also work. | | ✅ | shell_gpt | https://github.com/TheR1D/shell_gpt | OPENAI_API_HOST=http://127.0.0.1:5000 | | ✅ | gpt-shell | https://github.com/jla/gpt-shell | OPENAI_API_BASE=http://127.0.0.1:5000/v1 | | ✅ | gpt-discord-bot | https://github.com/openai/gpt-discord-bot | OPENAI_API_BASE=http://127.0.0.1:5000/v1 | | ✅ | OpenAI for Notepad++ | https://github.com/Krazal/nppopenai | api_url=http://127.0.0.1:5000 in the config file, or environment variables. | | ✅ | vscode-openai | https://marketplace.visualstudio.com/items?itemName=AndrewButson.vscode-openai | OPENAI_API_BASE=http://127.0.0.1:5000/v1 | | ✅❌ | langchain | https://github.com/hwchase17/langchain | Use `base_url="http://127.0.0.1:5000/v1"`. Results depend on model and prompt formatting. | ================================================ FILE: docs/13 - Keyboard Shortcuts.md ================================================ # Keyboard Shortcuts #### General | Shortcut | Description | |-------------------------|--------------------------------------------------| | Esc | Stop generation | #### Chat tab | Shortcut | Description | |-------------------------|--------------------------------------------------| | Ctrl + S | Show/hide chat controls | | Ctrl + Enter | Regenerate | | Alt + Enter | Continue | | Ctrl + Shift + Backspace| Remove last | | Ctrl + Shift + M | Impersonate | | ← (Left Arrow) | Navigate to previous version of last assistant message | | → (Right Arrow) | Navigate to next version of last assistant message (or regenerate if at latest version) | ================================================ FILE: docs/Image Generation Tutorial.md ================================================ # Image Generation Tutorial This feature allows you to generate images using `diffusers` models like [Tongyi-MAI/Z-Image-Turbo](https://huggingface.co/Tongyi-MAI/Z-Image-Turbo) directly within the web UI. print ## Installation 1. Clone the repository with ``` git clone https://github.com/oobabooga/text-generation-webui ``` or download it from [here](https://github.com/oobabooga/text-generation-webui/archive/refs/heads/main.zip) and unzip it. 2. Use the one-click installer. - Windows: Double click on `start_windows.bat` - Linux: Run `./start_linux.sh` - macOS: Run `./start_macos.sh` Note: Image generation does not work with the portable builds in `.zip` format in the [Releases page](https://github.com/oobabooga/text-generation-webui/releases). You need the "full" version of the web UI. ## Downloading a model 1. Once installation ends, browse to `http://127.0.0.1:7860/`. 2. Click on "Image AI" on the left. 3. Click on "Model" at the top. 4. In the "Download model" field, paste `https://huggingface.co/Tongyi-MAI/Z-Image-Turbo` and click "Download". 5. Wait for the download to finish (it's 31 GB). ## Loading the model Select the quantization option in the "Quantization" menu and click "Load". The memory usage for `Z-Image-Turbo` for each option is: | Quantization Method | VRAM Usage | | :--- | :--- | | None (FP16/BF16) | 25613 MiB | | bnb-8bit | 16301 MiB | | bnb-8bit + CPU Offload | 16235 MiB | | bnb-4bit | 11533 MiB | | bnb-4bit + CPU Offload | 7677 MiB | The `torchao` options support `torch.compile` for faster image generation, with `float8wo` specifically providing native hardware acceleration for RTX 40-series and newer GPUs. Note: The next time you launch the web UI, the model will get automatically loaded with your last settings when you try to generate an image. You do not need to go to the Model tab and click "Load" each time. ## Generating images: 1. While still in the "Image AI" page, go to the "Generate" tab. 2. Type your prompt and click on the Generate button. ### Model-specific settings - For Z-Image-Turbo, make sure to keep CFG Scale at 0 and Steps at 9. Do not write a Negative Prompt as it will get ignored with this CFG Scale value. ### LLM Prompt Variations To use this feature, you need to load an LLM in the main "Model" page on the left. If you have no idea what to use, do this to get started: 1. Download [Qwen3-4B-Q3_K_M.gguf](https://huggingface.co/unsloth/Qwen3-4B-GGUF/resolve/main/Qwen3-4B-Q3_K_M.gguf) to your `text-generation-webui/user_data/models` folder. 2. Select the model in the dropdown menu in the "Model" page. 3. Click Load. Then go back to the "Image AI" page and check "LLM Prompt Variations". After that, your prompts will be automatically updated by the LLM each time you generate an image. If you use a "Sequential Count" value greater than 1, a new prompt will be created for each sequential batch. The improvement in creativity is striking (prompt: `Photo of a beautiful woman at night under moonlight`): comparison_collage ## Generating images over API It is possible to generate images using the project's API. Just make sure to start the server with `--api`, either by 1. Passing the `--api` flag to your `start` script, like `./start_linux.sh --api`, or 2. Writing `--api` to your `user_data/CMD_FLAGS.txt` file and relaunching the web UI. Here is an API call example: ``` curl http://127.0.0.1:5000/v1/images/generations \ -H "Content-Type: application/json" \ -d '{ "prompt": "an orange tree", "steps": 9, "cfg_scale": 0, "batch_size": 1, "batch_count": 1 }' ``` ================================================ FILE: docs/Multimodal Tutorial.md ================================================ ## Getting started ### 1. Find a multimodal model GGUF models with vision capabilities are uploaded along a `mmproj` file to Hugging Face. For instance, [unsloth/gemma-3-4b-it-GGUF](https://huggingface.co/unsloth/gemma-3-4b-it-GGUF/tree/main) has this: print1 ### 2. Download the model to `user_data/models` As an example, download https://huggingface.co/unsloth/gemma-3-4b-it-GGUF/resolve/main/gemma-3-4b-it-Q4_K_S.gguf?download=true to your `text-generation-webui/user_data/models` folder. ### 3. Download the associated mmproj file to `user_data/mmproj` Then download https://huggingface.co/unsloth/gemma-3-4b-it-GGUF/resolve/main/mmproj-F16.gguf?download=true to your `text-generation-webui/user_data/mmproj` folder. Name it `mmproj-gemma-3-4b-it-F16.gguf` to give it a recognizable name. ### 4. Load the model 1. Launch the web UI 2. Navigate to the Model tab 3. Select the GGUF model in the Model dropdown: print2 4. Select the mmproj file in the Multimodal (vision) menu: print3 5. Click "Load" ### 5. Send a message with an image Select your image by clicking on the 📎 icon and send your message: print5 The model will reply with great understanding of the image contents: print6 ## Multimodal with ExLlamaV3 Multimodal also works with the ExLlamaV3 loader (the non-HF one). No additional files are necessary, just load a multimodal EXL3 model and send an image. Examples of models that you can use: - https://huggingface.co/turboderp/gemma-3-27b-it-exl3 - https://huggingface.co/turboderp/Mistral-Small-3.1-24B-Instruct-2503-exl3 ## Multimodal API examples In the page below you can find some ready-to-use examples: [Multimodal/vision (llama.cpp and ExLlamaV3)](https://github.com/oobabooga/text-generation-webui/wiki/12-%E2%80%90-OpenAI-API#multimodalvision-llamacpp-and-exllamav3) ================================================ FILE: docs/README.md ================================================ These files are a mirror of the documentation at: # https://github.com/oobabooga/text-generation-webui/wiki It is recommended to browse it there. Contributions can be sent here and will later be synced with the wiki. ================================================ FILE: docs/Tool Calling Tutorial.md ================================================ ## Supported models The following models are supported: - Qwen 3.5 - GPT-OSS - Mistral Small / Devstral - DeepSeek V3 - Kimi-K2 - MiniMax-M2.5 - GLM-5 - Llama 4 Other models that output tool calls as JSON (inside XML tags, code blocks, or plain JSON) are also supported through a generic fallback parser. ## Tool calling in the UI ### 1. Load a model with tool-calling support Load a model with tool-calling support from the Model tab. ### 2. Select tools In the chat sidebar, check the tools you want the model to use: - **web_search** -- Search the web using DuckDuckGo. - **fetch_webpage** -- Fetch the content of a URL. - **calculate** -- Evaluate math expressions. - **get_datetime** -- Get the current date and time. - **roll_dice** -- Roll dice. ### 3. Chat Send a message as usual. When the model decides it needs a tool, it will call it automatically. You will see each tool call and its result in a collapsible accordion inside the chat message. The model may call multiple tools in sequence before giving its final answer. ## Writing custom tools Each tool is a single `.py` file in `user_data/tools/`. It needs two things: 1. A `tool` dictionary that describes the function (name, description, parameters). 2. An `execute(arguments)` function that runs it and returns the result. Here is a minimal example (`user_data/tools/get_datetime.py`): ```python from datetime import datetime tool = { "type": "function", "function": { "name": "get_datetime", "description": "Get the current date and time.", "parameters": { "type": "object", "properties": {}, } } } def execute(arguments): now = datetime.now() return {"date": now.strftime("%Y-%m-%d"), "time": now.strftime("%I:%M %p")} ``` An example with parameters (`user_data/tools/roll_dice.py`): ```python import random tool = { "type": "function", "function": { "name": "roll_dice", "description": "Roll one or more dice with the specified number of sides.", "parameters": { "type": "object", "properties": { "count": {"type": "integer", "description": "Number of dice to roll.", "default": 1}, "sides": {"type": "integer", "description": "Number of sides per die.", "default": 20}, }, } } } def execute(arguments): count = max(1, min(arguments.get("count", 1), 1000)) sides = max(2, min(arguments.get("sides", 20), 1000)) rolls = [random.randint(1, sides) for _ in range(count)] return {"rolls": rolls, "total": sum(rolls)} ``` You can open the built-in tools in `user_data/tools/` for more examples. ## Tool calling over the API Tool calling over the API follows the [OpenAI API](https://platform.openai.com/docs/guides/function-calling) convention. Define your tools, send them with your messages, and handle tool calls in a loop until the model gives a final answer. ```python import json import requests url = "http://127.0.0.1:5000/v1/chat/completions" tools = [ { "type": "function", "function": { "name": "get_weather", "description": "Get the current weather for a given location.", "parameters": { "type": "object", "properties": { "location": {"type": "string", "description": "City name"}, }, "required": ["location"] } } } ] def execute_tool(name, arguments): if name == "get_weather": return {"temperature": "14°C", "condition": "partly cloudy"} return {"error": f"Unknown tool: {name}"} messages = [{"role": "user", "content": "What's the weather like in Paris?"}] for _ in range(10): response = requests.post(url, json={"messages": messages, "tools": tools}).json() choice = response["choices"][0] if choice["finish_reason"] == "tool_calls": messages.append({ "role": "assistant", "content": choice["message"]["content"], "tool_calls": choice["message"]["tool_calls"], }) for tool_call in choice["message"]["tool_calls"]: name = tool_call["function"]["name"] arguments = json.loads(tool_call["function"]["arguments"]) result = execute_tool(name, arguments) print(f"Tool call: {name}({arguments}) => {result}") messages.append({ "role": "tool", "tool_call_id": tool_call["id"], "content": json.dumps(result), }) else: print(f"\nAssistant: {choice['message']['content']}") break ``` ================================================ FILE: docs/What Works.md ================================================ ## What Works | Loader | Loading LoRAs | Training LoRAs | Multimodal | Perplexity evaluation | |----------------|---------------|----------------|------------|-----------------------| | llama.cpp | ❌ | ❌ | ✅\* | ❌ | | Transformers | ✅ | ✅ | ✅\*\* | ✅ | | ExLlamav3_HF | ❌ | ❌ | ❌ | ✅ | | ExLlamav3 | ❌ | ❌ | ✅ | ❌ | | TensorRT-LLM | ❌ | ❌ | ❌ | ❌ | ❌ = not supported ✅ = supported \* Via the `mmproj` parameter (multimodal projector file). \*\* Via the `send_pictures` extension. ================================================ FILE: download-model.py ================================================ ''' Downloads models from Hugging Face to user_data/models/username_modelname. Example: python download-model.py facebook/opt-1.3b ''' import argparse import base64 import datetime import hashlib import json import os import re import sys from multiprocessing import Array from pathlib import Path from time import sleep import requests import tqdm from requests.adapters import HTTPAdapter from requests.exceptions import ConnectionError, RequestException, Timeout from tqdm.contrib.concurrent import thread_map from modules.paths import resolve_user_data_dir base = os.environ.get("HF_ENDPOINT") or "https://huggingface.co" class ModelDownloader: def __init__(self, max_retries=7): self.max_retries = max_retries self.session = self.get_session() self._progress_bar_slots = None self.progress_queue = None def get_session(self): session = requests.Session() if self.max_retries: session.mount('https://cdn-lfs.huggingface.co', HTTPAdapter(max_retries=self.max_retries)) session.mount('https://huggingface.co', HTTPAdapter(max_retries=self.max_retries)) if os.getenv('HF_USER') is not None and os.getenv('HF_PASS') is not None: session.auth = (os.getenv('HF_USER'), os.getenv('HF_PASS')) try: from huggingface_hub import get_token token = get_token() except ImportError: token = os.getenv("HF_TOKEN") if token is not None: session.headers = {'authorization': f'Bearer {token}'} return session def sanitize_model_and_branch_names(self, model, branch): if model[-1] == '/': model = model[:-1] if model.startswith(base + '/'): model = model[len(base) + 1:] model_parts = model.split(":") model = model_parts[0] if len(model_parts) > 0 else model branch = model_parts[1] if len(model_parts) > 1 else branch if branch is None: branch = "main" else: pattern = re.compile(r"^[a-zA-Z0-9._-]+$") if not pattern.match(branch): raise ValueError( "Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.") return model, branch def get_download_links_from_huggingface(self, model, branch, text_only=False, specific_file=None, exclude_pattern=None): session = self.session page = f"/api/models/{model}/tree/{branch}" cursor = b"" links = [] sha256 = [] file_sizes = [] classifications = [] has_pytorch = False has_pt = False has_gguf = False has_safetensors = False is_lora = False while True: url = f"{base}{page}" + (f"?cursor={cursor.decode()}" if cursor else "") r = session.get(url, timeout=10) r.raise_for_status() content = r.content dict = json.loads(content) if len(dict) == 0: break for i in range(len(dict)): fname = dict[i]['path'] if specific_file not in [None, ''] and fname != specific_file: continue # Exclude files matching the exclude pattern if exclude_pattern is not None and re.match(exclude_pattern, fname): continue if not is_lora and fname.endswith(('adapter_config.json', 'adapter_model.bin')): is_lora = True is_pytorch = re.match(r"(pytorch|adapter|gptq)_model.*\.bin", fname) is_safetensors = re.match(r".*\.safetensors", fname) is_pt = re.match(r".*\.pt", fname) is_gguf = re.match(r".*\.gguf", fname) is_tiktoken = re.match(r".*\.tiktoken", fname) is_tokenizer = re.match(r"(tokenizer|ice|spiece).*\.model", fname) or is_tiktoken is_text = re.match(r".*\.(txt|json|py|md)", fname) or is_tokenizer if any((is_pytorch, is_safetensors, is_pt, is_gguf, is_tokenizer, is_text)): file_size = 0 if 'lfs' in dict[i]: sha256.append([fname, dict[i]['lfs']['oid']]) file_size = dict[i]['lfs'].get('size', 0) elif 'size' in dict[i]: file_size = dict[i]['size'] file_sizes.append(file_size) if is_text: links.append(f"{base}/{model}/resolve/{branch}/{fname}") classifications.append('text') continue if not text_only: links.append(f"{base}/{model}/resolve/{branch}/{fname}") if is_safetensors: has_safetensors = True classifications.append('safetensors') elif is_pytorch: has_pytorch = True classifications.append('pytorch') elif is_pt: has_pt = True classifications.append('pt') elif is_gguf: has_gguf = True classifications.append('gguf') cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50' cursor = base64.b64encode(cursor) cursor = cursor.replace(b'=', b'%3D') # If both pytorch and safetensors are available, download safetensors only # Also if GGUF and safetensors are available, download only safetensors if (has_pytorch or has_pt or has_gguf) and has_safetensors: has_gguf = False for i in range(len(classifications) - 1, -1, -1): if classifications[i] in ['pytorch', 'pt', 'gguf']: links.pop(i) file_sizes.pop(i) # For GGUF, try to download only the Q4_K_M if no specific file is specified. if has_gguf and specific_file is None: has_q4km = False for i in range(len(classifications) - 1, -1, -1): if 'q4_k_m' in links[i].lower(): has_q4km = True if has_q4km: for i in range(len(classifications) - 1, -1, -1): if 'q4_k_m' not in links[i].lower(): links.pop(i) file_sizes.pop(i) else: for i in range(len(classifications) - 1, -1, -1): if links[i].lower().endswith('.gguf'): links.pop(i) file_sizes.pop(i) is_llamacpp = has_gguf and specific_file is not None return links, sha256, is_lora, is_llamacpp, file_sizes def get_output_folder(self, model, branch, is_lora, is_llamacpp=False, model_dir=None, user_data_dir=None): if model_dir: base_folder = model_dir else: if user_data_dir is None: user_data_dir = resolve_user_data_dir() base_folder = str(user_data_dir / 'models') if not is_lora else str(user_data_dir / 'loras') # If the model is of type GGUF, save directly in the base_folder if is_llamacpp: return Path(base_folder) output_folder = f"{'_'.join(model.split('/')[-2:])}" if branch != 'main': output_folder += f'_{branch}' output_folder = Path(base_folder) / output_folder return output_folder @property def progress_bar_slots(self): if self._progress_bar_slots is None: raise RuntimeError("Progress bar slots not initialized. Start download threads first.") return self._progress_bar_slots def initialize_progress_bar_slots(self, num_threads): self._progress_bar_slots = Array("B", [0] * num_threads) def get_progress_bar_position(self): with self.progress_bar_slots.get_lock(): for i in range(len(self.progress_bar_slots)): if self.progress_bar_slots[i] == 0: self.progress_bar_slots[i] = 1 return i return 0 # fallback def release_progress_bar_position(self, slot): with self.progress_bar_slots.get_lock(): self.progress_bar_slots[slot] = 0 def get_single_file(self, url, output_folder, start_from_scratch=False): filename = Path(url.rsplit('/', 1)[1]) output_path = output_folder / filename progress_bar_position = self.get_progress_bar_position() max_retries = self.max_retries attempt = 0 file_downloaded_count_for_progress = 0 try: while attempt < max_retries: attempt += 1 session = self.session headers = {} mode = 'wb' current_file_size_on_disk = 0 try: if output_path.exists() and not start_from_scratch: current_file_size_on_disk = output_path.stat().st_size # Make a HEAD request without following redirects to get metadata first r_head = session.head(url, timeout=20, allow_redirects=True) r_head.raise_for_status() # Will raise an error for 4xx or 5xx status codes # Check for the new 'x-linked-size' header from Hugging Face if 'x-linked-size' in r_head.headers: total_size = int(r_head.headers['x-linked-size']) # Fallback to the old 'content-length' just in case elif 'content-length' in r_head.headers: total_size = int(r_head.headers.get('content-length', 0)) else: total_size = 0 if current_file_size_on_disk >= total_size and total_size > 0: if self.progress_queue is not None and total_size > 0: self.progress_queue.put((1.0, str(filename))) return headers = {'Range': f'bytes={current_file_size_on_disk}-'} mode = 'ab' with session.get(url, stream=True, headers=headers, timeout=30) as r: r.raise_for_status() total_size_from_stream = int(r.headers.get('content-length', 0)) if mode == 'ab': effective_total_size = current_file_size_on_disk + total_size_from_stream else: effective_total_size = total_size_from_stream block_size = 1024 * 1024 filename_str = str(filename) tqdm_kwargs = { 'total': effective_total_size, 'initial': current_file_size_on_disk if mode == 'ab' else 0, 'unit': 'B', 'unit_scale': True, 'unit_divisor': 1024, 'bar_format': '{desc}{percentage:3.0f}%|{bar:50}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]', 'desc': f"{filename_str}: ", 'position': progress_bar_position, 'leave': False } if 'COLAB_GPU' in os.environ: tqdm_kwargs.update({ 'position': 0, 'leave': True }) with open(output_path, mode) as f: if mode == 'ab': f.seek(current_file_size_on_disk) with tqdm.tqdm(**tqdm_kwargs) as t: file_downloaded_count_for_progress = current_file_size_on_disk for data in r.iter_content(block_size): f.write(data) t.update(len(data)) if effective_total_size != 0 and self.progress_queue is not None: file_downloaded_count_for_progress += len(data) progress_fraction = float(file_downloaded_count_for_progress) / float(effective_total_size) self.progress_queue.put((progress_fraction, filename_str)) break except (RequestException, ConnectionError, Timeout) as e: print(f"Error downloading {filename}: {e}.") print(f"That was attempt {attempt}/{max_retries}.", end=' ') if attempt < max_retries: print(f"Retry begins in {2 ** attempt} seconds.") sleep(2 ** attempt) else: print("Failed to download after the maximum number of attempts.") finally: self.release_progress_bar_position(progress_bar_position) def start_download_threads(self, file_list, output_folder, start_from_scratch=False, threads=4): self.initialize_progress_bar_slots(threads) tqdm.tqdm.set_lock(tqdm.tqdm.get_lock()) try: thread_map( lambda url: self.get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True ) finally: print(f"\nDownload of {len(file_list)} files to {output_folder} completed.") def download_model_files(self, model, branch, links, sha256, output_folder, progress_queue=None, start_from_scratch=False, threads=4, specific_file=None, is_llamacpp=False): self.progress_queue = progress_queue output_folder.mkdir(parents=True, exist_ok=True) if not is_llamacpp: metadata = f'url: https://huggingface.co/{model}\n' \ f'branch: {branch}\n' \ f'download date: {datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}\n' sha256_str = '\n'.join([f' {item[1]} {item[0]}' for item in sha256]) if sha256_str: metadata += f'sha256sum:\n{sha256_str}' metadata += '\n' (output_folder / 'huggingface-metadata.txt').write_text(metadata) if specific_file: print(f"Downloading {specific_file} to {output_folder}") else: print(f"Downloading the model to {output_folder}") self.start_download_threads(links, output_folder, start_from_scratch=start_from_scratch, threads=threads) def check_model_files(self, model, branch, links, sha256, output_folder): # Validate the checksums validated = True for i in range(len(sha256)): fpath = (output_folder / sha256[i][0]) if not fpath.exists(): print(f"The following file is missing: {fpath}") validated = False continue with open(output_folder / sha256[i][0], "rb") as f: bytes = f.read() file_hash = hashlib.sha256(bytes).hexdigest() if file_hash != sha256[i][1]: print(f'Checksum failed: {sha256[i][0]} {sha256[i][1]}') validated = False else: print(f'Checksum validated: {sha256[i][0]} {sha256[i][1]}') if validated: print('[+] Validated checksums of all model files!') else: print('[-] Invalid checksums. Rerun download-model.py with the --clean flag.') if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('MODEL', type=str, default=None, nargs='?') parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.') parser.add_argument('--threads', type=int, default=4, help='Number of files to download simultaneously.') parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).') parser.add_argument('--specific-file', type=str, default=None, help='Name of the specific file to download (if not provided, downloads all).') parser.add_argument('--exclude-pattern', type=str, default=None, help='Regex pattern to exclude files from download.') parser.add_argument('--output', type=str, default=None, help='Save the model files to this folder.') parser.add_argument('--model-dir', type=str, default=None, help='Save the model files to a subfolder of this folder instead of the default one (user_data/models).') parser.add_argument('--user-data-dir', type=str, default=None, help='Path to the user data directory. Overrides auto-detection.') parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.') parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.') parser.add_argument('--max-retries', type=int, default=7, help='Max retries count when get error in download time.') args = parser.parse_args() branch = args.branch model = args.MODEL specific_file = args.specific_file exclude_pattern = args.exclude_pattern if model is None: print("Error: Please specify the model you'd like to download (e.g. 'python download-model.py facebook/opt-1.3b').") sys.exit() downloader = ModelDownloader(max_retries=args.max_retries) # Handle direct file URLs (e.g. https://huggingface.co/org/repo/resolve/branch/file.gguf) if '/resolve/' in model: url = model if model.startswith('http') else f'{base}/{model}' url = url.split('?')[0] filename = url.split('/')[-1] if args.output: output_folder = Path(args.output) elif args.model_dir: output_folder = Path(args.model_dir) else: user_data_dir = Path(args.user_data_dir) if args.user_data_dir else resolve_user_data_dir() output_folder = user_data_dir / 'models' output_folder.mkdir(parents=True, exist_ok=True) print(f"Downloading {filename} to {output_folder}") downloader.get_single_file(url, output_folder, start_from_scratch=args.clean) sys.exit() # Clean up the model/branch names try: model, branch = downloader.sanitize_model_and_branch_names(model, branch) except ValueError as err_branch: print(f"Error: {err_branch}") sys.exit() # Get the download links from Hugging Face links, sha256, is_lora, is_llamacpp, file_sizes = downloader.get_download_links_from_huggingface( model, branch, text_only=args.text_only, specific_file=specific_file, exclude_pattern=exclude_pattern ) # Get the output folder user_data_dir = Path(args.user_data_dir) if args.user_data_dir else None if args.output: output_folder = Path(args.output) else: output_folder = downloader.get_output_folder(model, branch, is_lora, is_llamacpp=is_llamacpp, model_dir=args.model_dir, user_data_dir=user_data_dir) if args.check: # Check previously downloaded files downloader.check_model_files(model, branch, links, sha256, output_folder) else: # Download files downloader.download_model_files( model, branch, links, sha256, output_folder, specific_file=specific_file, threads=args.threads, is_llamacpp=is_llamacpp ) ================================================ FILE: js/dark_theme.js ================================================ function toggleDarkMode() { document.body.classList.toggle("dark"); var currentCSS = document.getElementById("highlight-css"); if (currentCSS.getAttribute("href") === "file/css/highlightjs/github-dark.min.css") { currentCSS.setAttribute("href", "file/css/highlightjs/github.min.css"); } else { currentCSS.setAttribute("href", "file/css/highlightjs/github-dark.min.css"); } // Re-highlight all code blocks once stylesheet loads currentCSS.onload = function() { const messageBodies = document.getElementById("chat").querySelectorAll(".message-body"); messageBodies.forEach((messageBody) => { const codeBlocks = messageBody.querySelectorAll("pre code"); codeBlocks.forEach((codeBlock) => { hljs.highlightElement(codeBlock); }); }); }; } ================================================ FILE: js/global_scope_js.js ================================================ // ------------------------------------------------- // Event handlers // ------------------------------------------------- function copyToClipboard(element) { if (!element) return; const messageElement = element.closest(".message, .user-message, .assistant-message"); if (!messageElement) return; const rawText = messageElement.getAttribute("data-raw"); if (!rawText) return; const copyPromise = navigator.clipboard && window.isSecureContext ? navigator.clipboard.writeText(rawText) : fallbackCopyToClipboard(rawText); copyPromise.then(function() { const originalSvg = element.innerHTML; element.innerHTML = ""; setTimeout(() => { element.innerHTML = originalSvg; }, 1000); }).catch(function(err) { console.error("Failed to copy text: ", err); }); } function fallbackCopyToClipboard(text) { return new Promise((resolve, reject) => { const textArea = document.createElement("textarea"); textArea.value = text; textArea.style.position = "fixed"; textArea.style.left = "-9999px"; textArea.style.top = "-9999px"; document.body.appendChild(textArea); textArea.focus(); textArea.select(); try { const successful = document.execCommand("copy"); document.body.removeChild(textArea); successful ? resolve() : reject(); } catch (err) { document.body.removeChild(textArea); reject(err); } }); } function branchHere(element) { if (!element) return; const messageElement = element.closest(".message, .user-message, .assistant-message"); if (!messageElement) return; const index = messageElement.getAttribute("data-index"); if (!index) return; const branchIndexInput = document.getElementById("Branch-index").querySelector("input"); if (!branchIndexInput) { console.error("Element with ID 'Branch-index' not found."); return; } const branchButton = document.getElementById("Branch"); if (!branchButton) { console.error("Required element 'Branch' not found."); return; } branchIndexInput.value = index; // Trigger any 'change' or 'input' events Gradio might be listening for const event = new Event("input", { bubbles: true }); branchIndexInput.dispatchEvent(event); branchButton.click(); } // ------------------------------------------------- // Message Editing Functions // ------------------------------------------------- function editHere(buttonElement) { if (!buttonElement) return; const messageElement = buttonElement.closest(".message, .user-message, .assistant-message"); if (!messageElement) return; const messageBody = messageElement.querySelector(".message-body"); if (!messageBody) return; // If already editing, focus the textarea const existingTextarea = messageBody.querySelector(".editing-textarea"); if (existingTextarea) { existingTextarea.focus(); return; } // Determine role based on message element - handle different chat modes const isUserMessage = messageElement.classList.contains("user-message") || messageElement.querySelector(".text-you") !== null || messageElement.querySelector(".circle-you") !== null; startEditing(messageElement, messageBody, isUserMessage); } function startEditing(messageElement, messageBody, isUserMessage) { const rawText = messageElement.getAttribute("data-raw") || messageBody.textContent; const originalHTML = messageBody.innerHTML; // Create editing interface const editingInterface = createEditingInterface(rawText); // Replace message content messageBody.innerHTML = ""; messageBody.appendChild(editingInterface.textarea); messageBody.appendChild(editingInterface.controls); editingInterface.textarea.focus(); editingInterface.textarea.setSelectionRange(rawText.length, rawText.length); // Temporarily mark as scrolled to prevent auto-scroll const wasScrolled = window.isScrolled; window.isScrolled = true; // Scroll the textarea into view editingInterface.textarea.scrollIntoView({ behavior: "smooth", block: "center" }); // Restore the original scroll state after animation setTimeout(() => { window.isScrolled = wasScrolled; }, 500); // Setup event handlers setupEditingHandlers(editingInterface.textarea, messageElement, originalHTML, messageBody, isUserMessage); } function createEditingInterface(text) { const textarea = document.createElement("textarea"); textarea.value = text; textarea.className = "editing-textarea"; textarea.rows = Math.max(3, text.split("\n").length); const controls = document.createElement("div"); controls.className = "edit-controls-container"; const saveButton = document.createElement("button"); saveButton.textContent = "Save"; saveButton.className = "edit-control-button"; saveButton.type = "button"; const cancelButton = document.createElement("button"); cancelButton.textContent = "Cancel"; cancelButton.className = "edit-control-button edit-cancel-button"; cancelButton.type = "button"; controls.appendChild(saveButton); controls.appendChild(cancelButton); return { textarea, controls, saveButton, cancelButton }; } function setupEditingHandlers(textarea, messageElement, originalHTML, messageBody, isUserMessage) { const saveButton = messageBody.querySelector(".edit-control-button:not(.edit-cancel-button)"); const cancelButton = messageBody.querySelector(".edit-cancel-button"); const submitEdit = () => { const index = messageElement.getAttribute("data-index"); if (!index || !submitMessageEdit(index, textarea.value, isUserMessage)) { cancelEdit(); } }; const cancelEdit = () => { messageBody.innerHTML = originalHTML; }; // Event handlers saveButton.onclick = submitEdit; cancelButton.onclick = cancelEdit; textarea.onkeydown = (e) => { if (e.key === "Enter" && !e.shiftKey) { e.preventDefault(); submitEdit(); } else if (e.key === "Escape") { e.preventDefault(); cancelEdit(); } }; } function submitMessageEdit(index, newText, isUserMessage) { const editIndexInput = document.getElementById("Edit-message-index")?.querySelector("input"); const editTextInput = document.getElementById("Edit-message-text")?.querySelector("textarea"); const editRoleInput = document.getElementById("Edit-message-role")?.querySelector("textarea"); const editButton = document.getElementById("Edit-message"); if (!editIndexInput || !editTextInput || !editRoleInput || !editButton) { console.error("Edit elements not found"); return false; } editIndexInput.value = index; editTextInput.value = newText; editRoleInput.value = isUserMessage ? "user" : "assistant"; editIndexInput.dispatchEvent(new Event("input", { bubbles: true })); editTextInput.dispatchEvent(new Event("input", { bubbles: true })); editRoleInput.dispatchEvent(new Event("input", { bubbles: true })); editButton.click(); return true; } function navigateVersion(element, direction) { if (!element) return; const messageElement = element.closest(".message, .user-message, .assistant-message"); if (!messageElement) return; const index = messageElement.getAttribute("data-index"); if (!index) return; // Determine role based on message element classes let role = "assistant"; // Default role if (messageElement.classList.contains("user-message") || messageElement.querySelector(".text-you") || messageElement.querySelector(".circle-you")) { role = "user"; } const indexInput = document.getElementById("Navigate-message-index")?.querySelector("input"); const directionInput = document.getElementById("Navigate-direction")?.querySelector("textarea"); const roleInput = document.getElementById("Navigate-message-role")?.querySelector("textarea"); const navigateButton = document.getElementById("Navigate-version"); if (!indexInput || !directionInput || !roleInput || !navigateButton) { console.error("Navigation control elements (index, direction, role, or button) not found."); return; } indexInput.value = index; directionInput.value = direction; roleInput.value = role; // Trigger 'input' events for Gradio to pick up changes const event = new Event("input", { bubbles: true }); indexInput.dispatchEvent(event); directionInput.dispatchEvent(event); roleInput.dispatchEvent(event); navigateButton.click(); } function regenerateClick() { document.getElementById("Regenerate").click(); } function continueClick() { document.getElementById("Continue").click(); } function removeLastClick() { document.getElementById("Remove-last").click(); } function autoScrollToBottom() { if (!window.isScrolled) { const chatParent = document.getElementById("chat")?.parentNode?.parentNode?.parentNode; if (chatParent) { const maxScroll = chatParent.scrollHeight - chatParent.clientHeight; if (maxScroll > 0 && chatParent.scrollTop < maxScroll - 1) { chatParent.scrollTop = maxScroll; } } } } function updateInstructPadding() { const chatElement = document.getElementById("chat"); if (chatElement && chatElement.getAttribute("data-mode") === "instruct") { const messagesContainer = chatElement.querySelector(".messages"); const lastChild = messagesContainer?.lastElementChild; const prevSibling = lastChild?.previousElementSibling; if (lastChild && prevSibling && chatElement.offsetHeight > 0) { let bufferHeight = Math.max(0, Math.max(window.innerHeight - 128 - 84, window.innerHeight - prevSibling.offsetHeight - 84) - lastChild.offsetHeight); if (window.innerWidth <= 924) { bufferHeight = Math.max(0, bufferHeight - 32); } messagesContainer.style.paddingBottom = `${bufferHeight}px`; } } } let pendingMorphdomData = null; let morphdomRafId = null; function handleMorphdomUpdate(data) { pendingMorphdomData = data; if (!morphdomRafId) { morphdomRafId = requestAnimationFrame(() => { morphdomRafId = null; applyMorphdomUpdate(pendingMorphdomData); pendingMorphdomData = null; }); } } function applyMorphdomUpdate(data) { // Determine target element and use it as query scope var target_element, target_html; if (data.last_message_only) { const childNodes = document.getElementsByClassName("messages")[0].childNodes; target_element = childNodes[childNodes.length - 1]; target_html = data.html; } else { target_element = document.getElementById("chat").parentNode; target_html = "
" + data.html + "
"; } const queryScope = target_element; // Track open blocks and store their scroll positions const openBlocks = new Set(); const scrollPositions = {}; queryScope.querySelectorAll(".thinking-block").forEach(block => { const blockId = block.getAttribute("data-block-id"); if (blockId && block.hasAttribute("open")) { openBlocks.add(blockId); const content = block.querySelector(".thinking-content"); if (content) { const isAtBottom = Math.abs((content.scrollHeight - content.scrollTop) - content.clientHeight) < 5; scrollPositions[blockId] = { position: content.scrollTop, isAtBottom: isAtBottom }; } } }); morphdom( target_element, target_html, { onBeforeElUpdated: function(fromEl, toEl) { // Preserve code highlighting if (fromEl.tagName === "PRE") { const fromCode = fromEl.querySelector("code[data-highlighted]"); const toCode = toEl.querySelector("code"); if (fromCode && toCode && fromCode.textContent === toCode.textContent) { toEl.className = fromEl.className; toEl.innerHTML = fromEl.innerHTML; return false; } } // For thinking blocks, assume closed by default if (fromEl.classList && fromEl.classList.contains("thinking-block") && toEl.classList && toEl.classList.contains("thinking-block")) { const blockId = toEl.getAttribute("data-block-id"); // Remove open attribute by default toEl.removeAttribute("open"); // If this block was explicitly opened by user, keep it open if (blockId && openBlocks.has(blockId)) { toEl.setAttribute("open", ""); } } return !fromEl.isEqualNode(toEl); }, onElUpdated: function(el) { // Restore scroll positions for open thinking blocks if (el.classList && el.classList.contains("thinking-block") && el.hasAttribute("open")) { const blockId = el.getAttribute("data-block-id"); const content = el.querySelector(".thinking-content"); if (content && blockId && scrollPositions[blockId]) { setTimeout(() => { if (scrollPositions[blockId].isAtBottom) { content.scrollTop = content.scrollHeight; } else { content.scrollTop = scrollPositions[blockId].position; } }, 0); } } } } ); // Syntax highlighting and LaTeX if (window.doSyntaxHighlighting) { window.doSyntaxHighlighting(); } // Auto-scroll runs both before and after padding update. // Before: so content growth isn't hidden by padding absorption. // After: so padding-added space is also scrolled into view. autoScrollToBottom(); updateInstructPadding(); autoScrollToBottom(); // Add toggle listeners for new blocks queryScope.querySelectorAll(".thinking-block").forEach(block => { if (!block._hasToggleListener) { block.addEventListener("toggle", function(e) { const wasScrolled = window.isScrolled; if (this.open) { const content = this.querySelector(".thinking-content"); if (content) { setTimeout(() => { content.scrollTop = content.scrollHeight; }, 0); } } autoScrollToBottom(); updateInstructPadding(); autoScrollToBottom(); // Restore scroll state so the browser's layout adjustment // from the toggle doesn't disable auto-scroll window.isScrolled = wasScrolled; }); block._hasToggleListener = true; } }); } ================================================ FILE: js/katex/auto-render.js ================================================ ! function(e, t) { "object" == typeof exports && "object" == typeof module ? module.exports = t(require("katex")) : "function" == typeof define && define.amd ? define(["katex"], t) : "object" == typeof exports ? exports.renderMathInElement = t(require("katex")) : e.renderMathInElement = t(e.katex) }("undefined" != typeof self ? self : this, (function(e) { return function() { "use strict"; var t = { 771: function(t) { t.exports = e } }, n = {}; function r(e) { var o = n[e]; if (void 0 !== o) return o.exports; var i = n[e] = { exports: {} }; return t[e](i, i.exports, r), i.exports } r.n = function(e) { var t = e && e.__esModule ? function() { return e.default } : function() { return e }; return r.d(t, { a: t }), t }, r.d = function(e, t) { for (var n in t) r.o(t, n) && !r.o(e, n) && Object.defineProperty(e, n, { enumerable: !0, get: t[n] }) }, r.o = function(e, t) { return Object.prototype.hasOwnProperty.call(e, t) }; var o = {}; return function() { r.d(o, { default: function() { return d } }); var e = r(771), t = r.n(e); const n = function(e, t, n) { let r = n, o = 0; const i = e.length; for (; r < t.length;) { const n = t[r]; if (o <= 0 && t.slice(r, r + i) === e) return r; "\\" === n ? r++ : "{" === n ? o++ : "}" === n && o--, r++ } return -1 }, i = /^\\begin{/; var a = function(e, t) { let r; const o = [], a = new RegExp("(" + t.map((e => e.left.replace(/[-/\\^$*+?.()|[\]{}]/g, "\\$&"))).join("|") + ")"); for (; r = e.search(a), -1 !== r;) { const charAfterOpen = e[r + 1]; if (e[r] == "$" && charAfterOpen != "$") { const closeDollarIndex = e.indexOf('$', r + 1); if (closeDollarIndex != -1) { const charBeforeOpen = r > 0 ? e[r - 1] : ''; const charBeforeClose = r + 1 < closeDollarIndex ? e[closeDollarIndex - 1] : ''; const charBeforeBeforeClose = r + 1 < closeDollarIndex ? e[closeDollarIndex - 2] : ''; const charAfterClose = closeDollarIndex + 1 < e.length ? e[closeDollarIndex + 1] : ''; if ((/[A-Za-z0-9_$-]/.test(charBeforeOpen)) || ((' ' == charBeforeClose) || /[0-9]/.test(charAfterOpen) && (/[A-Za-z0-9]/.test(charAfterClose) || '-' == charBeforeClose))) { o.push({ type: "text", data: e.slice(0, r + 1), }); e = e.slice(r + 1); // now text starts after delimiter continue; } } } r > 0 && (o.push({ type: "text", data: e.slice(0, r) }), e = e.slice(r)); const a = t.findIndex((t => e.startsWith(t.left))); if (r = n(t[a].right, e, t[a].left.length), -1 === r) break; const l = e.slice(0, r + t[a].right.length), s = i.test(l) ? l : e.slice(t[a].left.length, r); o.push({ type: "math", data: s, rawData: l, display: t[a].display }), e = e.slice(r + t[a].right.length) } return "" !== e && o.push({ type: "text", data: e }), o }; const l = function(e, n) { const r = a(e, n.delimiters); if (1 === r.length && "text" === r[0].type) return null; const o = document.createDocumentFragment(); for (let e = 0; e < r.length; e++) if ("text" === r[e].type) o.appendChild(document.createTextNode(r[e].data)); else { const i = document.createElement("span"); let a = r[e].data; n.displayMode = r[e].display; try { n.preProcess && (a = n.preProcess(a)), t().render(a, i, n) } catch (i) { if (!(i instanceof t().ParseError)) throw i; n.errorCallback("KaTeX auto-render: Failed to parse `" + r[e].data + "` with ", i), o.appendChild(document.createTextNode(r[e].rawData)); continue } o.appendChild(i) } return o }, s = function(e, t) { for (let n = 0; n < e.childNodes.length; n++) { const r = e.childNodes[n]; if (3 === r.nodeType) { let o = r.textContent, i = r.nextSibling, a = 0; for (; i && i.nodeType === Node.TEXT_NODE;) o += i.textContent, i = i.nextSibling, a++; const s = l(o, t); if (s) { for (let e = 0; e < a; e++) r.nextSibling.remove(); n += s.childNodes.length - 1, e.replaceChild(s, r) } else n += a } else if (1 === r.nodeType) { const e = " " + r.className + " "; - 1 === t.ignoredTags.indexOf(r.nodeName.toLowerCase()) && t.ignoredClasses.every((t => -1 === e.indexOf(" " + t + " "))) && s(r, t) } } }; var d = function(e, t) { if (!e) throw new Error("No element provided to render"); const n = {}; for (const e in t) t.hasOwnProperty(e) && (n[e] = t[e]); n.delimiters = n.delimiters || [{ left: "$$", right: "$$", display: !0 }, { left: "\\(", right: "\\)", display: !1 }, { left: "\\begin{equation}", right: "\\end{equation}", display: !0 }, { left: "\\begin{align}", right: "\\end{align}", display: !0 }, { left: "\\begin{alignat}", right: "\\end{alignat}", display: !0 }, { left: "\\begin{gather}", right: "\\end{gather}", display: !0 }, { left: "\\begin{CD}", right: "\\end{CD}", display: !0 }, { left: "\\[", right: "\\]", display: !0 }], n.ignoredTags = n.ignoredTags || ["script", "noscript", "style", "textarea", "pre", "code", "option"], n.ignoredClasses = n.ignoredClasses || [], n.errorCallback = n.errorCallback || console.error, n.macros = n.macros || {}, s(e, n) } }(), o = o.default }() })); ================================================ FILE: js/main.js ================================================ // ------------------------------------------------ // Main // ------------------------------------------------ // Sync highlight.js theme with the actual Gradio theme var defined_hljs_css = document.body.classList.contains("dark") ? "file/css/highlightjs/github-dark.min.css" : "file/css/highlightjs/github.min.css"; if (document.getElementById("highlight-css").getAttribute("href") !== defined_hljs_css) { document.getElementById("highlight-css").setAttribute("href", defined_hljs_css); } let main_parent = document.getElementById("chat-tab").parentNode; let extensions = document.getElementById("extensions"); main_parent.childNodes[0].classList.add("header_bar"); main_parent.style = "padding: 0; margin: 0"; main_parent.parentNode.style = "gap: 0"; main_parent.parentNode.parentNode.style = "padding: 0"; document.querySelector(".header_bar").addEventListener("click", function(event) { if (event.target.tagName !== "BUTTON") return; const buttonText = event.target.textContent.trim(); const extensionsVisible = ["Chat", "Default", "Notebook"].includes(buttonText); const chatVisible = buttonText === "Chat"; const showControlsChecked = document.querySelector("#show-controls input").checked; const extensions = document.querySelector("#extensions"); if (extensionsVisible) { if (extensions) { extensions.style.display = "flex"; } this.style.marginBottom = chatVisible ? "0px" : "19px"; if (chatVisible && !showControlsChecked) { document.querySelectorAll("#extensions").forEach(element => { element.style.display = "none"; }); } } else { this.style.marginBottom = "19px"; if (extensions) extensions.style.display = "none"; } }); //------------------------------------------------ // Keyboard shortcuts //------------------------------------------------ // --- Helper functions --- // function isModifiedKeyboardEvent() { return (event instanceof KeyboardEvent && event.shiftKey || event.ctrlKey || event.altKey || event.metaKey); } function isFocusedOnEditableTextbox() { if (event.target.tagName === "INPUT" || event.target.tagName === "TEXTAREA") { return !!event.target.value; } } let previousTabId = "chat-tab-button"; document.addEventListener("keydown", function(event) { // Stop generation on Esc pressed if (event.key === "Escape") { // Find the element with id 'stop' and click it var stopButton = document.getElementById("stop"); if (stopButton) { stopButton.click(); } return; } if (!document.querySelector("#chat-tab").checkVisibility() ) { return; } // Show chat controls on Ctrl + S if (event.ctrlKey && event.key == "s") { event.preventDefault(); var showControlsElement = document.getElementById("show-controls"); if (showControlsElement && showControlsElement.childNodes.length >= 4) { showControlsElement.childNodes[3].click(); var arr = document.getElementById("chat-input").childNodes[2].childNodes; arr[arr.length - 1].focus(); } } // Regenerate on Ctrl + Enter else if (event.ctrlKey && event.key === "Enter") { event.preventDefault(); document.getElementById("Regenerate").click(); } // Continue on Alt + Enter else if (event.altKey && event.key === "Enter") { event.preventDefault(); document.getElementById("Continue").click(); } // Remove last on Ctrl + Shift + Backspace else if (event.ctrlKey && event.shiftKey && event.key === "Backspace") { event.preventDefault(); document.getElementById("Remove-last").click(); } // Impersonate on Ctrl + Shift + M else if (event.ctrlKey && event.shiftKey && event.key === "M") { event.preventDefault(); document.getElementById("Impersonate").click(); } // --- Simple version navigation --- // if (!isFocusedOnEditableTextbox()) { // Version navigation on Arrow keys (horizontal) if (!isModifiedKeyboardEvent() && event.key === "ArrowLeft") { event.preventDefault(); navigateLastAssistantMessage("left"); } else if (!isModifiedKeyboardEvent() && event.key === "ArrowRight") { event.preventDefault(); if (!navigateLastAssistantMessage("right")) { // If can't navigate right (last version), regenerate document.getElementById("Regenerate").click(); } } } }); //------------------------------------------------ // Position the chat typing dots //------------------------------------------------ typing = document.getElementById("typing-container"); typingParent = typing.parentNode; typingSibling = typing.previousElementSibling; typingSibling.insertBefore(typing, typingSibling.childNodes[2]); //------------------------------------------------ // Chat scrolling //------------------------------------------------ const targetElement = document.getElementById("chat").parentNode.parentNode.parentNode; targetElement.classList.add("pretty_scrollbar"); targetElement.classList.add("chat-parent"); window.isScrolled = false; let scrollTimeout; let lastScrollTop = 0; let lastScrollHeight = 0; let lastClientHeight = 0; targetElement.addEventListener("scroll", function() { let diff = targetElement.scrollHeight - targetElement.clientHeight; let isAtBottomNow = Math.abs(targetElement.scrollTop - diff) <= 10 || diff <= 0; // Add scrolling class to disable hover effects if (window.isScrolled || !isAtBottomNow) { targetElement.classList.add("scrolling"); } if(isAtBottomNow) { window.isScrolled = false; } else if (targetElement.scrollTop < lastScrollTop && targetElement.scrollHeight >= lastScrollHeight && targetElement.clientHeight <= lastClientHeight) { window.isScrolled = true; } lastScrollTop = targetElement.scrollTop; lastScrollHeight = targetElement.scrollHeight; lastClientHeight = targetElement.clientHeight; // Clear previous timeout and set new one clearTimeout(scrollTimeout); scrollTimeout = setTimeout(() => { targetElement.classList.remove("scrolling"); doSyntaxHighlighting(); // Only run after scrolling stops }, 150); }); // Create a MutationObserver instance const observer = new MutationObserver(function() { if (targetElement.classList.contains("_generating")) { typing.parentNode.classList.add("visible-dots"); document.getElementById("stop").style.display = "flex"; document.getElementById("Generate").style.display = "none"; // If the user is near the bottom, ensure auto-scroll is enabled // for the new reply. This catches cases where isScrolled was // incorrectly set to true by layout shifts during page load, etc. const diff = targetElement.scrollHeight - targetElement.clientHeight; if (Math.abs(targetElement.scrollTop - diff) <= 10 || diff <= 0) { window.isScrolled = false; } } else { typing.parentNode.classList.remove("visible-dots"); document.getElementById("stop").style.display = "none"; document.getElementById("Generate").style.display = "flex"; } }); // Only watch for attribute changes on targetElement (e.g. _generating class) const config = { attributes: true }; // Start observing the target element observer.observe(targetElement, config); //------------------------------------------------ // Handle syntax highlighting / LaTeX //------------------------------------------------ function isElementVisibleOnScreen(element) { const rect = element.getBoundingClientRect(); return ( rect.left < window.innerWidth && rect.right > 0 && rect.top < window.innerHeight && rect.bottom > 0 ); } window.doSyntaxHighlighting = function() { const messageBodies = document.getElementById("chat").querySelectorAll(".message-body"); if (messageBodies.length > 0) { let hasSeenVisible = false; // Go from last message to first for (let i = messageBodies.length - 1; i >= 0; i--) { const messageBody = messageBodies[i]; if (isElementVisibleOnScreen(messageBody)) { hasSeenVisible = true; // Handle both code and math in a single pass through each message const codeBlocks = messageBody.querySelectorAll("pre code:not([data-highlighted])"); codeBlocks.forEach((codeBlock) => { hljs.highlightElement(codeBlock); codeBlock.setAttribute("data-highlighted", "true"); codeBlock.classList.add("pretty_scrollbar"); }); // Only render math in visible elements const mathContainers = messageBody.querySelectorAll("p, span, li, td, th, h1, h2, h3, h4, h5, h6, blockquote, figcaption, caption, dd, dt"); mathContainers.forEach(container => { if (isElementVisibleOnScreen(container)) { renderMathInElement(container, { delimiters: [ { left: "$$", right: "$$", display: true }, { left: "$", right: "$", display: false }, { left: "\\(", right: "\\)", display: false }, { left: "\\[", right: "\\]", display: true }, ], }); } }); } else if (hasSeenVisible) { // We've seen visible messages but this one is not visible // Since we're going from last to first, we can break break; } } } } const doSyntaxHighlighting = window.doSyntaxHighlighting; //------------------------------------------------ // Add some scrollbars //------------------------------------------------ const scrollbarElements = document.querySelectorAll(".add_scrollbar textarea, .add_scrollbar .drag-drop-list"); for(i = 0; i < scrollbarElements.length; i++) { scrollbarElements[i].classList.remove("scroll-hide"); scrollbarElements[i].classList.add("pretty_scrollbar"); scrollbarElements[i].style.resize = "none"; } //------------------------------------------------ // Tools: inject "Refresh list" link into the label //------------------------------------------------ const toolsTitle = document.querySelector("#tools-group > [data-testid='block-info']"); const toolsInfo = toolsTitle ? toolsTitle.nextElementSibling : null; if (toolsInfo) { const refreshLink = document.createElement("span"); refreshLink.textContent = " [Refresh list]"; refreshLink.className = "tools-refresh-link"; refreshLink.addEventListener("click", function(e) { e.preventDefault(); document.querySelector("#tools-refresh-btn").click(); }); toolsInfo.appendChild(refreshLink); } //------------------------------------------------ // Remove some backgrounds //------------------------------------------------ const noBackgroundelements = document.querySelectorAll(".no-background"); for(i = 0; i < noBackgroundelements.length; i++) { noBackgroundelements[i].parentNode.style.border = "none"; noBackgroundelements[i].parentNode.parentNode.parentNode.style.alignItems = "center"; } const slimDropdownElements = document.querySelectorAll(".slim-dropdown"); for (i = 0; i < slimDropdownElements.length; i++) { const parentNode = slimDropdownElements[i].parentNode; parentNode.style.background = "transparent"; parentNode.style.border = "0"; } //------------------------------------------------ // Create the hover menu in the chat tab // The show/hide events were adapted from: // https://github.com/SillyTavern/SillyTavern/blob/6c8bd06308c69d51e2eb174541792a870a83d2d6/public/script.js //------------------------------------------------ var buttonsInChat = document.querySelectorAll("#chat-tab #chat-buttons button, #chat-tab #chat-buttons #show-controls"); var button = document.getElementById("hover-element-button"); var menu = document.getElementById("hover-menu"); var istouchscreen = (navigator.maxTouchPoints > 0) || "ontouchstart" in document.documentElement; function showMenu() { menu.style.display = "flex"; // Show the menu } function hideMenu() { menu.style.display = "none"; // Hide the menu if (!istouchscreen) { document.querySelector("#chat-input textarea").focus(); // Focus on the chat input } } if (buttonsInChat.length > 0) { for (let i = buttonsInChat.length - 1; i >= 0; i--) { const thisButton = buttonsInChat[i]; menu.appendChild(thisButton); // Only apply transformations to button elements if (thisButton.tagName.toLowerCase() === "button") { thisButton.addEventListener("click", () => { hideMenu(); }); const buttonText = thisButton.textContent; const matches = buttonText.match(/(\(.*?\))/); if (matches && matches.length > 1) { // Apply the transparent-substring class to the matched substring const substring = matches[1]; const newText = buttonText.replace(substring, ` ${substring.slice(1, -1)}`); thisButton.innerHTML = newText; } } } } function isMouseOverButtonOrMenu() { return menu.matches(":hover") || button.matches(":hover"); } button.addEventListener("mouseenter", function () { if (!istouchscreen) { showMenu(); } }); button.addEventListener("click", function () { if (menu.style.display === "flex") { hideMenu(); } else { showMenu(); } }); // Add event listener for mouseleave on the button button.addEventListener("mouseleave", function () { // Delay to prevent menu hiding when the mouse leaves the button into the menu setTimeout(function () { if (!isMouseOverButtonOrMenu()) { hideMenu(); } }, 100); }); // Add event listener for mouseleave on the menu menu.addEventListener("mouseleave", function () { // Delay to prevent menu hide when the mouse leaves the menu into the button setTimeout(function () { if (!isMouseOverButtonOrMenu()) { hideMenu(); } }, 100); }); // Add event listener for click anywhere in the document document.addEventListener("click", function (event) { const target = event.target; // Check if the click is outside the button/menu and the menu is visible if (!isMouseOverButtonOrMenu() && menu.style.display === "flex") { hideMenu(); } if (event.target.classList.contains("pfp_character")) { toggleBigPicture(); } // Handle sidebar clicks on mobile if (isMobile()) { // Check if the click did NOT originate from any of the specified toggle buttons or elements if ( target.closest("#navigation-toggle") !== navigationToggle && target.closest("#past-chats-toggle") !== pastChatsToggle && target.closest("#chat-controls-toggle") !== chatControlsToggle && target.closest(".header_bar") !== headerBar && target.closest("#past-chats-row") !== pastChatsRow && target.closest("#chat-controls") !== chatControlsRow ) { handleIndividualSidebarClose(event); } } }); //------------------------------------------------ // Position the chat input //------------------------------------------------ document.getElementById("chat-input-row").classList.add("chat-input-positioned"); //------------------------------------------------ // Focus on the chat input //------------------------------------------------ const chatTextArea = document.getElementById("chat-input").querySelector("textarea"); function respondToChatInputVisibility(element, callback) { var options = { root: document.documentElement, }; var observer = new IntersectionObserver((entries, observer) => { entries.forEach(entry => { callback(entry.intersectionRatio > 0); }); }, options); observer.observe(element); } function handleChatInputVisibilityChange(isVisible) { if (isVisible) { chatTextArea.focus(); } } respondToChatInputVisibility(chatTextArea, handleChatInputVisibilityChange); //------------------------------------------------ // Show enlarged character picture when the profile // picture is clicked on //------------------------------------------------ let bigPictureVisible = false; function addBigPicture() { var imgElement = document.createElement("img"); var timestamp = new Date().getTime(); imgElement.src = "/file/user_data/cache/pfp_character.png?time=" + timestamp; imgElement.classList.add("bigProfilePicture"); imgElement.addEventListener("load", function () { this.style.visibility = "visible"; }); imgElement.addEventListener("error", function () { this.style.visibility = "hidden"; }); var imgElementParent = document.getElementById("chat").parentNode.parentNode.parentNode.parentNode.parentNode.parentNode.parentNode; imgElementParent.appendChild(imgElement); } function deleteBigPicture() { var bigProfilePictures = document.querySelectorAll(".bigProfilePicture"); bigProfilePictures.forEach(function (element) { element.parentNode.removeChild(element); }); } function toggleBigPicture() { if(bigPictureVisible) { deleteBigPicture(); bigPictureVisible = false; } else { addBigPicture(); bigPictureVisible = true; } } //------------------------------------------------ // Handle the chat input box growth //------------------------------------------------ // Cache DOM elements const chatContainer = document.getElementById("chat").parentNode.parentNode.parentNode; const chatInput = document.querySelector("#chat-input textarea"); // Variables to store current dimensions let currentChatInputHeight = chatInput.clientHeight; //------------------------------------------------ // Focus on the rename text area when it becomes visible //------------------------------------------------ const renameTextArea = document.getElementById("rename-row").querySelector("textarea"); function respondToRenameVisibility(element, callback) { var options = { root: document.documentElement, }; var observer = new IntersectionObserver((entries, observer) => { entries.forEach(entry => { callback(entry.intersectionRatio > 0); }); }, options); observer.observe(element); } function handleVisibilityChange(isVisible) { if (isVisible) { renameTextArea.focus(); } } respondToRenameVisibility(renameTextArea, handleVisibilityChange); //------------------------------------------------ // Adjust the chat tab margin if no extension UI // is present at the bottom //------------------------------------------------ if (document.getElementById("extensions") === null) { document.getElementById("chat-tab").style.marginBottom = "-29px"; } //------------------------------------------------ // Focus on the chat input after starting a new chat //------------------------------------------------ document.querySelectorAll(".focus-on-chat-input").forEach(element => { element.addEventListener("click", function() { document.querySelector("#chat-input textarea").focus(); }); }); //------------------------------------------------ // "New chat" hover menu with incognito option //------------------------------------------------ (function() { const newChatBtn = document.getElementById("new-chat-btn"); const wrapper = document.createElement("div"); wrapper.id = "new-chat-wrapper"; newChatBtn.replaceWith(wrapper); wrapper.appendChild(newChatBtn); const arrow = document.createElement("span"); arrow.className = "new-chat-arrow"; arrow.textContent = "\u25BE"; const menu = document.createElement("div"); menu.className = "new-chat-menu"; const option = document.createElement("div"); option.className = "new-chat-menu-item"; option.textContent = "Incognito chat"; menu.appendChild(option); arrow.appendChild(menu); wrapper.appendChild(arrow); option.addEventListener("click", function(e) { e.stopPropagation(); document.querySelector("#incognito-chat-btn").click(); }); })(); //------------------------------------------------ // Fix a border around the "past chats" menu //------------------------------------------------ document.getElementById("past-chats").parentNode.style.borderRadius = "0px"; //------------------------------------------------ // Allow the character dropdown to coexist at the // Chat tab and the Parameters > Character tab //------------------------------------------------ const headerBar = document.querySelector(".header_bar"); let originalParent; let originalIndex; // To keep track of the original position let movedElement; function moveToChatTab() { const characterMenu = document.getElementById("character-menu"); const grandParent = characterMenu.parentElement.parentElement; // Save the initial location for the character dropdown if (!originalParent) { originalParent = grandParent.parentElement; originalIndex = Array.from(originalParent.children).indexOf(grandParent); movedElement = grandParent; } // Do not show the Character dropdown in the Chat tab when "instruct" mode is selected const instructRadio = document.querySelector("#chat-mode input[value=\"instruct\"]"); if (instructRadio && instructRadio.checked) { grandParent.style.display = "none"; } grandParent.children[0].style.minWidth = "100%"; const chatControlsFirstChild = document.querySelector("#chat-controls").firstElementChild; const newParent = chatControlsFirstChild; let newPosition = newParent.children.length - 3; newParent.insertBefore(grandParent, newParent.children[newPosition]); document.getElementById("save-character").style.display = "none"; document.getElementById("restore-character").style.display = "none"; } function restoreOriginalPosition() { if (originalParent && movedElement) { if (originalIndex >= originalParent.children.length) { originalParent.appendChild(movedElement); } else { originalParent.insertBefore(movedElement, originalParent.children[originalIndex]); } document.getElementById("save-character").style.display = ""; document.getElementById("restore-character").style.display = ""; movedElement.style.display = ""; movedElement.children[0].style.minWidth = ""; } } headerBar.addEventListener("click", (e) => { if (e.target.tagName === "BUTTON") { const tabName = e.target.textContent.trim(); if (tabName === "Chat") { moveToChatTab(); } else { restoreOriginalPosition(); } } }); //------------------------------------------------ // Add a confirmation dialog when leaving the page // Useful to avoid data loss //------------------------------------------------ window.addEventListener("beforeunload", function (event) { // Cancel the event event.preventDefault(); // Chrome requires returnValue to be set event.returnValue = ""; }); moveToChatTab(); //------------------------------------------------ // Buttons to toggle the sidebars //------------------------------------------------ const leftArrowSVG = ` `; const rightArrowSVG = ` `; const hamburgerMenuSVG = ` `; const closeMenuSVG = ` `; const chatTab = document.getElementById("chat-tab"); const pastChatsRow = document.getElementById("past-chats-row"); const chatControlsRow = document.getElementById("chat-controls"); if (chatTab) { // Create past-chats-toggle div const pastChatsToggle = document.createElement("div"); pastChatsToggle.id = "past-chats-toggle"; pastChatsToggle.innerHTML = leftArrowSVG; // Set initial icon to left arrow pastChatsToggle.classList.add("past-chats-open"); // Set initial position // Create chat-controls-toggle div const chatControlsToggle = document.createElement("div"); chatControlsToggle.id = "chat-controls-toggle"; chatControlsToggle.innerHTML = rightArrowSVG; // Set initial icon to right arrow chatControlsToggle.classList.add("chat-controls-open"); // Set initial position // Append both elements to the chat-tab chatTab.appendChild(pastChatsToggle); chatTab.appendChild(chatControlsToggle); } // Create navigation toggle div const navigationToggle = document.createElement("div"); navigationToggle.id = "navigation-toggle"; navigationToggle.innerHTML = leftArrowSVG; // Set initial icon to right arrow navigationToggle.classList.add("navigation-left"); // Set initial position headerBar.appendChild(navigationToggle); // Retrieve the dynamically created toggle buttons const pastChatsToggle = document.getElementById("past-chats-toggle"); const chatControlsToggle = document.getElementById("chat-controls-toggle"); function handleIndividualSidebarClose(event) { const target = event.target; // Close navigation bar if click is outside and it is open if (!headerBar.contains(target) && !headerBar.classList.contains("sidebar-hidden")) { toggleSidebar(headerBar, navigationToggle, true); } // Close past chats row if click is outside and it is open if (!pastChatsRow.contains(target) && !pastChatsRow.classList.contains("sidebar-hidden")) { toggleSidebar(pastChatsRow, pastChatsToggle, true); } // Close chat controls row if click is outside and it is open if (!chatControlsRow.contains(target) && !chatControlsRow.classList.contains("sidebar-hidden")) { toggleSidebar(chatControlsRow, chatControlsToggle, true); } } function toggleSidebar(sidebar, toggle, forceClose = false) { const isCurrentlyHidden = sidebar.classList.contains("sidebar-hidden"); const shouldClose = !isCurrentlyHidden; // Apply visibility classes sidebar.classList.toggle("sidebar-hidden", shouldClose); sidebar.classList.toggle("sidebar-shown", !shouldClose); if (sidebar === headerBar) { // Special handling for header bar document.documentElement.style.setProperty("--header-width", shouldClose ? "0px" : "112px"); pastChatsRow.classList.toggle("negative-header", shouldClose); pastChatsToggle.classList.toggle("negative-header", shouldClose); toggle.innerHTML = shouldClose ? hamburgerMenuSVG : closeMenuSVG; } else if (sidebar === pastChatsRow) { // Past chats sidebar toggle.classList.toggle("past-chats-closed", shouldClose); toggle.classList.toggle("past-chats-open", !shouldClose); toggle.innerHTML = shouldClose ? rightArrowSVG : leftArrowSVG; } else if (sidebar === chatControlsRow) { // Chat controls sidebar toggle.classList.toggle("chat-controls-closed", shouldClose); toggle.classList.toggle("chat-controls-open", !shouldClose); toggle.innerHTML = shouldClose ? leftArrowSVG : rightArrowSVG; } // Mobile handling if (isMobile()) { sidebar.classList.toggle("sidebar-shown", !shouldClose); } } // Function to check if the device is mobile function isMobile() { return window.innerWidth <= 924; } // Function to initialize sidebars function initializeSidebars() { const isOnMobile = isMobile(); if (isOnMobile) { // Mobile state: Hide sidebars and set closed states [pastChatsRow, chatControlsRow, headerBar].forEach(el => { el.classList.add("sidebar-hidden"); el.classList.remove("sidebar-shown"); }); document.documentElement.style.setProperty("--header-width", "0px"); pastChatsRow.classList.add("negative-header"); pastChatsToggle.classList.add("negative-header", "past-chats-closed"); pastChatsToggle.classList.remove("past-chats-open"); [chatControlsToggle, navigationToggle].forEach(el => { el.classList.add("chat-controls-closed"); el.classList.remove("chat-controls-open"); }); pastChatsToggle.innerHTML = rightArrowSVG; chatControlsToggle.innerHTML = leftArrowSVG; navigationToggle.innerHTML = hamburgerMenuSVG; } else { // Desktop state: Show sidebars and set open states [pastChatsRow, chatControlsRow].forEach(el => { el.classList.remove("sidebar-hidden", "sidebar-shown"); }); pastChatsToggle.classList.add("past-chats-open"); pastChatsToggle.classList.remove("past-chats-closed"); [chatControlsToggle, navigationToggle].forEach(el => { el.classList.add("chat-controls-open"); el.classList.remove("chat-controls-closed"); }); pastChatsToggle.innerHTML = leftArrowSVG; chatControlsToggle.innerHTML = rightArrowSVG; navigationToggle.innerHTML = closeMenuSVG; } } // Run the initializer when the page loads initializeSidebars(); // Add click event listeners to toggle buttons pastChatsToggle.addEventListener("click", () => { const isCurrentlyOpen = !pastChatsRow.classList.contains("sidebar-hidden"); toggleSidebar(pastChatsRow, pastChatsToggle); // On desktop, open/close both sidebars at the same time if (!isMobile()) { if (isCurrentlyOpen) { // If we just closed the left sidebar, also close the right sidebar if (!chatControlsRow.classList.contains("sidebar-hidden")) { toggleSidebar(chatControlsRow, chatControlsToggle, true); } } else { // If we just opened the left sidebar, also open the right sidebar if (chatControlsRow.classList.contains("sidebar-hidden")) { toggleSidebar(chatControlsRow, chatControlsToggle, false); } } } }); chatControlsToggle.addEventListener("click", () => { const isCurrentlyOpen = !chatControlsRow.classList.contains("sidebar-hidden"); toggleSidebar(chatControlsRow, chatControlsToggle); // On desktop, open/close both sidebars at the same time if (!isMobile()) { if (isCurrentlyOpen) { // If we just closed the right sidebar, also close the left sidebar if (!pastChatsRow.classList.contains("sidebar-hidden")) { toggleSidebar(pastChatsRow, pastChatsToggle, true); } } else { // If we just opened the right sidebar, also open the left sidebar if (pastChatsRow.classList.contains("sidebar-hidden")) { toggleSidebar(pastChatsRow, pastChatsToggle, false); } } } }); navigationToggle.addEventListener("click", () => { toggleSidebar(headerBar, navigationToggle); }); //------------------------------------------------ // Fixes #chat-input textarea height issue // for devices with width <= 924px //------------------------------------------------ if (isMobile()) { // Target the textarea const textarea = document.querySelector("#chat-input textarea"); if (textarea) { // Simulate adding and removing a newline textarea.value += "\n"; textarea.dispatchEvent(new Event("input", { bubbles: true })); textarea.value = textarea.value.slice(0, -1); textarea.dispatchEvent(new Event("input", { bubbles: true })); } } //------------------------------------------------ // Create a top navigation bar on mobile //------------------------------------------------ function createMobileTopBar() { const chatTab = document.getElementById("chat-tab"); // Only create the top bar if it doesn't already exist if (chatTab && !chatTab.querySelector(".mobile-top-bar")) { const topBar = document.createElement("div"); topBar.classList.add("mobile-top-bar"); // Insert the top bar as the first child of chat-tab chatTab.appendChild(topBar); } } createMobileTopBar(); //------------------------------------------------ // Simple Navigation Functions //------------------------------------------------ function navigateLastAssistantMessage(direction) { const chat = document.querySelector("#chat"); if (!chat) return false; const messages = chat.querySelectorAll("[data-index]"); if (messages.length === 0) return false; // Find the last assistant message (starting from the end) let lastAssistantMessage = null; for (let i = messages.length - 1; i >= 0; i--) { const msg = messages[i]; if ( msg.classList.contains("assistant-message") || msg.querySelector(".circle-bot") || msg.querySelector(".text-bot") ) { lastAssistantMessage = msg; break; } } if (!lastAssistantMessage) return false; const buttons = lastAssistantMessage.querySelectorAll(".version-nav-button"); for (let i = 0; i < buttons.length; i++) { const button = buttons[i]; const onclick = button.getAttribute("onclick"); const disabled = button.hasAttribute("disabled"); const isLeft = onclick && onclick.includes("'left'"); const isRight = onclick && onclick.includes("'right'"); if (!disabled) { if (direction === "left" && isLeft) { navigateVersion(button, direction); return true; } if (direction === "right" && isRight) { navigateVersion(button, direction); return true; } } } return false; } //------------------------------------------------ // Paste Handler for Long Text //------------------------------------------------ const MAX_PLAIN_TEXT_LENGTH = 2500; function setupPasteHandler() { const textbox = document.querySelector("#chat-input textarea[data-testid=\"textbox\"]"); const fileInput = document.querySelector("#chat-input input[data-testid=\"file-upload\"]"); if (!textbox || !fileInput) { setTimeout(setupPasteHandler, 500); return; } textbox.addEventListener("paste", async (event) => { const text = event.clipboardData?.getData("text"); if (text && text.length > MAX_PLAIN_TEXT_LENGTH && document.querySelector("#paste_to_attachment input[data-testid=\"checkbox\"]")?.checked) { event.preventDefault(); const file = new File([text], "pasted_text.txt", { type: "text/plain", lastModified: Date.now() }); const dataTransfer = new DataTransfer(); dataTransfer.items.add(file); fileInput.files = dataTransfer.files; fileInput.dispatchEvent(new Event("change", { bubbles: true })); } }); } if (document.readyState === "loading") { document.addEventListener("DOMContentLoaded", setupPasteHandler); } else { setupPasteHandler(); } //------------------------------------------------ // Tooltips //------------------------------------------------ // File upload button document.querySelector("#chat-input .upload-button").title = "Upload text files, PDFs, DOCX documents, and images"; // Activate web search document.getElementById("web-search").title = "Search the internet with DuckDuckGo"; //------------------------------------------------ // Inline icons for deleting past chats //------------------------------------------------ function addMiniDeletes() { document.querySelectorAll("#past-chats label:not(.has-delete)").forEach(label => { const container = document.createElement("span"); container.className = "delete-container"; label.classList.add("chat-label-with-delete"); const trashBtn = document.createElement("button"); trashBtn.innerHTML = "🗑️"; trashBtn.className = "trash-btn"; const cancelBtn = document.createElement("button"); cancelBtn.innerHTML = "✕"; cancelBtn.className = "cancel-btn"; const confirmBtn = document.createElement("button"); confirmBtn.innerHTML = "✓"; confirmBtn.className = "confirm-btn"; label.addEventListener("mouseenter", () => { container.style.opacity = "1"; }); label.addEventListener("mouseleave", () => { container.style.opacity = "0"; }); trashBtn.onclick = (e) => { e.stopPropagation(); label.querySelector("input").click(); document.querySelector("#delete_chat").click(); trashBtn.style.display = "none"; cancelBtn.style.display = "flex"; confirmBtn.style.display = "flex"; }; cancelBtn.onclick = (e) => { e.stopPropagation(); document.querySelector("#delete_chat-cancel").click(); resetButtons(); }; confirmBtn.onclick = (e) => { e.stopPropagation(); document.querySelector("#delete_chat-confirm").click(); resetButtons(); }; function resetButtons() { trashBtn.style.display = "inline"; cancelBtn.style.display = "none"; confirmBtn.style.display = "none"; } container.append(trashBtn, cancelBtn, confirmBtn); label.appendChild(container); label.classList.add("has-delete"); }); } new MutationObserver(() => addMiniDeletes()).observe( document.querySelector("#past-chats"), {childList: true, subtree: true} ); addMiniDeletes(); //------------------------------------------------ // Fix autoscroll after fonts load //------------------------------------------------ document.fonts.addEventListener("loadingdone", (event) => { setTimeout(() => { if (!window.isScrolled) { const maxScroll = targetElement.scrollHeight - targetElement.clientHeight; if (targetElement.scrollTop < maxScroll - 5) { targetElement.scrollTop = maxScroll; } } }, 50); }); (function() { const chatParent = document.querySelector(".chat-parent"); const chatInputRow = document.querySelector("#chat-input-row"); const originalMarginBottom = 75; let originalHeight = chatInputRow.offsetHeight; function updateMargin() { const currentHeight = chatInputRow.offsetHeight; const heightDifference = currentHeight - originalHeight; chatParent.style.marginBottom = `${originalMarginBottom + heightDifference}px`; if (!window.isScrolled) { chatParent.scrollTop = chatParent.scrollHeight - chatParent.clientHeight; } } // Watch for size changes that affect height new ResizeObserver(updateMargin).observe(chatInputRow); // Also listen for window resize window.addEventListener("resize", updateMargin); // Initial call to set the margin based on current state updateMargin(); })(); ================================================ FILE: js/save_files.js ================================================ // Functions for downloading JSON files function getCurrentTimestamp() { const now = new Date(); const timezoneOffset = now.getTimezoneOffset() * 60000; // Convert to milliseconds const localTime = new Date(now.getTime() - timezoneOffset); const formattedTimestamp = localTime.toISOString().replace(/[-:]/g, "").slice(0, 15); return formattedTimestamp; } function saveFile(contents, filename) { const element = document.createElement("a"); element.setAttribute("href", "data:text/plain;charset=utf-8," + encodeURIComponent(contents)); element.setAttribute("download", filename); element.style.display = "none"; document.body.appendChild(element); element.click(); document.body.removeChild(element); } function saveHistory(history, character, mode) { let path = null; if (["chat", "chat-instruct"].includes(mode) && character && character.trim() !== "") { path = `history_${character}_${getCurrentTimestamp()}.json`; } else { try { path = `history_${mode}_${getCurrentTimestamp()}.json`; } catch (error) { path = `history_${getCurrentTimestamp()}.json`; } } saveFile(history, path); } function saveSession(session) { let path = null; path = `session_${getCurrentTimestamp()}.json`; saveFile(session, path); } ================================================ FILE: js/show_controls.js ================================================ const chatParent = document.querySelector(".chat-parent"); function toggle_controls(value) { const extensions = document.querySelector("#extensions"); if (value) { // SHOW MODE: Click toggles to show hidden sidebars const navToggle = document.getElementById("navigation-toggle"); const pastChatsToggle = document.getElementById("past-chats-toggle"); if (navToggle && document.querySelector(".header_bar")?.classList.contains("sidebar-hidden")) { navToggle.click(); } if (pastChatsToggle && document.getElementById("past-chats-row")?.classList.contains("sidebar-hidden")) { pastChatsToggle.click(); } // Show extensions only if (extensions) { extensions.style.display = "inherit"; } let gallery_element = document.getElementById("gallery-extension"); if (gallery_element) { gallery_element.style.display = "block"; } } else { // HIDE MODE: Click toggles to hide visible sidebars const navToggle = document.getElementById("navigation-toggle"); const pastChatsToggle = document.getElementById("past-chats-toggle"); if (navToggle && !document.querySelector(".header_bar")?.classList.contains("sidebar-hidden")) { navToggle.click(); } if (pastChatsToggle && !document.getElementById("past-chats-row")?.classList.contains("sidebar-hidden")) { pastChatsToggle.click(); } // Hide extensions only if (extensions) { extensions.style.display = "none"; } } } ================================================ FILE: js/switch_tabs.js ================================================ function scrollToTop() { window.scrollTo({ top: 0 }); } function findButtonsByText(buttonText) { const buttons = document.getElementsByTagName("button"); const matchingButtons = []; for (let i = 0; i < buttons.length; i++) { if (buttons[i].textContent.trim() === buttonText) { matchingButtons.push(buttons[i]); } } return matchingButtons; } function switch_to_chat() { document.getElementById("chat-tab-button").click(); scrollToTop(); } function switch_to_notebook() { document.getElementById("notebook-parent-tab-button").click(); findButtonsByText("Raw")[1].click(); scrollToTop(); } function switch_to_generation_parameters() { document.getElementById("parameters-button").click(); findButtonsByText("Generation")[0].click(); scrollToTop(); } function switch_to_character() { document.getElementById("character-tab-button").click(); scrollToTop(); } function switch_to_image_ai_generate() { const container = document.querySelector("#image-ai-tab"); const buttons = container.getElementsByTagName("button"); for (let i = 0; i < buttons.length; i++) { if (buttons[i].textContent.trim() === "Generate") { buttons[i].click(); break; } } scrollToTop(); } ================================================ FILE: js/update_big_picture.js ================================================ function updateBigPicture() { var existingElement = document.querySelector(".bigProfilePicture"); if (existingElement) { var timestamp = new Date().getTime(); existingElement.src = "/file/user_data/cache/pfp_character.png?time=" + timestamp; } } ================================================ FILE: modules/LoRA.py ================================================ from pathlib import Path import modules.shared as shared from modules.logging_colors import logger def add_lora_to_model(lora_names): add_lora_transformers(lora_names) def get_lora_path(lora_name): p = Path(lora_name) if p.exists(): lora_name = p.parts[-1] return Path(f"{shared.args.lora_dir}/{lora_name}") def add_lora_transformers(lora_names): from peft import PeftModel from modules.torch_utils import get_device prior_set = set(shared.lora_names) added_set = set(lora_names) - prior_set removed_set = prior_set - set(lora_names) # If no LoRA needs to be added or removed, exit if len(added_set) == 0 and len(removed_set) == 0: return # Add a LoRA when another LoRA is already present if len(removed_set) == 0 and len(prior_set) > 0 and "__merged" not in shared.model.peft_config.keys(): logger.info(f"Adding the LoRA(s) named {added_set} to the model") for lora in added_set: shared.model.load_adapter(get_lora_path(lora), lora) if len(lora_names) > 1: merge_loras() shared.lora_names = lora_names return # If any LoRA needs to be removed, start over if len(removed_set) > 0: shared.model = shared.model.unload() if len(lora_names) > 0: params = {} if not shared.args.cpu: if not shared.args.load_in_4bit and not shared.args.load_in_8bit: params['dtype'] = shared.model.dtype if hasattr(shared.model, "hf_device_map"): params['device_map'] = {"base_model.model." + k: v for k, v in shared.model.hf_device_map.items()} logger.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join(lora_names))) shared.model = PeftModel.from_pretrained(shared.model, get_lora_path(lora_names[0]), adapter_name=lora_names[0], **params) for lora in lora_names[1:]: shared.model.load_adapter(get_lora_path(lora), lora) if len(lora_names) > 1: merge_loras() if not shared.args.load_in_8bit and not shared.args.cpu: shared.model.half() if not hasattr(shared.model, "hf_device_map"): device = get_device() if device: shared.model = shared.model.to(device) shared.lora_names = lora_names def merge_loras(): if len(list({shared.model.peft_config[adapter].r for adapter in shared.model.peft_config.keys()})) > 1: logger.warning("The loaded LoRAs cannot be merged, as they have dissimilar ranks. Only the first one will be active.") return shared.model.add_weighted_adapter(shared.lora_names, [1] * len(shared.lora_names), "__merged") shared.model.set_adapter("__merged") ================================================ FILE: modules/callbacks.py ================================================ import traceback from queue import Queue from threading import Thread import modules.shared as shared class StopNowException(Exception): pass class Iteratorize: """ Transforms a function that takes a callback into a lazy iterator (generator). Adapted from: https://stackoverflow.com/a/9969000 """ def __init__(self, func, args=None, kwargs=None, callback=None): self.mfunc = func self.c_callback = callback self.q = Queue() self.sentinel = object() self.args = args or [] self.kwargs = kwargs or {} self.stop_now = False def _callback(val): if self.stop_now or shared.stop_everything: raise StopNowException self.q.put(val) def gentask(): try: ret = self.mfunc(callback=_callback, *args, **self.kwargs) except StopNowException: pass except Exception: traceback.print_exc() pass self.q.put(self.sentinel) if self.c_callback: self.c_callback(ret) self.thread = Thread(target=gentask) self.thread.start() def __iter__(self): return self def __next__(self): obj = self.q.get(True, None) if obj is self.sentinel: raise StopIteration else: return obj def __del__(self): pass def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.stop_now = True ================================================ FILE: modules/chat.py ================================================ import base64 import copy import functools import html import json import pprint import re import shutil import threading import time from datetime import datetime from functools import partial from pathlib import Path import markupsafe import yaml from jinja2.ext import loopcontrols from jinja2.sandbox import ImmutableSandboxedEnvironment from PIL import Image import modules.shared as shared from modules import utils from modules.extensions import apply_extensions from modules.html_generator import ( chat_html_wrapper, convert_to_markdown, extract_thinking_block, make_thumbnail ) from modules.image_utils import open_image_safely from modules.logging_colors import logger from modules.reasoning import THINKING_FORMATS from modules.text_generation import ( generate_reply, get_encoded_length, get_max_prompt_length ) from modules.utils import ( delete_file, get_available_characters, get_available_users, sanitize_filename, save_file ) from modules.web_search import add_web_search_attachments _history_file_lock = threading.Lock() def strftime_now(format): return datetime.now().strftime(format) def get_current_timestamp(): """Returns the current time in 24-hour format""" return datetime.now().strftime('%b %d, %Y %H:%M') def update_message_metadata(metadata_dict, role, index, **fields): """ Updates or adds metadata fields for a specific message. Args: metadata_dict: The metadata dictionary role: The role (user, assistant, etc) index: The message index **fields: Arbitrary metadata fields to update/add """ key = f"{role}_{index}" if key not in metadata_dict: metadata_dict[key] = {} # Update with provided fields for field_name, field_value in fields.items(): metadata_dict[key][field_name] = field_value jinja_env = ImmutableSandboxedEnvironment( trim_blocks=True, lstrip_blocks=True, extensions=[loopcontrols] ) def custom_tojson(value, indent=None, ensure_ascii=True): return markupsafe.Markup(json.dumps(value, indent=indent, ensure_ascii=ensure_ascii)) jinja_env.filters["tojson"] = custom_tojson jinja_env.globals["strftime_now"] = strftime_now def _raise_exception(message): raise ValueError(message) jinja_env.globals["raise_exception"] = _raise_exception _template_cache = {} def get_compiled_template(template_str): """Cache compiled Jinja2 templates keyed by their source string.""" compiled = _template_cache.get(template_str) if compiled is None: compiled = jinja_env.from_string(template_str) _template_cache[template_str] = compiled return compiled def str_presenter(dumper, data): """ Copied from https://github.com/yaml/pyyaml/issues/240 Makes pyyaml output prettier multiline strings. """ if data.count('\n') > 0: return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|') return dumper.represent_scalar('tag:yaml.org,2002:str', data) yaml.add_representer(str, str_presenter) yaml.representer.SafeRepresenter.add_representer(str, str_presenter) class _JsonDict(dict): """A dict that serializes as JSON when used in string concatenation. Some Jinja2 templates (Qwen, GLM) iterate arguments with .items(), requiring a dict. Others (DeepSeek) concatenate arguments as a string, requiring JSON. This class satisfies both. """ def __str__(self): return json.dumps(self, ensure_ascii=False) def __add__(self, other): return str(self) + other def __radd__(self, other): return other + str(self) def _deserialize_tool_call_arguments(tool_calls): """Convert tool_call arguments from JSON strings to _JsonDict. The OpenAI API spec sends arguments as a JSON string, but Jinja2 templates may need a dict (.items()) or a string (concatenation). _JsonDict handles both transparently. """ result = [] for tc in tool_calls: tc = copy.copy(tc) func = tc.get('function', {}) if isinstance(func, dict): func = dict(func) args = func.get('arguments') if isinstance(args, str): try: func['arguments'] = _JsonDict(json.loads(args)) except (json.JSONDecodeError, ValueError): pass elif isinstance(args, dict) and not isinstance(args, _JsonDict): func['arguments'] = _JsonDict(args) tc['function'] = func result.append(tc) return result def _expand_tool_sequence(tool_seq): """Expand a tool_sequence list into API messages. Returns a list of dicts (role: assistant with tool_calls, or role: tool). If any tool_call IDs are missing a matching tool result, a synthetic empty result is inserted so the prompt is never malformed. """ messages = [] expected_ids = [] seen_ids = set() for item in tool_seq: if 'tool_calls' in item: deserialized = _deserialize_tool_call_arguments(item['tool_calls']) messages.append({ "role": "assistant", "content": item.get('content', ''), "tool_calls": deserialized }) for tc in item['tool_calls']: tc_id = tc.get('id', '') if tc_id: expected_ids.append(tc_id) elif item.get('role') == 'tool': messages.append({ "role": "tool", "content": item['content'], "tool_call_id": item.get('tool_call_id', '') }) seen_ids.add(item.get('tool_call_id', '')) # Fill in synthetic results for any orphaned tool call IDs for tc_id in expected_ids: if tc_id not in seen_ids: messages.append({ "role": "tool", "content": "", "tool_call_id": tc_id }) return messages def generate_chat_prompt(user_input, state, **kwargs): impersonate = kwargs.get('impersonate', False) _continue = kwargs.get('_continue', False) also_return_rows = kwargs.get('also_return_rows', False) history_data = kwargs.get('history', state['history']) history = history_data['internal'] metadata = history_data.get('metadata', {}) # Templates chat_template_str = state['chat_template_str'] if state['mode'] != 'instruct': chat_template_str = replace_character_names(chat_template_str, state['name1'], state['name2']) instruction_template = get_compiled_template(state['instruction_template_str']) chat_template = get_compiled_template(chat_template_str) instruct_renderer = partial( instruction_template.render, builtin_tools=None, tools=state['tools'] if 'tools' in state else None, tools_in_user_message=False, add_generation_prompt=False, enable_thinking=state['enable_thinking'], reasoning_effort=state['reasoning_effort'], thinking_budget=-1 if state.get('enable_thinking', True) else 0, bos_token=shared.bos_token, eos_token=shared.eos_token, ) chat_renderer = partial( chat_template.render, add_generation_prompt=False, name1=state['name1'], name2=state['name2'], user_bio=replace_character_names(state['user_bio'], state['name1'], state['name2']), tools=state['tools'] if 'tools' in state else None, ) messages = [] if state['mode'] == 'instruct': renderer = instruct_renderer if state['custom_system_message'].strip() != '': messages.append({"role": "system", "content": state['custom_system_message']}) else: renderer = chat_renderer if state['context'].strip() != '' or state['user_bio'].strip() != '': context = replace_character_names(state['context'], state['name1'], state['name2']) messages.append({"role": "system", "content": context}) insert_pos = len(messages) for i, entry in enumerate(reversed(history)): user_msg = entry[0].strip() assistant_msg = entry[1].strip() tool_msg = entry[2].strip() if len(entry) > 2 else '' entry_meta = entry[3] if len(entry) > 3 else {} row_idx = len(history) - i - 1 if tool_msg: tool_message = {"role": "tool", "content": tool_msg} if "tool_call_id" in entry_meta: tool_message["tool_call_id"] = entry_meta["tool_call_id"] messages.insert(insert_pos, tool_message) if not assistant_msg and entry_meta.get('tool_calls'): # Assistant message with only tool_calls and no text content messages.insert(insert_pos, {"role": "assistant", "content": "", "tool_calls": _deserialize_tool_call_arguments(entry_meta['tool_calls'])}) elif assistant_msg: # Handle GPT-OSS as a special case if '<|channel|>analysis<|message|>' in assistant_msg or '<|channel|>final<|message|>' in assistant_msg: thinking_content = "" final_content = "" # Extract analysis content if present if '<|channel|>analysis<|message|>' in assistant_msg: parts = assistant_msg.split('<|channel|>analysis<|message|>', 1) if len(parts) > 1: # The content is everything after the tag potential_content = parts[1] # Now, find the end of this content block analysis_end_tag = '<|end|>' if analysis_end_tag in potential_content: thinking_content = potential_content.split(analysis_end_tag, 1)[0].strip() else: # Fallback: if no <|end|> tag, stop at the start of the final channel if it exists final_channel_tag = '<|channel|>final<|message|>' if final_channel_tag in potential_content: thinking_content = potential_content.split(final_channel_tag, 1)[0].strip() else: thinking_content = potential_content.strip() # Extract final content if present final_tag_to_find = '<|channel|>final<|message|>' if final_tag_to_find in assistant_msg: parts = assistant_msg.split(final_tag_to_find, 1) if len(parts) > 1: # The content is everything after the tag potential_content = parts[1] # Now, find the end of this content block final_end_tag = '<|end|>' if final_end_tag in potential_content: final_content = potential_content.split(final_end_tag, 1)[0].strip() else: final_content = potential_content.strip() # Insert as structured message msg_dict = {"role": "assistant", "content": final_content} if '<|channel|>analysis<|message|>' in assistant_msg: msg_dict["thinking"] = thinking_content messages.insert(insert_pos, msg_dict) # Handle Seed-OSS elif '' in assistant_msg: thinking_content = "" final_content = assistant_msg # Extract thinking content if present if '' in assistant_msg: parts = assistant_msg.split('', 1) if len(parts) > 1: potential_content = parts[1] if '' in potential_content: thinking_content = potential_content.split('', 1)[0].strip() final_content = parts[0] + potential_content.split('', 1)[1] else: thinking_content = potential_content.strip() final_content = parts[0] # Insert as structured message msg_dict = {"role": "assistant", "content": final_content.strip()} if thinking_content: msg_dict["reasoning_content"] = thinking_content messages.insert(insert_pos, msg_dict) else: # Default case (used by all other models) messages.insert(insert_pos, {"role": "assistant", "content": assistant_msg}) # Attach tool_calls metadata to the assistant message if present if entry_meta.get('tool_calls') and messages[insert_pos].get('role') == 'assistant': messages[insert_pos]['tool_calls'] = _deserialize_tool_call_arguments(entry_meta['tool_calls']) # Expand tool_sequence from metadata (inserted AFTER assistant so that # the final order is: user → tool_calls → tool_results → final_answer) meta_key = f"assistant_{row_idx}" tool_seq = metadata.get(meta_key, {}).get('tool_sequence', []) if tool_seq: for msg in reversed(_expand_tool_sequence(tool_seq)): messages.insert(insert_pos, msg) if entry_meta.get('role') == 'system': if user_msg: messages.insert(insert_pos, {"role": "system", "content": user_msg}) elif user_msg not in ['', '<|BEGIN-VISIBLE-CHAT|>']: # Check for user message attachments in metadata user_key = f"user_{row_idx}" enhanced_user_msg = user_msg # Add attachment content if present AND if past attachments are enabled if user_key in metadata and "attachments" in metadata[user_key]: attachments_text = "" image_refs = "" for attachment in metadata[user_key]["attachments"]: if attachment.get("type") == "image": # Add image reference for multimodal models image_refs += "<__media__>" elif state.get('include_past_attachments', True): # Handle text/PDF attachments filename = attachment.get("name", "file") content = attachment.get("content", "") if attachment.get("type") == "text/html" and attachment.get("url"): attachments_text += f"\nName: {filename}\nURL: {attachment['url']}\nContents:\n\n=====\n{content}\n=====\n\n" else: attachments_text += f"\nName: {filename}\nContents:\n\n=====\n{content}\n=====\n\n" if image_refs: enhanced_user_msg = f"{image_refs}\n\n{enhanced_user_msg}" if attachments_text: enhanced_user_msg += f"\n\nATTACHMENTS:\n{attachments_text}" messages.insert(insert_pos, {"role": "user", "content": enhanced_user_msg}) # Handle the current user input user_input = user_input.strip() # Check if we have attachments if not (impersonate or _continue): has_attachments = False if len(history_data.get('metadata', {})) > 0: current_row_idx = len(history) user_key = f"user_{current_row_idx}" has_attachments = user_key in metadata and "attachments" in metadata[user_key] if user_input or has_attachments: # For the current user input being processed, check if we need to add attachments if len(history_data.get('metadata', {})) > 0: current_row_idx = len(history) user_key = f"user_{current_row_idx}" if user_key in metadata and "attachments" in metadata[user_key]: attachments_text = "" image_refs = "" for attachment in metadata[user_key]["attachments"]: if attachment.get("type") == "image": image_refs += "<__media__>" else: filename = attachment.get("name", "file") content = attachment.get("content", "") if attachment.get("type") == "text/html" and attachment.get("url"): attachments_text += f"\nName: {filename}\nURL: {attachment['url']}\nContents:\n\n=====\n{content}\n=====\n\n" else: attachments_text += f"\nName: {filename}\nContents:\n\n=====\n{content}\n=====\n\n" if image_refs: user_input = f"{image_refs}\n\n{user_input}" if attachments_text: user_input += f"\n\nATTACHMENTS:\n{attachments_text}" messages.append({"role": "user", "content": user_input}) # Expand tool_sequence for the current entry (excluded from the # history loop during regenerate — needed so the model sees prior # tool calls and results when re-generating the final answer). current_tool_seq = metadata.get(f"assistant_{len(history)}", {}).get('tool_sequence', []) messages.extend(_expand_tool_sequence(current_tool_seq)) if impersonate and state['mode'] != 'chat-instruct': messages.append({"role": "user", "content": "fake user message replace me"}) def make_prompt(messages): last_message = messages[-1].copy() if _continue: if state['mode'] == 'chat-instruct': messages = messages[:-1] else: messages[-1]["content"] = "fake assistant message replace me" messages.append({"role": "assistant", "content": "this will get deleted"}) if state['mode'] != 'chat-instruct': add_generation_prompt = (not _continue and not impersonate) else: add_generation_prompt = False prompt = renderer( messages=messages, add_generation_prompt=add_generation_prompt ) if state['mode'] == 'chat-instruct': command = state['chat-instruct_command'] command = command.replace('<|character|>', state['name2'] if not impersonate else state['name1']) command = command.replace('<|prompt|>', prompt) command = replace_character_names(command, state['name1'], state['name2']) outer_messages = [] if state['custom_system_message'].strip() != '': outer_messages.append({"role": "system", "content": state['custom_system_message']}) outer_messages.append({"role": "user", "content": command}) if _continue: outer_messages.append(last_message.copy()) outer_messages[-1]["content"] = "fake assistant message replace me" outer_messages.append({"role": "assistant", "content": "this will get deleted"}) prompt = instruct_renderer( messages=outer_messages, add_generation_prompt=not _continue ) if _continue: prompt = prompt.split("fake assistant message replace me", 1)[0] content = last_message.get("content", "") partial_thought = last_message.get("thinking", "") or last_message.get("reasoning_content", "") # Handle partial thinking blocks (GPT-OSS and Seed-OSS) if not content and partial_thought and partial_thought.strip(): search_string = partial_thought.strip() index = prompt.rfind(search_string) if index != -1: prompt = prompt[:index] + partial_thought else: # Fallback if search fails: just append the thought prompt += partial_thought else: # All other cases prompt += content if impersonate: prompt = prompt.split("fake user message replace me", 1)[0] prompt += user_input if state['mode'] in ['chat', 'chat-instruct'] and not impersonate and not _continue: prompt += apply_extensions('bot_prefix', "", state) return prompt prompt = make_prompt(messages) # Handle truncation if shared.tokenizer is not None: max_length = get_max_prompt_length(state) encoded_length = get_encoded_length(prompt) while len(messages) > 0 and encoded_length > max_length: # Remove old message, save system message if len(messages) > 2 and messages[0]['role'] == 'system': messages.pop(1) # Remove old message when no system message is present elif len(messages) > 1 and messages[0]['role'] != 'system': messages.pop(0) # Resort to truncating the user input else: user_message = messages[-1]['content'] # Bisect the truncation point left, right = 0, len(user_message) while left < right: mid = (left + right + 1) // 2 messages[-1]['content'] = user_message[:mid] prompt = make_prompt(messages) encoded_length = get_encoded_length(prompt) if encoded_length <= max_length: left = mid else: right = mid - 1 messages[-1]['content'] = user_message[:left] prompt = make_prompt(messages) encoded_length = get_encoded_length(prompt) if encoded_length > max_length: logger.error(f"Failed to build the chat prompt. The input is too long for the available context length.\n\nTruncation length: {state['truncation_length']}\nmax_new_tokens: {state['max_new_tokens']} (is it too high?)\nAvailable context length: {max_length}\n") raise ValueError else: # Calculate token counts for the log message original_user_tokens = get_encoded_length(user_message) truncated_user_tokens = get_encoded_length(user_message[:left]) total_context = max_length + state['max_new_tokens'] logger.warning( f"User message truncated from {original_user_tokens} to {truncated_user_tokens} tokens. " f"Context full: {max_length} input tokens ({total_context} total, {state['max_new_tokens']} for output). " f"Increase ctx-size while loading the model to avoid truncation." ) break prompt = make_prompt(messages) encoded_length = get_encoded_length(prompt) if also_return_rows: return prompt, [message['content'] for message in messages] else: return prompt def count_prompt_tokens(text_input, state): """Count tokens for current history + input including attachments""" if shared.tokenizer is None: return "Tokenizer not available" try: # Handle dict format with text and files files = [] if isinstance(text_input, dict): files = text_input.get('files', []) text = text_input.get('text', '') else: text = text_input files = [] # Create temporary history copy to add attachments temp_history = copy.deepcopy(state['history']) if 'metadata' not in temp_history: temp_history['metadata'] = {} # Process attachments if any if files: row_idx = len(temp_history['internal']) for file_path in files: add_message_attachment(temp_history, row_idx, file_path, is_user=True) # Create temp state with modified history temp_state = copy.deepcopy(state) temp_state['history'] = temp_history # Build prompt using existing logic prompt = generate_chat_prompt(text, temp_state) current_tokens = get_encoded_length(prompt) max_tokens = temp_state['truncation_length'] percentage = (current_tokens / max_tokens) * 100 if max_tokens > 0 else 0 return f"History + Input:
{current_tokens:,} / {max_tokens:,} tokens ({percentage:.1f}%)" except Exception as e: logger.error(f"Error counting tokens: {e}") return f"Error: {str(e)}" def get_stopping_strings(state): stopping_strings = [] renderers = [] if state['mode'] in ['instruct', 'chat-instruct']: template = get_compiled_template(state['instruction_template_str']) renderer = partial(template.render, add_generation_prompt=False, bos_token=shared.bos_token, eos_token=shared.eos_token) renderers.append(renderer) if state['mode'] in ['chat']: template = get_compiled_template(state['chat_template_str']) renderer = partial(template.render, add_generation_prompt=False, name1=state['name1'], name2=state['name2']) renderers.append(renderer) fake_messages = [ {"role": "user", "content": "first user message"}, {"role": "assistant", "content": "first assistant message"}, {"role": "user", "content": "second user message"}, {"role": "assistant", "content": "second assistant message"}, ] stopping_strings = [] for renderer in renderers: prompt = renderer(messages=fake_messages) # Find positions of each message content first_user_end = prompt.find("first user message") + len("first user message") first_assistant_start = prompt.find("first assistant message") first_assistant_end = prompt.find("first assistant message") + len("first assistant message") second_user_start = prompt.find("second user message") second_assistant_end = prompt.find("second assistant message") + len("second assistant message") # Extract pieces of text potentially containing unique stopping strings texts = [ prompt[first_user_end:first_assistant_start], prompt[first_assistant_end:second_user_start], prompt[second_assistant_end:] ] for text in texts: stripped_text = text.strip() if stripped_text.startswith("<") and ">" in stripped_text: stopping_strings.append(stripped_text.split(">")[0] + ">") elif stripped_text.startswith("[") and "]" in stripped_text: stopping_strings.append(stripped_text.split("]")[0] + "]") elif stripped_text.startswith("(") and ")" in stripped_text: stopping_strings.append(stripped_text.split(")")[0] + ")") elif stripped_text.startswith("{") and "}" in stripped_text: stopping_strings.append(stripped_text.split("}")[0] + "}") elif ":" in text: stopping_strings.append(text.split(":")[0] + ":") if 'stopping_strings' in state and isinstance(state['stopping_strings'], list): stopping_strings += state.pop('stopping_strings') # Remove redundant items that start with another item result = [item for item in stopping_strings if not any(item.startswith(other) and item != other for other in stopping_strings)] result = list(set(result)) # Handle GPT-OSS as a special case if '<|channel|>final<|message|>' in state['instruction_template_str'] and "<|end|>" in result: result.remove("<|end|>") result.append("<|result|>") result = list(set(result)) if shared.args.verbose: logger.info("STOPPING_STRINGS=") pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(result) print() return result def add_message_version(history, role, row_idx, is_current=True): key = f"{role}_{row_idx}" if 'metadata' not in history: history['metadata'] = {} if key not in history['metadata']: history['metadata'][key] = {} if "versions" not in history['metadata'][key]: history['metadata'][key]["versions"] = [] # Determine which index to use for content based on role content_idx = 0 if role == 'user' else 1 current_content = history['internal'][row_idx][content_idx] current_visible = history['visible'][row_idx][content_idx] history['metadata'][key]["versions"].append({ "content": current_content, "visible_content": current_visible, "timestamp": get_current_timestamp() }) if is_current: # Set the current_version_index to the newly added version (which is now the last one). history['metadata'][key]["current_version_index"] = len(history['metadata'][key]["versions"]) - 1 def add_message_attachment(history, row_idx, file_path, is_user=True): """Add a file attachment to a message in history metadata""" if 'metadata' not in history: history['metadata'] = {} key = f"{'user' if is_user else 'assistant'}_{row_idx}" if key not in history['metadata']: history['metadata'][key] = {"timestamp": get_current_timestamp()} if "attachments" not in history['metadata'][key]: history['metadata'][key]["attachments"] = [] # Get file info using pathlib path = Path(file_path) filename = path.name file_extension = path.suffix.lower() try: # Handle image files if file_extension in ['.jpg', '.jpeg', '.png', '.webp', '.bmp', '.gif']: # Convert image to base64 with open(path, 'rb') as f: image_data = base64.b64encode(f.read()).decode('utf-8') # Determine MIME type from extension mime_type_map = { '.jpg': 'image/jpeg', '.jpeg': 'image/jpeg', '.png': 'image/png', '.webp': 'image/webp', '.bmp': 'image/bmp', '.gif': 'image/gif' } mime_type = mime_type_map.get(file_extension, 'image/jpeg') # Format as data URL data_url = f"data:{mime_type};base64,{image_data}" # Generate unique image ID image_id = len([att for att in history['metadata'][key]["attachments"] if att.get("type") == "image"]) + 1 attachment = { "name": filename, "type": "image", "image_data": data_url, "image_id": image_id, } elif file_extension == '.pdf': # Process PDF file content = extract_pdf_text(path) attachment = { "name": filename, "type": "application/pdf", "content": content, } elif file_extension == '.docx': content = extract_docx_text(path) attachment = { "name": filename, "type": "application/docx", "content": content, } else: # Default handling for text files with open(path, 'r', encoding='utf-8') as f: content = f.read() attachment = { "name": filename, "type": "text/plain", "content": content, } history['metadata'][key]["attachments"].append(attachment) return attachment # Return the attachment for reuse except Exception as e: logger.error(f"Error processing attachment {filename}: {e}") return None def extract_pdf_text(pdf_path): """Extract text from a PDF file""" import pymupdf text = "" try: with pymupdf.open(pdf_path) as doc: for page in doc: text += page.get_text() + "\n\n" return text.strip() except Exception as e: logger.error(f"Error extracting text from PDF: {e}") return f"[Error extracting PDF text: {str(e)}]" def extract_docx_text(docx_path): """ Extract text from a .docx file, including headers, body (paragraphs and tables), and footers. """ try: import docx doc = docx.Document(docx_path) parts = [] # 1) Extract non-empty header paragraphs from each section for section in doc.sections: for para in section.header.paragraphs: text = para.text.strip() if text: parts.append(text) # 2) Extract body blocks (paragraphs and tables) in document order parent_elm = doc.element.body for child in parent_elm.iterchildren(): if isinstance(child, docx.oxml.text.paragraph.CT_P): para = docx.text.paragraph.Paragraph(child, doc) text = para.text.strip() if text: parts.append(text) elif isinstance(child, docx.oxml.table.CT_Tbl): table = docx.table.Table(child, doc) for row in table.rows: cells = [cell.text.strip() for cell in row.cells] parts.append("\t".join(cells)) # 3) Extract non-empty footer paragraphs from each section for section in doc.sections: for para in section.footer.paragraphs: text = para.text.strip() if text: parts.append(text) return "\n".join(parts) except Exception as e: logger.error(f"Error extracting text from DOCX: {e}") return f"[Error extracting DOCX text: {str(e)}]" def generate_search_query(user_message, state): """Generate a search query from user message using the LLM""" # Augment the user message with search instruction augmented_message = f"{user_message}\n\n=====\n\nPlease turn the message above into a short web search query in the same language as the message. Respond with only the search query, nothing else." # Use a minimal state for search query generation but keep the full history search_state = state.copy() search_state['auto_max_new_tokens'] = True search_state['enable_thinking'] = False search_state['reasoning_effort'] = 'low' search_state['start_with'] = "" # Generate the full prompt using existing history + augmented message formatted_prompt = generate_chat_prompt(augmented_message, search_state) query = "" for reply in generate_reply(formatted_prompt, search_state, stopping_strings=[], is_chat=True): query = reply # Check for thinking block delimiters and extract content after them if "" in query: query = query.rsplit("", 1)[1] elif "<|start|>assistant<|channel|>final<|message|>" in query: query = query.rsplit("<|start|>assistant<|channel|>final<|message|>", 1)[1] elif "<|channel|>final<|message|>" in query: query = query.rsplit("<|channel|>final<|message|>", 1)[1] elif "" in query: query = query.rsplit("", 1)[1] # Strip and remove surrounding quotes if present query = query.strip() if len(query) >= 2 and query.startswith('"') and query.endswith('"'): query = query[1:-1] return query def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_message=True, for_ui=False): # Handle dict format with text and files files = [] if isinstance(text, dict): files = text.get('files', []) text = text.get('text', '') history = state['history'] output = copy.deepcopy(history) output = apply_extensions('history', output) state = apply_extensions('state', state) # Handle GPT-OSS as a special case if '<|channel|>final<|message|>' in state['instruction_template_str']: state['skip_special_tokens'] = False # Let the jinja2 template handle the BOS token if state['mode'] in ['instruct', 'chat-instruct']: state['add_bos_token'] = False # Initialize metadata if not present if 'metadata' not in output: output['metadata'] = {} visible_text = None stopping_strings = get_stopping_strings(state) is_stream = state['stream'] # Prepare the input if not (regenerate or _continue): visible_text = html.escape(text) # Process file attachments and store in metadata row_idx = len(output['internal']) # Add attachments to metadata only, not modifying the message text for file_path in files: add_message_attachment(output, row_idx, file_path, is_user=True) # Add web search results as attachments if enabled if state.get('enable_web_search', False): search_query = generate_search_query(text, state) add_web_search_attachments(output, row_idx, text, search_query, state) # Apply extensions text, visible_text = apply_extensions('chat_input', text, visible_text, state) text = apply_extensions('input', text, state, is_chat=True) # Current row index output['internal'].append([text, '']) output['visible'].append([visible_text, '']) # Add metadata with timestamp update_message_metadata(output['metadata'], "user", row_idx, timestamp=get_current_timestamp()) # *Is typing...* if loading_message: yield { 'visible': output['visible'][:-1] + [[output['visible'][-1][0], shared.processing_message]], 'internal': output['internal'], 'metadata': output['metadata'] } else: text, visible_text = output['internal'][-1][0], output['visible'][-1][0] if regenerate and not state.get('_tool_turn'): row_idx = len(output['internal']) - 1 # Store the old response as a version before regenerating if not output['metadata'].get(f"assistant_{row_idx}", {}).get('versions'): add_message_version(output, "assistant", row_idx, is_current=False) # Add new empty version (will be filled during streaming) key = f"assistant_{row_idx}" output['metadata'][key]["versions"].append({ "content": "", "visible_content": "", "timestamp": get_current_timestamp() }) output['metadata'][key]["current_version_index"] = len(output['metadata'][key]["versions"]) - 1 if loading_message: yield { 'visible': output['visible'][:-1] + [[visible_text, shared.processing_message]], 'internal': output['internal'][:-1] + [[text, '']], 'metadata': output['metadata'] } elif _continue: last_reply = [output['internal'][-1][1], output['visible'][-1][1]] if loading_message: yield { 'visible': output['visible'][:-1] + [[visible_text, last_reply[1] + '...']], 'internal': output['internal'], 'metadata': output['metadata'] } row_idx = len(output['internal']) - 1 # Collect image attachments for multimodal generation from the entire history all_image_attachments = [] if 'metadata' in output: for i in range(len(output['internal'])): user_key = f"user_{i}" if user_key in output['metadata'] and "attachments" in output['metadata'][user_key]: for attachment in output['metadata'][user_key]["attachments"]: if attachment.get("type") == "image": all_image_attachments.append(attachment) # Add all collected image attachments to state for the generation if all_image_attachments: state['image_attachments'] = all_image_attachments # Generate the prompt kwargs = { '_continue': _continue, 'history': output if _continue else { k: (v[:-1] if k in ['internal', 'visible'] else v) for k, v in output.items() } } prompt = apply_extensions('custom_generate_chat_prompt', text, state, **kwargs) if prompt is None: prompt = generate_chat_prompt(text, state, **kwargs) # Add timestamp for assistant's response at the start of generation update_message_metadata(output['metadata'], "assistant", row_idx, timestamp=get_current_timestamp(), model_name=shared.model_name) # Detect if the template appended a thinking start tag to the prompt thinking_prefix = None if not _continue: stripped_prompt = prompt.rstrip('\n') for start_tag, end_tag, content_tag in THINKING_FORMATS: if start_tag is not None and stripped_prompt.endswith(start_tag): thinking_prefix = start_tag break # When tools are active, buffer streaming output during potential tool # call generation to prevent raw markup from leaking into the display. _check_tool_markers = bool(state.get('tools')) _last_visible_before_tool_buffer = None if _check_tool_markers: from modules.tool_parsing import streaming_tool_buffer_check, detect_tool_call_format _tool_names = [t['function']['name'] for t in state['tools'] if 'function' in t and 'name' in t['function']] _template_str = state.get('instruction_template_str', '') if state.get('mode') == 'instruct' else state.get('chat_template_str', '') _, _streaming_markers, _check_bare_names = detect_tool_call_format(_template_str) # Generate reply = None for j, reply in enumerate(generate_reply(prompt, state, stopping_strings=stopping_strings, is_chat=True, for_ui=for_ui)): # Prepend thinking tag if the template appended it to the prompt if thinking_prefix: reply = thinking_prefix + reply # Extract the reply if state['mode'] in ['chat', 'chat-instruct']: if not _continue: reply = reply.lstrip() if reply.startswith(state['name2'] + ':'): reply = reply[len(state['name2'] + ':'):] elif reply.startswith(state['name1'] + ':'): reply = reply[len(state['name1'] + ':'):] visible_reply = re.sub("(||{{user}})", state['name1'], reply) else: visible_reply = reply visible_reply = html.escape(visible_reply) if shared.stop_everything: if not state.get('_skip_output_extensions'): output['visible'][-1][1] = apply_extensions('output', output['visible'][-1][1], state, is_chat=True) yield output return if _continue: output['internal'][-1] = [text, last_reply[0] + reply] output['visible'][-1] = [visible_text, last_reply[1] + visible_reply] elif not (j == 0 and visible_reply.strip() == ''): output['internal'][-1] = [text, reply.lstrip(' ')] output['visible'][-1] = [visible_text, visible_reply.lstrip(' ')] # Keep version metadata in sync during streaming (for regeneration) if regenerate and not state.get('_tool_turn'): row_idx = len(output['internal']) - 1 key = f"assistant_{row_idx}" current_idx = output['metadata'][key]['current_version_index'] output['metadata'][key]['versions'][current_idx].update({ 'content': output['internal'][row_idx][1], 'visible_content': output['visible'][row_idx][1] }) if is_stream: if _check_tool_markers: if streaming_tool_buffer_check(output['internal'][-1][1], markers=_streaming_markers, tool_names=_tool_names, check_bare_names=_check_bare_names): continue _last_visible_before_tool_buffer = output['visible'][-1][1] yield output if _continue: # Reprocess the entire internal text for extensions (like translation). # Skip entirely when the visible text contains markers, # since those only exist in visible (internal is cleared after each tool # execution) and rebuilding from internal would destroy them. Output # extensions also can't handle the raw markup safely. if '' not in output['visible'][-1][1]: full_internal = output['internal'][-1][1] if state['mode'] in ['chat', 'chat-instruct']: full_visible = re.sub("(||{{user}})", state['name1'], full_internal) else: full_visible = full_internal full_visible = html.escape(full_visible) if not state.get('_skip_output_extensions'): output['visible'][-1][1] = apply_extensions('output', full_visible, state, is_chat=True) else: if not state.get('_skip_output_extensions'): output['visible'][-1][1] = apply_extensions('output', output['visible'][-1][1], state, is_chat=True) # Final sync for version metadata (in case streaming was disabled) if regenerate and not state.get('_tool_turn'): row_idx = len(output['internal']) - 1 key = f"assistant_{row_idx}" current_idx = output['metadata'][key]['current_version_index'] output['metadata'][key]['versions'][current_idx].update({ 'content': output['internal'][row_idx][1], 'visible_content': output['visible'][row_idx][1] }) # When tool markers were detected during streaming, restore the last # visible text from before buffering started so raw markup doesn't flash # in the UI. The internal text is left intact so the caller can still # parse tool calls from it. if is_stream and _check_tool_markers and streaming_tool_buffer_check(output['internal'][-1][1], markers=_streaming_markers, tool_names=_tool_names, check_bare_names=_check_bare_names): output['visible'][-1][1] = _last_visible_before_tool_buffer or '' yield output def impersonate_wrapper(textbox, state): text = textbox['text'] static_output = chat_html_wrapper(state['history'], state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']) prompt = generate_chat_prompt('', state, impersonate=True) stopping_strings = get_stopping_strings(state) textbox['text'] = text + '...' yield textbox, static_output reply = None for reply in generate_reply(prompt + text, state, stopping_strings=stopping_strings, is_chat=True): textbox['text'] = (text + reply).lstrip(' ') yield textbox, static_output if shared.stop_everything: return def generate_chat_reply(text, state, regenerate=False, _continue=False, loading_message=True, for_ui=False): history = state['history'] if regenerate or _continue: text = '' if (len(history['visible']) == 1 and not history['visible'][0][0]) or len(history['internal']) == 0: yield history return for history in chatbot_wrapper(text, state, regenerate=regenerate, _continue=_continue, loading_message=loading_message, for_ui=for_ui): yield history def character_is_loaded(state, raise_exception=False): if state['mode'] in ['chat', 'chat-instruct'] and state['name2'] == '': logger.error('It looks like no character is loaded. Please load one under Parameters > Character.') if raise_exception: raise ValueError return False else: return True def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False): ''' Same as above but returns HTML for the UI. When tools are selected, wraps generation in a loop that detects tool calls, executes them, and re-generates until the model stops. All tool output is consolidated into a single visible chat bubble using metadata['assistant_N']['tool_sequence']. ''' if not character_is_loaded(state): return if state['start_with'] != '' and not _continue: if regenerate: text, state['history'] = remove_last_message(state['history']) regenerate = False _continue = True send_dummy_message(text, state) send_dummy_reply(state['start_with'], state) # On regenerate, clear old tool_sequence metadata so it gets rebuilt. # Save it first so it can be stored per-version below. # This must happen after the start_with logic above, which may remove # and re-add messages, changing which row we operate on. _old_tool_sequence = None if regenerate: history = state['history'] meta = history.get('metadata', {}) row_idx = len(history['internal']) - 1 if row_idx >= 0: _old_tool_sequence = meta.get(f'assistant_{row_idx}', {}).pop('tool_sequence', None) # Load tools if any are selected selected = state.get('selected_tools', []) parse_tool_call = None _tool_parsers = None if selected: from modules.tool_use import load_tools, execute_tool from modules.tool_parsing import parse_tool_call, get_tool_call_id, detect_tool_call_format if selected: tool_defs, tool_executors = load_tools(selected) state['tools'] = tool_defs tool_func_names = [t['function']['name'] for t in tool_defs] _template_str = state.get('instruction_template_str', '') if state.get('mode') == 'instruct' else state.get('chat_template_str', '') _tool_parsers, _, _ = detect_tool_call_format(_template_str) else: tool_func_names = None visible_prefix = [] # Accumulated tool call summaries + results last_save_time = time.monotonic() save_interval = 8 _tool_turn = 0 while True: history = state['history'] # Turn 0: use original flags; turns 2+: regenerate into the same entry. # _tool_turn tells chatbot_wrapper to skip version creation/sync so # that intermediate tool-loop regenerations don't pollute swipe history. if _tool_turn > 0: state['_tool_turn'] = True state['_skip_output_extensions'] = True regen = regenerate if _tool_turn == 0 else True cont = _continue if _tool_turn == 0 else False cur_text = text if _tool_turn == 0 else '' for i, history in enumerate(generate_chat_reply(cur_text, state, regen, cont, loading_message=True, for_ui=True)): # Prepend accumulated tool output to visible reply for display. # Save and restore the original to prevent the markers from leaking # back into chatbot_wrapper's shared output object, which would cause # duplication on the next yield. _original_visible = history['visible'][-1][1] if visible_prefix else None if visible_prefix: history['visible'][-1][1] = '\n\n'.join(visible_prefix + [_original_visible]) yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'], last_message_only=(i > 0)), history if visible_prefix: history['visible'][-1][1] = _original_visible if i == 0: # Save old tool_sequence into version 0 (created by chatbot_wrapper # on the first yield). Only needed on the first regeneration when # versions didn't previously exist. if _old_tool_sequence is not None and _tool_turn == 0: _ri = len(history['internal']) - 1 _versions = history.get('metadata', {}).get(f'assistant_{_ri}', {}).get('versions', []) if _versions and 'tool_sequence' not in _versions[0]: _versions[0]['tool_sequence'] = _old_tool_sequence _old_tool_sequence = None time.sleep(0.125) current_time = time.monotonic() if i == 0 or (current_time - last_save_time) >= save_interval: save_history(history, state['unique_id'], state['character_menu'], state['mode']) last_save_time = current_time # Early stop on tool call detection if tool_func_names and parse_tool_call(history['internal'][-1][1], tool_func_names, parsers=_tool_parsers): break # Save the model's visible output before re-applying visible_prefix, # so we can extract thinking content from just this turn's output. _model_visible = history['visible'][-1][1] # Recover visible_prefix from existing visible text (e.g. on Continue # after a previous session had tool calls). Extract all # blocks and any text between them (thinking blocks, intermediate text). if tool_func_names and not visible_prefix and _model_visible: tc_matches = list(re.finditer(r'.*?', _model_visible, re.DOTALL)) if tc_matches: prefix_end = tc_matches[-1].end() prefix = _model_visible[:prefix_end].strip() if prefix: visible_prefix = [prefix] _model_visible = _model_visible[prefix_end:].strip() # Re-apply visible prefix to the final state after streaming completes. # This is safe because we're no longer sharing the object with chatbot_wrapper. if visible_prefix: history['visible'][-1][1] = '\n\n'.join(visible_prefix + [_model_visible]) if tool_func_names: save_history(history, state['unique_id'], state['character_menu'], state['mode']) # Check for tool calls if not tool_func_names or shared.stop_everything: break answer = history['internal'][-1][1] parsed_calls, content_prefix = parse_tool_call(answer, tool_func_names, return_prefix=True, parsers=_tool_parsers) if answer else (None, '') if not parsed_calls: break # No tool calls — done # --- Process tool calls --- row_idx = len(history['internal']) - 1 meta = history.get('metadata', {}) seq = meta.setdefault(f'assistant_{row_idx}', {}).setdefault('tool_sequence', []) def _render(): return chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']) # Serialize tool calls and build display headers in one pass serialized = [] tc_headers = [] for tc in parsed_calls: tc['id'] = get_tool_call_id() fn_name = tc['function']['name'] fn_args = tc['function'].get('arguments', {}) serialized.append({ 'id': tc['id'], 'type': 'function', 'function': { 'name': fn_name, 'arguments': json.dumps(fn_args) if isinstance(fn_args, dict) else fn_args } }) if isinstance(fn_args, dict) and fn_args: args_summary = ', '.join(f'{k}={json.dumps(v, ensure_ascii=False)}' for k, v in fn_args.items()) elif isinstance(fn_args, dict): args_summary = '' else: args_summary = str(fn_args) tc_headers.append(f'{fn_name}({args_summary})') seq_entry = {'tool_calls': serialized} if content_prefix.strip(): # Strip GPT-OSS channel tokens so they don't get double-wrapped # by the template (which adds its own channel markup). clean = content_prefix.strip() if '<|channel|>' in clean and '<|message|>' in clean: inner = clean.split('<|message|>', 1)[1] if '<|end|>' in inner: inner = inner.split('<|end|>', 1)[0] clean = inner.strip() if clean: seq_entry['content'] = clean seq.append(seq_entry) # Clear internal (raw tool markup) history['internal'][-1][1] = '' # Preserve thinking block and intermediate text from this turn. # content_prefix is the raw text before tool call syntax (returned # by parse_tool_call); HTML-escape it and extract thinking to get # the content the user should see. content_text = html.escape(content_prefix) thinking_content, intermediate = extract_thinking_block(content_text) if thinking_content: visible_prefix.append(f'<think>\n{thinking_content}\n</think>') if intermediate and intermediate.strip(): visible_prefix.append(intermediate.strip()) # Show placeholder accordions with "..." before execution starts # (tool calls may be slow, e.g. web search). pending_placeholders = [f'{h}\n...\n' for h in tc_headers] history['visible'][-1][1] = '\n\n'.join(visible_prefix + pending_placeholders) yield _render(), history # Execute tools, store results, and replace placeholders with real results for i, tc in enumerate(parsed_calls): # Check for stop request before each tool execution if shared.stop_everything: for j in range(i, len(parsed_calls)): seq.append({'role': 'tool', 'content': 'Tool execution was cancelled by the user.', 'tool_call_id': parsed_calls[j]['id']}) pending_placeholders[j] = f'{tc_headers[j]}\nCancelled\n' history['visible'][-1][1] = '\n\n'.join(visible_prefix + pending_placeholders) yield _render(), history break fn_name = tc['function']['name'] fn_args = tc['function'].get('arguments', {}) result = execute_tool(fn_name, fn_args, tool_executors) seq.append({'role': 'tool', 'content': result, 'tool_call_id': tc['id']}) try: pretty_result = json.dumps(json.loads(result), indent=2, ensure_ascii=False) except (json.JSONDecodeError, TypeError): pretty_result = result # Replace the placeholder with the real result pending_placeholders[i] = f'{tc_headers[i]}\n{pretty_result}\n' history['visible'][-1][1] = '\n\n'.join(visible_prefix + pending_placeholders) yield _render(), history # Move completed tool calls into visible_prefix for next turns visible_prefix.extend(pending_placeholders) history['visible'][-1][1] = '\n\n'.join(visible_prefix) save_history(history, state['unique_id'], state['character_menu'], state['mode']) state['history'] = history _tool_turn += 1 state.pop('_tool_turn', None) # If output extensions were deferred during tool turns, apply them now # to the final model response only (not to tool call markers). if state.pop('_skip_output_extensions', None): _model_visible = apply_extensions('output', _model_visible, state, is_chat=True) if visible_prefix: history['visible'][-1][1] = '\n\n'.join(visible_prefix + [_model_visible]) else: history['visible'][-1][1] = _model_visible yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']), history state['history'] = history # Sync version metadata so swipes show the full visible (with tool prefix) if visible_prefix and history.get('metadata'): row_idx = len(history['internal']) - 1 key = f"assistant_{row_idx}" meta_entry = history['metadata'].get(key, {}) if 'versions' in meta_entry and 'current_version_index' in meta_entry: current_idx = meta_entry['current_version_index'] if current_idx < len(meta_entry['versions']): version_update = { 'content': history['internal'][row_idx][1], 'visible_content': history['visible'][row_idx][1] } ts = meta_entry.get('tool_sequence') if ts is not None: version_update['tool_sequence'] = ts meta_entry['versions'][current_idx].update(version_update) save_history(history, state['unique_id'], state['character_menu'], state['mode']) def remove_last_message(history): if 'metadata' not in history: history['metadata'] = {} if len(history['visible']) > 0 and history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>': row_idx = len(history['internal']) - 1 last = history['visible'].pop() history['internal'].pop() # Remove metadata directly by known keys if f"user_{row_idx}" in history['metadata']: del history['metadata'][f"user_{row_idx}"] if f"assistant_{row_idx}" in history['metadata']: del history['metadata'][f"assistant_{row_idx}"] else: last = ['', ''] return html.unescape(last[0]), history def send_dummy_message(text, state): history = state['history'] # Handle both dict and string inputs if isinstance(text, dict): text = text['text'] # Initialize metadata if not present if 'metadata' not in history: history['metadata'] = {} row_idx = len(history['internal']) history['visible'].append([html.escape(text), '']) history['internal'].append([apply_extensions('input', text, state, is_chat=True), '']) update_message_metadata(history['metadata'], "user", row_idx, timestamp=get_current_timestamp()) return history def send_dummy_reply(text, state): history = state['history'] # Handle both dict and string inputs if isinstance(text, dict): text = text['text'] # Initialize metadata if not present if 'metadata' not in history: history['metadata'] = {} if len(history['visible']) > 0 and not history['visible'][-1][1] == '': row_idx = len(history['internal']) history['visible'].append(['', '']) history['internal'].append(['', '']) # We don't need to add system metadata row_idx = len(history['internal']) - 1 history['visible'][-1][1] = html.escape(text) history['internal'][-1][1] = apply_extensions('input', text, state, is_chat=True) update_message_metadata(history['metadata'], "assistant", row_idx, timestamp=get_current_timestamp()) return history def redraw_html(history, name1, name2, mode, style, character, reset_cache=False): return chat_html_wrapper(history, name1, name2, mode, style, character, reset_cache=reset_cache) def start_new_chat(state, unique_id=None): mode = state['mode'] # Initialize with empty metadata dictionary history = {'internal': [], 'visible': [], 'metadata': {}} if mode != 'instruct': greeting = replace_character_names(state['greeting'], state['name1'], state['name2']) if greeting != '': history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]] history['visible'] += [['', apply_extensions('output', html.escape(greeting), state, is_chat=True)]] # Add timestamp for assistant's greeting update_message_metadata(history['metadata'], "assistant", 0, timestamp=get_current_timestamp()) if unique_id is None: unique_id = datetime.now().strftime('%Y%m%d-%H-%M-%S') save_history(history, unique_id, state['character_menu'], state['mode']) return history def get_history_file_path(unique_id, character, mode): if mode == 'instruct': p = shared.user_data_dir / 'logs' / 'instruct' / f'{unique_id}.json' else: p = shared.user_data_dir / 'logs' / 'chat' / character / f'{unique_id}.json' return p def save_history(history, unique_id, character, mode): if shared.args.multi_user: return if unique_id and unique_id.startswith('incognito-'): return p = get_history_file_path(unique_id, character, mode) if not p.parent.is_dir(): p.parent.mkdir(parents=True) with _history_file_lock: with open(p, 'w', encoding='utf-8') as f: f.write(json.dumps(history, indent=4, ensure_ascii=False)) def rename_history(old_id, new_id, character, mode): if shared.args.multi_user: return old_p = get_history_file_path(old_id, character, mode) new_p = get_history_file_path(new_id, character, mode) if new_p.parent != old_p.parent: logger.error(f"The following path is not allowed: \"{new_p}\".") elif new_p == old_p: logger.info("The provided path is identical to the old one.") elif new_p.exists(): logger.error(f"The new path already exists and will not be overwritten: \"{new_p}\".") else: logger.info(f"Renaming \"{old_p}\" to \"{new_p}\"") old_p.rename(new_p) def get_paths(state): if state['mode'] == 'instruct': return (shared.user_data_dir / 'logs' / 'instruct').glob('*.json') else: character = state['character_menu'] # Handle obsolete filenames and paths old_p = shared.user_data_dir / 'logs' / f'{character}_persistent.json' new_p = shared.user_data_dir / 'logs' / f'persistent_{character}.json' if old_p.exists(): logger.warning(f"Renaming \"{old_p}\" to \"{new_p}\"") old_p.rename(new_p) if new_p.exists(): unique_id = datetime.now().strftime('%Y%m%d-%H-%M-%S') p = get_history_file_path(unique_id, character, state['mode']) logger.warning(f"Moving \"{new_p}\" to \"{p}\"") p.parent.mkdir(exist_ok=True) new_p.rename(p) return (shared.user_data_dir / 'logs' / 'chat' / character).glob('*.json') def find_all_histories(state): if shared.args.multi_user: return [''] paths = get_paths(state) histories = sorted(paths, key=lambda x: x.stat().st_mtime, reverse=True) return [path.stem for path in histories] def find_all_histories_with_first_prompts(state): if shared.args.multi_user: return [] paths = get_paths(state) histories = sorted(paths, key=lambda x: x.stat().st_mtime, reverse=True) result = [] for i, path in enumerate(histories): filename = path.stem file_content = "" with open(path, 'r', encoding='utf-8') as f: file_content = f.read() if state['search_chat'] and state['search_chat'] not in file_content: continue data = json.loads(file_content) if re.match(r'^[0-9]{8}-[0-9]{2}-[0-9]{2}-[0-9]{2}$', filename): first_prompt = "" if data and 'visible' in data and len(data['visible']) > 0: if len(data['internal']) > 0 and data['internal'][0][0] == '<|BEGIN-VISIBLE-CHAT|>': if len(data['visible']) > 1: first_prompt = html.unescape(data['visible'][1][0]) elif i == 0: first_prompt = "New chat" else: first_prompt = html.unescape(data['visible'][0][0]) elif i == 0: first_prompt = "New chat" else: first_prompt = filename first_prompt = first_prompt.strip() # Truncate the first prompt if it's longer than 30 characters if len(first_prompt) > 30: first_prompt = first_prompt[:30 - 3] + '...' result.append((first_prompt, filename)) return result def load_latest_history(state): ''' Loads the latest history for the given character in chat or chat-instruct mode, or the latest instruct history for instruct mode. ''' if shared.args.multi_user: return start_new_chat(state), None histories = find_all_histories(state) if len(histories) > 0: # Try to load the last visited chat for this character/mode chat_state = load_last_chat_state() key = get_chat_state_key(state['character_menu'], state['mode']) last_chat_id = chat_state.get("last_chats", {}).get(key) # If we have a stored last chat and it still exists, use it if last_chat_id and last_chat_id in histories: unique_id = last_chat_id else: # Fall back to most recent (current behavior) unique_id = histories[0] history = load_history(unique_id, state['character_menu'], state['mode']) return history, unique_id else: return start_new_chat(state), None def load_history_after_deletion(state, idx): ''' Loads the latest history for the given character in chat or chat-instruct mode, or the latest instruct history for instruct mode. ''' import gradio as gr if shared.args.multi_user: return start_new_chat(state) histories = find_all_histories_with_first_prompts(state) idx = min(int(idx), len(histories) - 1) idx = max(0, idx) if len(histories) > 0: history = load_history(histories[idx][1], state['character_menu'], state['mode']) else: history = start_new_chat(state) histories = find_all_histories_with_first_prompts(state) return history, gr.update(choices=histories, value=histories[idx][1]) def update_character_menu_after_deletion(idx): import gradio as gr characters = utils.get_available_characters() idx = min(int(idx), len(characters) - 1) idx = max(0, idx) return gr.update(choices=characters, value=characters[idx]) def get_chat_state_key(character, mode): """Generate a key for storing last chat state""" if mode == 'instruct': return 'instruct' else: return f"chat_{character}" def load_last_chat_state(): """Load the last chat state from file""" state_file = shared.user_data_dir / 'logs' / 'chat_state.json' if state_file.exists(): try: with open(state_file, 'r', encoding='utf-8') as f: return json.loads(f.read()) except Exception: pass return {"last_chats": {}} def save_last_chat_state(character, mode, unique_id): """Save the last visited chat for a character/mode""" if shared.args.multi_user: return if unique_id and unique_id.startswith('incognito-'): return state = load_last_chat_state() key = get_chat_state_key(character, mode) state["last_chats"][key] = unique_id state_file = shared.user_data_dir / 'logs' / 'chat_state.json' state_file.parent.mkdir(exist_ok=True) with open(state_file, 'w', encoding='utf-8') as f: f.write(json.dumps(state, indent=2)) def load_history(unique_id, character, mode): p = get_history_file_path(unique_id, character, mode) if not p.exists(): return {'internal': [], 'visible': [], 'metadata': {}} f = json.loads(open(p, 'rb').read()) if 'internal' in f and 'visible' in f: history = f else: history = { 'internal': f['data'], 'visible': f['data_visible'] } # Add metadata if it doesn't exist if 'metadata' not in history: history['metadata'] = {} # Add placeholder timestamps for existing messages for i, (user_msg, asst_msg) in enumerate(history['internal']): if user_msg and user_msg != '<|BEGIN-VISIBLE-CHAT|>': update_message_metadata(history['metadata'], "user", i, timestamp="") if asst_msg: update_message_metadata(history['metadata'], "assistant", i, timestamp="") return history def load_history_json(file, history): try: file = file.decode('utf-8') f = json.loads(file) if 'internal' in f and 'visible' in f: history = f else: history = { 'internal': f['data'], 'visible': f['data_visible'] } # Add metadata if it doesn't exist if 'metadata' not in history: history['metadata'] = {} # Add placeholder timestamps for i, (user_msg, asst_msg) in enumerate(history['internal']): if user_msg and user_msg != '<|BEGIN-VISIBLE-CHAT|>': update_message_metadata(history['metadata'], "user", i, timestamp="") if asst_msg: update_message_metadata(history['metadata'], "assistant", i, timestamp="") return history except Exception: return history def delete_history(unique_id, character, mode): p = get_history_file_path(unique_id, character, mode) delete_file(p) def replace_character_names(text, name1, name2): text = text.replace('{{user}}', name1).replace('{{char}}', name2) return text.replace('', name1).replace('', name2) def generate_pfp_cache(character): cache_folder = Path(shared.args.disk_cache_dir) if not cache_folder.exists(): cache_folder.mkdir() for path in [shared.user_data_dir / 'characters' / f"{character}.{extension}" for extension in ['png', 'jpg', 'jpeg']]: if path.exists(): original_img = Image.open(path) # Define file paths pfp_path = Path(f'{cache_folder}/pfp_character.png') thumb_path = Path(f'{cache_folder}/pfp_character_thumb.png') # Save main picture and thumbnail original_img.save(pfp_path, format='PNG') thumb = make_thumbnail(original_img) thumb.save(thumb_path, format='PNG') # Return the path to the thumbnail, not the in-memory PIL Image object. return str(thumb_path) return None def load_character(character, name1, name2): context = greeting = "" greeting_field = 'greeting' picture = None filepath = None for extension in ["yml", "yaml", "json"]: filepath = shared.user_data_dir / 'characters' / f'{character}.{extension}' if filepath.exists(): break if filepath is None or not filepath.exists(): logger.error(f"Could not find the character \"{character}\" inside {shared.user_data_dir}/characters. No character has been loaded.") raise ValueError file_contents = open(filepath, 'r', encoding='utf-8').read() data = json.loads(file_contents) if extension == "json" else yaml.safe_load(file_contents) cache_folder = Path(shared.args.disk_cache_dir) for path in [Path(f"{cache_folder}/pfp_character.png"), Path(f"{cache_folder}/pfp_character_thumb.png")]: if path.exists(): path.unlink() picture = generate_pfp_cache(character) # Finding the bot's name for k in ['name', 'bot', '<|bot|>', 'char_name']: if k in data and data[k] != '': name2 = data[k] break # Find the user name (if any) for k in ['your_name', 'user', '<|user|>']: if k in data and data[k] != '': name1 = data[k] break if 'context' in data: context = data['context'].strip() elif "char_persona" in data: context = build_pygmalion_style_context(data) greeting_field = 'char_greeting' greeting = data.get(greeting_field, greeting) return name1, name2, picture, greeting, context def restore_character_for_ui(state): """Reset character fields to the currently loaded character's saved values""" if state['character_menu'] and state['character_menu'] != 'None': try: name1, name2, picture, greeting, context = load_character(state['character_menu'], state['name1'], state['name2']) state['name2'] = name2 state['greeting'] = greeting state['context'] = context state['character_picture'] = picture # This triggers cache update via generate_pfp_cache return state, name2, context, greeting, picture except Exception as e: logger.error(f"Failed to reset character '{state['character_menu']}': {e}") return clear_character_for_ui(state) else: return clear_character_for_ui(state) def clear_character_for_ui(state): """Clear all character fields and picture cache""" state['name2'] = shared.settings['name2'] state['context'] = shared.settings['context'] state['greeting'] = shared.settings['greeting'] state['character_picture'] = None # Clear the cache files cache_folder = Path(shared.args.disk_cache_dir) for cache_file in ['pfp_character.png', 'pfp_character_thumb.png']: cache_path = Path(f'{cache_folder}/{cache_file}') if cache_path.exists(): cache_path.unlink() return state, state['name2'], state['context'], state['greeting'], None @functools.cache def load_character_memoized(character, name1, name2): return load_character(character, name1, name2) @functools.cache def load_instruction_template_memoized(template): from modules.models_settings import load_instruction_template return load_instruction_template(template) def upload_character(file, img_path, tavern=False): import gradio as gr img = open_image_safely(img_path) decoded_file = file if isinstance(file, str) else file.decode('utf-8') try: data = json.loads(decoded_file) except Exception: data = yaml.safe_load(decoded_file) if 'char_name' in data: name = sanitize_filename(data['char_name']) greeting = data['char_greeting'] context = build_pygmalion_style_context(data) yaml_data = generate_character_yaml(name, greeting, context) else: name = sanitize_filename(data['name']) yaml_data = generate_character_yaml(data['name'], data['greeting'], data['context']) outfile_name = name i = 1 while (shared.user_data_dir / 'characters' / f'{outfile_name}.yaml').exists(): outfile_name = f'{name}_{i:03d}' i += 1 with open(shared.user_data_dir / 'characters' / f'{outfile_name}.yaml', 'w', encoding='utf-8') as f: f.write(yaml_data) if img is not None: img.save(shared.user_data_dir / 'characters' / f'{outfile_name}.png') logger.info(f'New character saved to "{shared.user_data_dir}/characters/{outfile_name}.yaml".') return gr.update(value=outfile_name, choices=get_available_characters()) def build_pygmalion_style_context(data): context = "" if 'char_persona' in data and data['char_persona'] != '': context += f"{data['char_name']}'s Persona: {data['char_persona']}\n" if 'world_scenario' in data and data['world_scenario'] != '': context += f"Scenario: {data['world_scenario']}\n" if 'example_dialogue' in data and data['example_dialogue'] != '': context += f"{data['example_dialogue'].strip()}\n" context = f"{context.strip()}\n" return context def upload_tavern_character(img_path, _json): _json = {'char_name': _json['name'], 'char_persona': _json['description'], 'char_greeting': _json['first_mes'], 'example_dialogue': _json['mes_example'], 'world_scenario': _json['scenario']} return upload_character(json.dumps(_json), img_path, tavern=True) def check_tavern_character(img_path): import gradio as gr img = open_image_safely(img_path) if img is None: return "Invalid or disallowed image file.", None, None, gr.update(interactive=False) if "chara" not in img.info: return "Not a TavernAI card", None, None, gr.update(interactive=False) decoded_string = base64.b64decode(img.info['chara']).replace(b'\\r\\n', b'\\n') _json = json.loads(decoded_string) if "data" in _json: _json = _json["data"] return _json['name'], _json['description'], _json, gr.update(interactive=True) def upload_your_profile_picture(img_path): img = open_image_safely(img_path) cache_folder = Path(shared.args.disk_cache_dir) if not cache_folder.exists(): cache_folder.mkdir() if img is None: if Path(f"{cache_folder}/pfp_me.png").exists(): Path(f"{cache_folder}/pfp_me.png").unlink() else: img = make_thumbnail(img) img.save(Path(f'{cache_folder}/pfp_me.png')) logger.info(f'Profile picture saved to "{cache_folder}/pfp_me.png"') def generate_character_yaml(name, greeting, context): data = { 'name': name, 'greeting': greeting, 'context': context, } data = {k: v for k, v in data.items() if v} # Strip falsy return yaml.dump(data, sort_keys=False, width=float("inf")) def generate_instruction_template_yaml(instruction_template): data = { 'instruction_template': instruction_template } return my_yaml_output(data) def save_character(name, greeting, context, picture, filename): filename = sanitize_filename(filename) if filename == "": logger.error("The filename is empty, so the character will not be saved.") return data = generate_character_yaml(name, greeting, context) filepath = shared.user_data_dir / 'characters' / f'{filename}.yaml' save_file(filepath, data) path_to_img = shared.user_data_dir / 'characters' / f'{filename}.png' if picture is not None: # Copy the image file from its source path to the character folder shutil.copy(picture, path_to_img) logger.info(f'Saved {path_to_img}.') def delete_character(name, instruct=False): name = sanitize_filename(name) # Check for character data files for extension in ["yml", "yaml", "json"]: delete_file(shared.user_data_dir / 'characters' / f'{name}.{extension}') # Check for character image files for extension in ["png", "jpg", "jpeg"]: delete_file(shared.user_data_dir / 'characters' / f'{name}.{extension}') def generate_user_pfp_cache(user): """Generate cached profile picture for user""" cache_folder = Path(shared.args.disk_cache_dir) if not cache_folder.exists(): cache_folder.mkdir() for path in [shared.user_data_dir / 'users' / f"{user}.{extension}" for extension in ['png', 'jpg', 'jpeg']]: if path.exists(): original_img = Image.open(path) # Define file paths pfp_path = Path(f'{cache_folder}/pfp_me.png') # Save thumbnail thumb = make_thumbnail(original_img) thumb.save(pfp_path, format='PNG') logger.info(f'User profile picture cached to "{pfp_path}"') return str(pfp_path) return None def load_user(user_name, name1, user_bio): """Load user profile from YAML file""" picture = None filepath = None for extension in ["yml", "yaml", "json"]: filepath = shared.user_data_dir / 'users' / f'{user_name}.{extension}' if filepath.exists(): break if filepath is None or not filepath.exists(): logger.error(f"Could not find the user \"{user_name}\" inside {shared.user_data_dir}/users. No user has been loaded.") raise ValueError with open(filepath, 'r', encoding='utf-8') as f: file_contents = f.read() extension = filepath.suffix[1:] # Remove the leading dot data = json.loads(file_contents) if extension == "json" else yaml.safe_load(file_contents) # Clear existing user picture cache cache_folder = Path(shared.args.disk_cache_dir) pfp_path = Path(f"{cache_folder}/pfp_me.png") if pfp_path.exists(): pfp_path.unlink() # Generate new picture cache picture = generate_user_pfp_cache(user_name) # Get user name if 'name' in data and data['name'] != '': name1 = data['name'] # Get user bio if 'user_bio' in data: user_bio = data['user_bio'] return name1, user_bio, picture def generate_user_yaml(name, user_bio): """Generate YAML content for user profile""" data = { 'name': name, 'user_bio': user_bio, } return yaml.dump(data, sort_keys=False, width=float("inf")) def save_user(name, user_bio, picture, filename): """Save user profile to YAML file""" filename = sanitize_filename(filename) if filename == "": logger.error("The filename is empty, so the user will not be saved.") return # Ensure the users directory exists users_dir = shared.user_data_dir / 'users' users_dir.mkdir(parents=True, exist_ok=True) data = generate_user_yaml(name, user_bio) filepath = shared.user_data_dir / 'users' / f'{filename}.yaml' save_file(filepath, data) path_to_img = shared.user_data_dir / 'users' / f'{filename}.png' if picture is not None: # Copy the image file from its source path to the users folder shutil.copy(picture, path_to_img) logger.info(f'Saved user profile picture to {path_to_img}.') def delete_user(name): """Delete user profile files""" name = sanitize_filename(name) # Check for user data files for extension in ["yml", "yaml", "json"]: delete_file(shared.user_data_dir / 'users' / f'{name}.{extension}') # Check for user image files for extension in ["png", "jpg", "jpeg"]: delete_file(shared.user_data_dir / 'users' / f'{name}.{extension}') def update_user_menu_after_deletion(idx): """Update user menu after a user is deleted""" import gradio as gr users = get_available_users() if len(users) == 0: # Create a default user if none exist save_user('You', '', None, 'Default') users = get_available_users() idx = min(int(idx), len(users) - 1) idx = max(0, idx) return gr.update(choices=users, value=users[idx]) def handle_user_menu_change(state): """Handle user menu selection change""" try: name1, user_bio, picture = load_user(state['user_menu'], state['name1'], state['user_bio']) return [ name1, user_bio, picture ] except Exception as e: logger.error(f"Failed to load user '{state['user_menu']}': {e}") return [ state['name1'], state['user_bio'], None ] def handle_save_user_click(name1): """Handle save user button click""" import gradio as gr return [ name1, gr.update(visible=True) ] def my_yaml_output(data): ''' pyyaml is very inconsistent with multiline strings. for simple instruction template outputs, this is enough. ''' result = "" for k in data: result += k + ": |-\n" for line in data[k].splitlines(): result += " " + line.rstrip(' ') + "\n" return result def handle_send_dummy_message_click(text, state): history = send_dummy_message(text, state) save_history(history, state['unique_id'], state['character_menu'], state['mode']) html = redraw_html(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']) return [history, html, {"text": "", "files": []}] def handle_send_dummy_reply_click(text, state): history = send_dummy_reply(text, state) save_history(history, state['unique_id'], state['character_menu'], state['mode']) html = redraw_html(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']) return [history, html, {"text": "", "files": []}] def handle_remove_last_click(state): last_input, history = remove_last_message(state['history']) save_history(history, state['unique_id'], state['character_menu'], state['mode']) html = redraw_html(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']) return [history, html, {"text": last_input, "files": []}] def handle_unique_id_select(state): history = load_history(state['unique_id'], state['character_menu'], state['mode']) html = redraw_html(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']) # Save this as the last visited chat save_last_chat_state(state['character_menu'], state['mode'], state['unique_id']) convert_to_markdown.cache_clear() return [history, html] def handle_start_new_chat_click(state): import gradio as gr history = start_new_chat(state) histories = find_all_histories_with_first_prompts(state) html = redraw_html(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']) convert_to_markdown.cache_clear() if len(histories) > 0: past_chats_update = gr.update(choices=histories, value=histories[0][1]) else: past_chats_update = gr.update(choices=histories) return [history, html, past_chats_update] def handle_start_incognito_chat_click(state): import gradio as gr unique_id = 'incognito-' + datetime.now().strftime('%Y%m%d-%H-%M-%S') history = start_new_chat(state, unique_id=unique_id) html = redraw_html(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']) convert_to_markdown.cache_clear() histories = find_all_histories_with_first_prompts(state) past_chats_update = gr.update(choices=histories, value=unique_id) return [history, html, past_chats_update] def handle_delete_chat_confirm_click(state): filtered_histories = find_all_histories_with_first_prompts(state) filtered_ids = [h[1] for h in filtered_histories] if state['unique_id'] not in filtered_ids: # Incognito or unknown chat — just load the most recent saved chat index = '0' else: index = str(filtered_ids.index(state['unique_id'])) delete_history(state['unique_id'], state['character_menu'], state['mode']) history, unique_id = load_history_after_deletion(state, index) html = redraw_html(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']) convert_to_markdown.cache_clear() return [history, html, unique_id] def handle_branch_chat_click(state): import gradio as gr branch_from_index = state['branch_index'] if branch_from_index == -1: history = state['history'] else: history = state['history'] history['visible'] = history['visible'][:branch_from_index + 1] history['internal'] = history['internal'][:branch_from_index + 1] # Prune the metadata dictionary to remove entries beyond the branch point if 'metadata' in history: history['metadata'] = {k: v for k, v in history['metadata'].items() if int(k.split('_')[-1]) <= branch_from_index} prefix = 'incognito-' if state['unique_id'] and state['unique_id'].startswith('incognito-') else '' new_unique_id = prefix + datetime.now().strftime('%Y%m%d-%H-%M-%S') save_history(history, new_unique_id, state['character_menu'], state['mode']) histories = find_all_histories_with_first_prompts(state) html = redraw_html(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']) convert_to_markdown.cache_clear() past_chats_update = gr.update(choices=histories, value=new_unique_id) return [history, html, past_chats_update, -1] def handle_edit_message_click(state): history = state['history'] message_index = int(state['edit_message_index']) new_text = state['edit_message_text'] role = state['edit_message_role'] # "user" or "assistant" if message_index >= len(history['internal']): html_output = redraw_html(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']) return [history, html_output] role_idx = 0 if role == "user" else 1 if 'metadata' not in history: history['metadata'] = {} key = f"{role}_{message_index}" if key not in history['metadata']: history['metadata'][key] = {} # If no versions exist yet for this message, store the current (pre-edit) content as the first version. if "versions" not in history['metadata'][key] or not history['metadata'][key]["versions"]: original_content = history['internal'][message_index][role_idx] original_visible = history['visible'][message_index][role_idx] original_timestamp = history['metadata'][key].get('timestamp', get_current_timestamp()) version_entry = { "content": original_content, "visible_content": original_visible, "timestamp": original_timestamp } ts = history['metadata'][key].get('tool_sequence') if ts is not None: version_entry['tool_sequence'] = ts history['metadata'][key]["versions"] = [version_entry] history['internal'][message_index][role_idx] = apply_extensions('input', new_text, state, is_chat=True) history['visible'][message_index][role_idx] = html.escape(new_text) history['metadata'][key].pop('tool_sequence', None) add_message_version(history, role, message_index, is_current=True) save_history(history, state['unique_id'], state['character_menu'], state['mode']) html_output = redraw_html(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']) return [history, html_output] def handle_navigate_version_click(state): history = state['history'] message_index = int(state['navigate_message_index']) direction = state['navigate_direction'] role = state['navigate_message_role'] if not role: logger.error("Role not provided for version navigation.") html = redraw_html(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']) return [history, html] key = f"{role}_{message_index}" if 'metadata' not in history or key not in history['metadata'] or 'versions' not in history['metadata'][key]: html = redraw_html(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']) return [history, html] metadata = history['metadata'][key] versions = metadata['versions'] # Default to the last version if current_version_index is not set current_idx = metadata.get('current_version_index', len(versions) - 1 if versions else 0) if direction == 'left': new_idx = max(0, current_idx - 1) else: # right new_idx = min(len(versions) - 1, current_idx + 1) if new_idx == current_idx: html = redraw_html(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']) return [history, html] msg_content_idx = 0 if role == 'user' else 1 # 0 for user content, 1 for assistant content in the pair version_to_load = versions[new_idx] history['internal'][message_index][msg_content_idx] = version_to_load['content'] history['visible'][message_index][msg_content_idx] = version_to_load['visible_content'] metadata['current_version_index'] = new_idx # Restore per-version tool_sequence so follow-up prompts see consistent context version_ts = version_to_load.get('tool_sequence') if version_ts is not None: metadata['tool_sequence'] = version_ts else: metadata.pop('tool_sequence', None) update_message_metadata(history['metadata'], role, message_index, timestamp=version_to_load['timestamp']) # Redraw and save html = redraw_html(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']) save_history(history, state['unique_id'], state['character_menu'], state['mode']) return [history, html] def handle_rename_chat_click(): import gradio as gr return [ gr.update(value="My New Chat"), gr.update(visible=True), ] def handle_rename_chat_confirm(rename_to, state): import gradio as gr if state['unique_id'] and state['unique_id'].startswith('incognito-'): return [ gr.update(), gr.update(visible=False), ] rename_history(state['unique_id'], rename_to, state['character_menu'], state['mode']) histories = find_all_histories_with_first_prompts(state) return [ gr.update(choices=histories, value=rename_to), gr.update(visible=False), ] def handle_search_chat_change(state): import gradio as gr histories = find_all_histories_with_first_prompts(state) return gr.update(choices=histories) def handle_upload_chat_history(load_chat_history, state): import gradio as gr history = start_new_chat(state) history = load_history_json(load_chat_history, history) save_history(history, state['unique_id'], state['character_menu'], state['mode']) histories = find_all_histories_with_first_prompts(state) html = redraw_html(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']) convert_to_markdown.cache_clear() if len(histories) > 0: past_chats_update = gr.update(choices=histories, value=histories[0][1]) else: past_chats_update = gr.update(choices=histories) return [ history, html, past_chats_update ] def handle_character_menu_change(state): import gradio as gr name1, name2, picture, greeting, context = load_character(state['character_menu'], state['name1'], state['name2']) state['name1'] = name1 state['name2'] = name2 state['character_picture'] = picture state['greeting'] = greeting state['context'] = context history, loaded_unique_id = load_latest_history(state) histories = find_all_histories_with_first_prompts(state) html = redraw_html(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']) convert_to_markdown.cache_clear() if len(histories) > 0: past_chats_update = gr.update(choices=histories, value=loaded_unique_id or histories[0][1]) else: past_chats_update = gr.update(choices=histories) return [ history, html, name1, name2, picture, greeting, context, past_chats_update ] def handle_character_picture_change(picture_path): """Update or clear cache when character picture changes""" picture = open_image_safely(picture_path) cache_folder = Path(shared.args.disk_cache_dir) if not cache_folder.exists(): cache_folder.mkdir() if picture is not None: # Save to cache picture.save(Path(f'{cache_folder}/pfp_character.png'), format='PNG') thumb = make_thumbnail(picture) thumb.save(Path(f'{cache_folder}/pfp_character_thumb.png'), format='PNG') else: # Remove cache files when picture is cleared for cache_file in ['pfp_character.png', 'pfp_character_thumb.png']: cache_path = Path(f'{cache_folder}/{cache_file}') if cache_path.exists(): cache_path.unlink() def handle_mode_change(state): import gradio as gr history, loaded_unique_id = load_latest_history(state) histories = find_all_histories_with_first_prompts(state) # Ensure character picture cache exists if state['mode'] in ['chat', 'chat-instruct'] and state['character_menu'] and state['character_menu'] != 'None': generate_pfp_cache(state['character_menu']) html = redraw_html(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']) convert_to_markdown.cache_clear() if len(histories) > 0: past_chats_update = gr.update(choices=histories, value=loaded_unique_id or histories[0][1]) else: past_chats_update = gr.update(choices=histories) return [ history, html, gr.update(visible=state['mode'] != 'instruct'), gr.update(visible=state['mode'] == 'chat-instruct'), past_chats_update ] def handle_save_character_click(name2): import gradio as gr return [ name2, gr.update(visible=True) ] def handle_load_template_click(instruction_template): from modules.models_settings import load_instruction_template output = load_instruction_template(instruction_template) return [ output, "Select template to load..." ] def handle_save_template_click(instruction_template_str): import gradio as gr contents = generate_instruction_template_yaml(instruction_template_str) root = str(shared.user_data_dir / 'instruction-templates') + '/' return [ "My Template.yaml", root, contents, root, gr.update(visible=True) ] def handle_delete_template_click(template): import gradio as gr root = str(shared.user_data_dir / 'instruction-templates') + '/' return [ f"{template}.yaml", root, root, gr.update(visible=False) ] def handle_your_picture_change(picture, state): upload_your_profile_picture(picture) html = redraw_html(state['history'], state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'], reset_cache=True) return html def handle_send_instruction_click(state): import gradio as gr state['mode'] = 'instruct' state['history'] = {'internal': [], 'visible': [], 'metadata': {}} output = generate_chat_prompt("Input", state) if state["show_two_notebook_columns"]: return gr.update(), output, "" else: return output, gr.update(), gr.update() def handle_send_chat_click(state): import gradio as gr output = generate_chat_prompt("", state, _continue=True) if state["show_two_notebook_columns"]: return gr.update(), output, "" else: return output, gr.update(), gr.update() ================================================ FILE: modules/evaluate.py ================================================ import datetime from pathlib import Path import pandas as pd from tqdm import tqdm from modules import shared from modules.logging_colors import logger from modules.models import load_model, unload_model from modules.models_settings import get_model_metadata, update_model_parameters from modules.text_generation import encode def load_past_evaluations(): if (shared.user_data_dir / 'logs' / 'evaluations.csv').exists(): df = pd.read_csv(shared.user_data_dir / 'logs' / 'evaluations.csv', dtype=str) df['Perplexity'] = pd.to_numeric(df['Perplexity']) return df else: return pd.DataFrame(columns=['Model', 'LoRAs', 'Dataset', 'Perplexity', 'stride', 'max_length', 'Date', 'Comment']) past_evaluations = load_past_evaluations() def save_past_evaluations(df): global past_evaluations past_evaluations = df filepath = shared.user_data_dir / 'logs' / 'evaluations.csv' filepath.parent.mkdir(parents=True, exist_ok=True) df.to_csv(filepath, index=False) def calculate_perplexity(models, input_dataset, stride, _max_length): ''' Based on: https://huggingface.co/docs/transformers/perplexity#calculating-ppl-with-fixedlength-models ''' import torch from datasets import load_dataset from modules.torch_utils import clear_torch_cache if shared.args.loader == "llama.cpp": logger.error("Perplexity evaluation is not implemented for the llama.cpp loader.") raise ValueError if not shared.args.no_use_fast: logger.warning("--no_use_fast is not set. If tokenizing the input dataset takes a long time, try reloading the model with that option set/checked.") global past_evaluations cumulative_log = '' cumulative_log += "Loading the input dataset...\n\n" yield cumulative_log # Copied from https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/triton/utils/datautils.py if input_dataset == 'wikitext': data = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') text = "\n\n".join(data['text']) elif input_dataset == 'ptb': data = load_dataset('ptb_text_only', 'penn_treebank', split='validation') text = "\n\n".join(data['sentence']) elif input_dataset == 'ptb_new': data = load_dataset('ptb_text_only', 'penn_treebank', split='test') text = " ".join(data['sentence']) else: with open(shared.user_data_dir / 'training' / 'datasets' / f'{input_dataset}.txt', 'r', encoding='utf-8') as f: text = f.read() for model in models: if is_in_past_evaluations(model, input_dataset, stride, _max_length): cumulative_log += f"`{model}` has already been tested. Ignoring.\n\n" yield cumulative_log continue if model != 'current model': try: yield cumulative_log + f"Loading `{model}`...\n\n" model_settings = get_model_metadata(model) shared.settings.update({k: v for k, v in model_settings.items() if k in shared.settings}) # hijacking the interface defaults update_model_parameters(model_settings) # hijacking the command-line arguments unload_model() shared.model, shared.tokenizer = load_model(model) except Exception: cumulative_log += f"Failed to load `{model}`. Moving on.\n\n" yield cumulative_log continue cumulative_log += f"Processing `{shared.model_name}`...\n\n" yield cumulative_log + "Tokenizing the input dataset...\n\n" encodings = encode(text, add_special_tokens=False) seq_len = encodings.shape[1] if _max_length: max_length = _max_length elif hasattr(shared.model.config, 'max_position_embeddings'): max_length = shared.model.config.max_position_embeddings else: max_length = 2048 nlls = [] prev_end_loc = 0 for begin_loc in tqdm(range(0, seq_len, stride)): yield cumulative_log + f"Evaluating... {100*begin_loc/seq_len:.2f}%" end_loc = min(begin_loc + max_length, seq_len) trg_len = end_loc - prev_end_loc # may be different from stride on last loop input_ids = encodings[:, begin_loc:end_loc] target_ids = input_ids.clone() target_ids[:, :-trg_len] = -100 clear_torch_cache() with torch.no_grad(): outputs = shared.model(input_ids=input_ids, labels=target_ids) # loss is calculated using CrossEntropyLoss which averages over valid labels # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels # to the left by 1. neg_log_likelihood = outputs.loss nlls.append(neg_log_likelihood) prev_end_loc = end_loc if end_loc == seq_len: break ppl = torch.exp(torch.stack(nlls).mean()) add_entry_to_past_evaluations(float(ppl), shared.model_name, input_dataset, stride, _max_length) save_past_evaluations(past_evaluations) message = f"The perplexity for `{shared.model_name}` is: {float(ppl)}" logger.info(message) cumulative_log += f"{message}\n\n" yield cumulative_log def add_entry_to_past_evaluations(perplexity, model, dataset, stride, max_length): global past_evaluations entry = { 'Model': model, 'LoRAs': ', '.join(shared.lora_names) or '-', 'Dataset': dataset, 'Perplexity': perplexity, 'stride': str(stride), 'max_length': str(max_length), 'Date': datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'Comment': '' } past_evaluations = pd.concat([past_evaluations, pd.DataFrame([entry])], ignore_index=True) def is_in_past_evaluations(model, dataset, stride, max_length): entries = past_evaluations[(past_evaluations['Model'] == model) & (past_evaluations['Dataset'] == dataset) & (past_evaluations['max_length'] == str(max_length)) & (past_evaluations['stride'] == str(stride))] if entries.shape[0] > 0: return True else: return False def generate_markdown_table(): sorted_df = past_evaluations.sort_values(by=['Dataset', 'stride', 'Perplexity', 'Date']) return sorted_df ================================================ FILE: modules/exllamav3.py ================================================ import math import queue import threading import traceback from pathlib import Path from typing import Any, List, Tuple import torch from exllamav3 import Cache, Config, Generator, Model, Tokenizer from exllamav3.cache import CacheLayer_fp16, CacheLayer_quant from exllamav3.generator import Job from exllamav3.generator.filter import Filter from exllamav3.generator.sampler import ( CustomSampler, SS_AdaptiveP, SS_Argmax, SS_MinP, SS_PresFreqP, SS_RepP, SS_Sample, SS_Temperature, SS_TopK, SS_TopP ) from modules import shared from modules.image_utils import ( convert_image_attachments_to_pil, convert_openai_messages_to_images ) from modules.logging_colors import logger from modules.text_generation import get_max_prompt_length try: import flash_attn except Exception: logger.warning('Failed to load flash-attention due to the following error:\n') traceback.print_exc() class LogitBiasFilter(Filter): """Filter subclass that applies a static additive logit bias mask.""" def __init__(self, tokenizer, logit_bias_dict): super().__init__(tokenizer=tokenizer, trigger_token=None, prefix_str=None, eos_after_completed=False) self.logit_bias_dict = logit_bias_dict self._mask = None def reset(self): pass def accept_token(self, token): pass def is_completed(self): return False def use_background_worker(self): return False def get_next_logit_mask(self): if self._mask is None: self._mask = torch.zeros((1, self.vocab_size), dtype=self.logits_dtype) for token_id_str, bias in self.logit_bias_dict.items(): token_id = int(token_id_str) if 0 <= token_id < self.vocab_size: self._mask[0, token_id] = bias return self._mask class ConcurrentGenerator: def __init__(self, generator): self.generator = generator self.lock = threading.Lock() self.job_queues = {} self.active = True self.has_jobs = threading.Event() self.thread = threading.Thread(target=self._iterate_loop, daemon=True) self.thread.start() def _iterate_loop(self): while self.active: self.has_jobs.wait(timeout=0.5) with self.lock: if not self.job_queues: self.has_jobs.clear() continue try: results = self.generator.iterate() except Exception: logger.error("Exception in ConcurrentGenerator iterate loop:\n" + traceback.format_exc()) for q in self.job_queues.values(): q.put(None) self.job_queues.clear() self.generator.clear_queue() self.has_jobs.clear() continue for result in results: job = result["job"] q = self.job_queues.get(job) if q: q.put(result) if result.get("eos"): self.job_queues.pop(job, None) if not self.job_queues: self.has_jobs.clear() def submit(self, job) -> queue.Queue: q = queue.Queue() with self.lock: self.job_queues[job] = q self.generator.enqueue(job) self.has_jobs.set() return q def cancel(self, job): with self.lock: if job in self.job_queues: self.generator.cancel(job) self.job_queues[job].put(None) del self.job_queues[job] def stop(self): self.active = False self.has_jobs.set() self.thread.join(timeout=5) class Exllamav3Model: def __init__(self): pass @property def device(self) -> torch.device: return torch.device(0) @classmethod def from_pretrained(cls, path_to_model): path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model) # Reset global MMTokenAllocator to prevent token ID corruption when switching models from exllamav3.tokenizer.mm_embedding import ( FIRST_MM_EMBEDDING_INDEX, global_allocator ) global_allocator.next_token_index = FIRST_MM_EMBEDDING_INDEX config = Config.from_directory(str(path_to_model)) model = Model.from_config(config) # Calculate the closest multiple of 256 at or above the chosen value max_tokens = shared.args.ctx_size if max_tokens % 256 != 0: adjusted_tokens = ((max_tokens // 256) + 1) * 256 logger.warning(f"max_num_tokens must be a multiple of 256. Adjusting from {max_tokens} to {adjusted_tokens}") max_tokens = adjusted_tokens # Parse cache type cache_type = shared.args.cache_type.lower() cache_kwargs = {} if cache_type == 'fp16': layer_type = CacheLayer_fp16 elif cache_type.startswith('q'): layer_type = CacheLayer_quant if '_' in cache_type: # Different bits for k and v (e.g., q4_q8) k_part, v_part = cache_type.split('_') k_bits = int(k_part[1:]) v_bits = int(v_part[1:]) else: # Same bits for k and v (e.g., q4) k_bits = v_bits = int(cache_type[1:]) # Validate bit ranges if not (2 <= k_bits <= 8 and 2 <= v_bits <= 8): logger.warning(f"Invalid quantization bits: k_bits={k_bits}, v_bits={v_bits}. Must be between 2 and 8. Falling back to fp16.") layer_type = CacheLayer_fp16 else: cache_kwargs = {'k_bits': k_bits, 'v_bits': v_bits} else: logger.warning(f"Unrecognized cache type: {cache_type}. Falling back to fp16.") layer_type = CacheLayer_fp16 cache = Cache(model, max_num_tokens=max_tokens, layer_type=layer_type, **cache_kwargs) load_params = {'progressbar': True} split = None if shared.args.gpu_split: split = [float(alloc) for alloc in shared.args.gpu_split.split(",")] load_params['use_per_device'] = split # Tensor-parallelism if shared.args.enable_tp: load_params['tensor_p'] = True load_params['tp_backend'] = shared.args.tp_backend # Load vision and draft before the main model so autosplit # accounts for their VRAM usage. # Load vision model component (ExLlamaV3 native) vision_model = None if "vision_config" in config.config_dict: logger.info("Vision component detected in model config. Attempting to load...") try: vision_model = Model.from_config(config, component="vision") vision_model.load(progressbar=True) logger.info("Vision model loaded successfully.") except Exception as e: logger.warning(f"Vision model loading failed (multimodal disabled): {e}") else: logger.info("No vision component in model config. Skipping multimodal setup.") # Initialize draft model for speculative decoding draft_model = None draft_cache = None if shared.args.model_draft and shared.args.model_draft.lower() not in ["", "none"]: logger.info(f"Loading draft model for speculative decoding: {shared.args.model_draft}") draft_path = Path(shared.args.model_draft) if not draft_path.is_dir(): draft_path = Path(f'{shared.args.model_dir}') / Path(shared.args.model_draft) if not draft_path.is_dir(): logger.warning(f"Draft model not found at {draft_path}, speculative decoding disabled.") else: draft_config = Config.from_directory(str(draft_path)) draft_model = Model.from_config(draft_config) draft_cache = Cache(draft_model, max_num_tokens=max_tokens, layer_type=layer_type, **cache_kwargs) draft_load_params = {'progressbar': True} if split: draft_load_params['use_per_device'] = split draft_model.load(**draft_load_params) logger.info(f"Draft model loaded successfully. Max speculative tokens: {shared.args.draft_max}") # Load main model last model.load(**load_params) tokenizer = Tokenizer.from_config(config) generator = Generator( model=model, cache=cache, tokenizer=tokenizer, draft_model=draft_model, draft_cache=draft_cache, num_draft_tokens=shared.args.draft_max if draft_model is not None else 0, ) result = cls() result.model = model result.cache = cache result.tokenizer = tokenizer result.generator = generator result.parallel_generator = ConcurrentGenerator(generator) result.config = config result.max_tokens = max_tokens result.vision_model = vision_model result.draft_model = draft_model result.draft_cache = draft_cache return result, result def is_multimodal(self) -> bool: """Check if this model supports multimodal input.""" return hasattr(self, 'vision_model') and self.vision_model is not None def _process_images_for_generation(self, prompt: str, state: dict) -> Tuple[str, List[Any]]: """ Process all possible image inputs and return modified prompt + embeddings. Returns: (processed_prompt, image_embeddings) """ # Collect images from various sources using shared utilities pil_images = [] # From webui image_attachments (preferred format) if 'image_attachments' in state and state['image_attachments']: pil_images.extend(convert_image_attachments_to_pil(state['image_attachments'])) # From OpenAI API raw_images elif 'raw_images' in state and state['raw_images']: pil_images.extend(state['raw_images']) # From OpenAI API messages format elif 'messages' in state and state['messages']: pil_images.extend(convert_openai_messages_to_images(state['messages'])) if not pil_images: return prompt, [] # ExLlamaV3-specific: Generate embeddings try: # Use pre-computed embeddings if available (proper MMEmbedding lifetime) if 'image_embeddings' in state and state['image_embeddings']: # Use existing embeddings - this preserves MMEmbedding lifetime image_embeddings = state['image_embeddings'] else: # Do not reset the cache/allocator index; it causes token ID conflicts during generation. logger.info(f"Processing {len(pil_images)} image(s) with ExLlamaV3 vision model") image_embeddings = [ self.vision_model.get_image_embeddings(tokenizer=self.tokenizer, image=img) for img in pil_images ] # ExLlamaV3-specific: Handle prompt processing with placeholders placeholders = [ie.text_alias for ie in image_embeddings] if '<__media__>' in prompt: # Web chat: Replace <__media__> placeholders for alias in placeholders: prompt = prompt.replace('<__media__>', alias, 1) logger.info(f"Replaced {len(placeholders)} <__media__> placeholder(s)") else: # API: Prepend embedding aliases combined_placeholders = "\n".join(placeholders) prompt = combined_placeholders + "\n" + prompt logger.info(f"Prepended {len(placeholders)} embedding(s) to prompt") return prompt, image_embeddings except Exception as e: logger.error(f"Failed to process images: {e}") return prompt, [] def generate_with_streaming(self, prompt, state): """ Generate text with streaming using native ExLlamaV3 API """ if shared.is_multimodal: # Process images and modify prompt (ExLlamaV3-specific) prompt, image_embeddings = self._process_images_for_generation(prompt, state) else: image_embeddings = [] # Greedy decoding is a special case if state['temperature'] == 0: sampler = CustomSampler([SS_Argmax()]) else: # 1. Create a list of all active, unordered samplers unordered_samplers = [] # Penalties penalty_range = state['repetition_penalty_range'] if penalty_range <= 0: penalty_range = int(10e7) # Use large number for "full context" rep_decay = 0 # Not a configurable parameter # Add penalty samplers if they are active if state['repetition_penalty'] != 1.0: unordered_samplers.append(SS_RepP(state['repetition_penalty'], penalty_range, rep_decay)) if state['presence_penalty'] != 0.0 or state['frequency_penalty'] != 0.0: unordered_samplers.append(SS_PresFreqP(state['presence_penalty'], state['frequency_penalty'], penalty_range, rep_decay)) # Standard samplers if state['top_k'] > 0: unordered_samplers.append(SS_TopK(state['top_k'])) if state['top_p'] < 1.0: unordered_samplers.append(SS_TopP(state['top_p'])) if state['min_p'] > 0.0: unordered_samplers.append(SS_MinP(state['min_p'])) # Temperature (SS_NoOp is returned if temp is 1.0) unordered_samplers.append(SS_Temperature(state['temperature'])) # 2. Define the mapping from class names to the priority list keys class_name_to_nickname = { 'SS_RepP': 'repetition_penalty', 'SS_PresFreqP': 'presence_frequency_penalty', 'SS_TopK': 'top_k', 'SS_TopP': 'top_p', 'SS_MinP': 'min_p', 'SS_Temperature': 'temperature', } # 3. Get the priority list and handle temperature_last default_priority = ['repetition_penalty', 'presence_frequency_penalty', 'top_k', 'top_p', 'min_p', 'temperature'] sampler_priority = list(state.get('sampler_priority') or default_priority) if state['temperature_last'] and 'temperature' in sampler_priority: sampler_priority.append(sampler_priority.pop(sampler_priority.index('temperature'))) # The preset system uses separate 'presence_penalty' and # 'frequency_penalty', but ExLlamaV3 has a single combined # SS_PresFreqP sampler. Normalize to the combined name. sampler_priority = ['presence_frequency_penalty' if x in ('presence_penalty', 'frequency_penalty') else x for x in sampler_priority] # 4. Sort the unordered list based on the priority list def custom_sort_key(sampler_obj): class_name = sampler_obj.__class__.__name__ nickname = class_name_to_nickname.get(class_name) if nickname and nickname in sampler_priority: return sampler_priority.index(nickname) return -1 ordered_samplers = sorted(unordered_samplers, key=custom_sort_key) # 5. Add the final sampling stage and build the sampler if state.get('adaptive_target', 0) > 0: ordered_samplers.append(SS_AdaptiveP(state['adaptive_target'], state['adaptive_decay'])) else: ordered_samplers.append(SS_Sample()) sampler = CustomSampler(ordered_samplers) # Encode prompt with embeddings (ExLlamaV3-specific) input_ids = self.tokenizer.encode( prompt, add_bos=state['add_bos_token'], encode_special_tokens=True, embeddings=image_embeddings, ) input_ids = input_ids[:, -get_max_prompt_length(state):] self._last_prompt_token_count = input_ids.shape[-1] # Determine max_new_tokens if state['auto_max_new_tokens']: max_new_tokens = state['truncation_length'] - self._last_prompt_token_count else: max_new_tokens = state['max_new_tokens'] # Use full EOS token list from config (may contain multiple IDs) stop_conditions = [] if not state['ban_eos_token']: for eos_id in self.config.eos_token_id_list: if eos_id is not None: stop_conditions.append(eos_id) # Build filters for logit_bias (OpenAI API) filters = [] logit_bias = state.get('logit_bias') if logit_bias: filters.append(LogitBiasFilter(self.tokenizer, logit_bias)) # Logprobs support (OpenAI API) logprobs = state.get('logprobs', 0) or 0 return_top_tokens = logprobs if logprobs > 0 else 0 seed = state.get('seed', -1) job = Job( input_ids=input_ids, max_new_tokens=max_new_tokens, decode_special_tokens=not state['skip_special_tokens'], embeddings=image_embeddings if image_embeddings else None, sampler=sampler, seed=seed if seed >= 0 else None, stop_conditions=stop_conditions if stop_conditions else None, filters=filters if filters else None, return_top_tokens=return_top_tokens, return_probs=return_top_tokens > 0, ) # Stream generation response_text = "" stop_event = state.get('stop_event') self.last_completion_probabilities = [] result_queue = self.parallel_generator.submit(job) try: while True: if shared.stop_everything or (stop_event and stop_event.is_set()): break try: result = result_queue.get(timeout=0.1) except queue.Empty: continue if result is None or result.get("eos"): # Capture logprobs from the final eos result too if result is not None and return_top_tokens > 0: self._capture_logprobs(result) break chunk = result.get("text", "") # Capture logprobs from streaming results if return_top_tokens > 0: self._capture_logprobs(result) if chunk: response_text += chunk yield response_text finally: self.parallel_generator.cancel(job) def _capture_logprobs(self, result): """Convert ExLlamav3 top-k token data to the shared logprobs format.""" top_k_tokens = result.get("top_k_tokens") top_k_probs = result.get("top_k_probs") if top_k_tokens is None or top_k_probs is None: return id_to_piece = self.tokenizer.get_id_to_piece_list(True) # top_k_tokens shape: (batch, seq_len, k), top_k_probs same for seq_idx in range(top_k_tokens.shape[1]): entry = {"top_logprobs": []} for k_idx in range(top_k_tokens.shape[2]): token_id = top_k_tokens[0, seq_idx, k_idx].item() prob = top_k_probs[0, seq_idx, k_idx].item() token_str = id_to_piece[token_id] if token_id < len(id_to_piece) else f"<{token_id}>" logprob = math.log(prob) if prob > 0 else float("-inf") entry["top_logprobs"].append({"token": token_str, "logprob": logprob}) self.last_completion_probabilities.append(entry) def generate(self, prompt, state): output = "" for chunk in self.generate_with_streaming(prompt, state): output = chunk return output def get_logits(self, token_ids, **kwargs): """ Process a batch of token_ids and return the logits for the last token. This will reset and overwrite the model's cache. """ # Initialize a single params dictionary that will be updated in-place params = { "cache": self.cache, "reconstruct": False, "attn_mode": "flash_attn", "batch_shape": (1, self.max_tokens), "past_len": 0 } params.update(kwargs) # Process prefix tokens to fill the cache and generate recurrent state if token_ids.shape[-1] > 1: prefix_ids = token_ids[:, :-1] # This forward call updates the 'params' dict with the recurrent state self.model.forward( input_ids=prefix_ids, params=params ) # Update past_len for the next call params["past_len"] = prefix_ids.shape[-1] # Process the last token, now using the state-filled 'params' dict last_token_ids = token_ids[:, -1:] logits = self.model.forward( input_ids=last_token_ids, params=params ) return logits.float().cpu() def encode(self, string, **kwargs): add_bos = kwargs.pop('add_bos', True) return self.tokenizer.encode(string, add_bos=add_bos, **kwargs) def decode(self, ids, **kwargs): if isinstance(ids, torch.Tensor) and ids.dim() == 0: ids = ids.view(1) return self.tokenizer.decode(ids, **kwargs) @property def last_prompt_token_count(self): return getattr(self, '_last_prompt_token_count', 0) def unload(self): logger.info("Unloading ExLlamaV3 model components...") if hasattr(self, 'parallel_generator') and self.parallel_generator is not None: try: self.parallel_generator.stop() except Exception as e: logger.warning(f"Error stopping parallel generator: {e}") self.parallel_generator = None if hasattr(self, 'vision_model') and self.vision_model is not None: try: del self.vision_model except Exception as e: logger.warning(f"Error unloading vision model: {e}") self.vision_model = None if hasattr(self, 'draft_model') and self.draft_model is not None: try: self.draft_model.unload() del self.draft_model except Exception as e: logger.warning(f"Error unloading draft model: {e}") self.draft_model = None if hasattr(self, 'draft_cache') and self.draft_cache is not None: self.draft_cache = None if hasattr(self, 'model') and self.model is not None: try: self.model.unload() del self.model except Exception as e: logger.warning(f"Error unloading main model: {e}") self.model = None if hasattr(self, 'cache') and self.cache is not None: self.cache = None if hasattr(self, 'generator') and self.generator is not None: self.generator = None if hasattr(self, 'tokenizer') and self.tokenizer is not None: self.tokenizer = None ================================================ FILE: modules/exllamav3_hf.py ================================================ import os import traceback from pathlib import Path from typing import Any, Dict, Optional, Union import torch from torch.nn import CrossEntropyLoss from transformers import ( GenerationConfig, GenerationMixin, PretrainedConfig, PreTrainedModel ) from transformers.modeling_outputs import CausalLMOutputWithPast from exllamav3 import Cache, Config, Model from exllamav3.cache import CacheLayer_fp16, CacheLayer_quant from modules import shared from modules.logging_colors import logger try: import flash_attn except Exception: logger.warning('Failed to load flash-attention due to the following error:\n') traceback.print_exc() class Exllamav3HF(PreTrainedModel, GenerationMixin): def __init__(self, model_dir): hf_config = PretrainedConfig.from_pretrained(model_dir) super().__init__(hf_config) exl3_config = Config.from_directory(model_dir) self.generation_config = GenerationConfig() self.ex_model = Model.from_config(exl3_config) # Calculate the closest multiple of 256 at or above the chosen value max_tokens = shared.args.ctx_size if max_tokens % 256 != 0: adjusted_tokens = ((max_tokens // 256) + 1) * 256 logger.warning(f"max_num_tokens must be a multiple of 256. Adjusting from {max_tokens} to {adjusted_tokens}") max_tokens = adjusted_tokens # Parse cache type cache_type = shared.args.cache_type.lower() cache_kwargs = {} if cache_type == 'fp16': layer_type = CacheLayer_fp16 elif cache_type.startswith('q'): layer_type = CacheLayer_quant if '_' in cache_type: # Different bits for k and v (e.g., q4_q8) k_part, v_part = cache_type.split('_') k_bits = int(k_part[1:]) v_bits = int(v_part[1:]) else: # Same bits for k and v (e.g., q4) k_bits = v_bits = int(cache_type[1:]) # Validate bit ranges if not (2 <= k_bits <= 8 and 2 <= v_bits <= 8): logger.warning(f"Invalid quantization bits: k_bits={k_bits}, v_bits={v_bits}. Must be between 2 and 8. Falling back to fp16.") layer_type = CacheLayer_fp16 else: cache_kwargs = {'k_bits': k_bits, 'v_bits': v_bits} else: logger.warning(f"Unrecognized cache type: {cache_type}. Falling back to fp16.") layer_type = CacheLayer_fp16 self.ex_cache = Cache(self.ex_model, max_num_tokens=max_tokens, layer_type=layer_type, **cache_kwargs) # Create load parameters dictionary load_params = {'progressbar': True} if shared.args.gpu_split: split = [float(alloc) for alloc in shared.args.gpu_split.split(",")] load_params['use_per_device'] = split # Tensor-parallelism if shared.args.enable_tp: load_params['tensor_p'] = True load_params['tp_backend'] = shared.args.tp_backend self.ex_model.load(**load_params) self.past_seq = None self.max_tokens = max_tokens self.layer_type = layer_type self.cache_kwargs = cache_kwargs if shared.args.cfg_cache: self.ex_cache_negative = Cache(self.ex_model, max_num_tokens=max_tokens, layer_type=layer_type, **cache_kwargs) self.past_seq_negative = None def _validate_model_class(self): pass def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): pass def prepare_inputs_for_generation(self, input_ids, **kwargs): return {'input_ids': input_ids, **kwargs} @property def device(self) -> torch.device: return torch.device(0) def __call__(self, *args, **kwargs): use_cache = kwargs.get('use_cache', True) labels = kwargs.get('labels', None) past_key_values = kwargs.get('past_key_values', None) if len(args) > 0: if not shared.args.cfg_cache: logger.error("Please enable the cfg-cache option to use CFG with ExLlamav3_HF.") return input_ids = args[0] is_negative = True past_seq = self.past_seq_negative ex_cache = self.ex_cache_negative else: input_ids = kwargs['input_ids'] is_negative = False past_seq = self.past_seq ex_cache = self.ex_cache seq = input_ids[0].tolist() if is_negative and past_key_values is not None: seq = past_key_values + seq seq_tensor = torch.tensor(seq) reset = True # Maximum number of tokens to process in a single forward pass max_chunk_size = 2048 # Make the forward call if labels is None: if past_seq is not None: min_length = min(past_seq.shape[0], seq_tensor.shape[0]) indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length])) if len(indices) > 0: longest_prefix = indices[0].item() else: longest_prefix = min_length if longest_prefix > 0: reset = False current_len = longest_prefix remaining_tokens = len(seq_tensor) - longest_prefix - 1 if remaining_tokens > 0: # Process tokens from longest_prefix to second-to-last token tokens_to_process = seq_tensor[longest_prefix:-1] # Use prefill() to fill the cache without computing logits for i in range(0, tokens_to_process.shape[0], max_chunk_size): chunk = tokens_to_process[i:i + max_chunk_size] self.ex_model.prefill( input_ids=chunk.view(1, -1), params={ "attn_mode": "flash_attn", "cache": ex_cache, "past_len": longest_prefix + i, "batch_shape": (1, self.max_tokens), } ) current_len = longest_prefix + remaining_tokens if reset: if len(seq_tensor) > 1: # Process all tokens except the last one tokens_to_process = seq_tensor[:-1] # Use prefill() to fill the cache without computing logits current_len = 0 for i in range(0, tokens_to_process.shape[0], max_chunk_size): chunk = tokens_to_process[i:i + max_chunk_size] self.ex_model.prefill( input_ids=chunk.view(1, -1), params={ "attn_mode": "flash_attn", "cache": ex_cache, "past_len": current_len, "batch_shape": (1, self.max_tokens), } ) current_len += chunk.shape[0] else: current_len = 0 # Process the last token and get logits logits = self.ex_model.forward( input_ids=seq_tensor[-1:].view(1, -1), params={ "attn_mode": "flash_attn", "cache": ex_cache, "past_len": current_len, "batch_shape": (1, self.max_tokens), } ).to(input_ids.device).float() else: # Labels path: use cache for cross-chunk attention. tokens_to_process = seq_tensor all_logits = None current_len = 0 for i in range(0, tokens_to_process.shape[0], max_chunk_size): chunk = tokens_to_process[i:i + max_chunk_size] chunk_logits = self.ex_model.forward( input_ids=chunk.view(1, -1), params={ "attn_mode": "flash_attn", "cache": ex_cache, "past_len": current_len, "batch_shape": (1, self.max_tokens), } ).float() current_len += chunk.shape[0] if all_logits is None: all_logits = chunk_logits else: all_logits = torch.cat([all_logits, chunk_logits], dim=1) logits = all_logits if is_negative: self.past_seq_negative = seq_tensor else: self.past_seq = seq_tensor if torch.cuda.is_available(): torch.cuda.synchronize() loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, logits.shape[-1]) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) return CausalLMOutputWithPast(logits=logits, past_key_values=seq if use_cache else None, loss=loss) @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs): assert len(model_args) == 0 and len(kwargs) == 0, "extra args is currently not supported" if isinstance(pretrained_model_name_or_path, str): pretrained_model_name_or_path = Path(pretrained_model_name_or_path) pretrained_model_name_or_path = Path(f'{shared.args.model_dir}') / Path(pretrained_model_name_or_path) return Exllamav3HF(pretrained_model_name_or_path) def unload(self): """Properly unload the ExllamaV3 model and free GPU memory.""" if hasattr(self, 'ex_model') and self.ex_model is not None: self.ex_model.unload() self.ex_model = None if hasattr(self, 'ex_cache') and self.ex_cache is not None: self.ex_cache = None # Clean up any additional ExllamaV3 resources if hasattr(self, 'past_seq'): self.past_seq = None if hasattr(self, 'past_seq_negative'): self.past_seq_negative = None if hasattr(self, 'ex_cache_negative'): self.ex_cache_negative = None ================================================ FILE: modules/extensions.py ================================================ import importlib import importlib.util import sys import traceback from functools import partial from inspect import signature from pathlib import Path import modules.shared as shared from modules.logging_colors import logger state = {} available_extensions = [] setup_called = set() def apply_settings(extension, name): if not hasattr(extension, 'params'): return for param in extension.params: _id = f"{name}-{param}" shared.default_settings[_id] = extension.params[param] if _id in shared.settings: extension.params[param] = shared.settings[_id] def load_extensions(): global state, setup_called state = {} for i, name in enumerate(shared.args.extensions): if name not in available_extensions: continue if name != 'api': logger.info(f'Loading the extension "{name}"') try: # Prefer user extension, fall back to system extension user_script_path = shared.user_data_dir / 'extensions' / name / 'script.py' if user_script_path.exists(): spec = importlib.util.spec_from_file_location( f"user_ext_{name}", str(user_script_path) ) extension = importlib.util.module_from_spec(spec) sys.modules[spec.name] = extension spec.loader.exec_module(extension) else: extension = importlib.import_module(f"extensions.{name}.script") if extension not in setup_called: apply_settings(extension, name) if hasattr(extension, "setup"): extension.setup() setup_called.add(extension) state[name] = [True, i, extension] # Store extension object except ModuleNotFoundError: extension_location = shared.user_data_dir / 'extensions' / name if user_script_path.exists() else Path('extensions') / name windows_path = str(extension_location).replace('/', '\\') logger.error( f"Could not import the requirements for '{name}'. Make sure to install the requirements for the extension.\n\n" f"* To install requirements automatically, launch the update_wizard script for your OS and:\n\n" f"1. Choose option B (Install/update extensions requirements)\n" f"2. Select '{name}' from the extension list\n\n" f"* To install requirements manually, launch the cmd script for your OS and paste the following command:\n\n" f"Linux / Mac:\n\n" f"pip install -r {extension_location}/requirements.txt --upgrade\n\n" f"Windows:\n\n" f"pip install -r {windows_path}\\requirements.txt --upgrade\n" ) raise except Exception: logger.error(f'Failed to load the extension "{name}".') traceback.print_exc() # This iterator returns the extensions in the order specified in the command-line def iterator(): for name in sorted(state, key=lambda x: state[x][1]): if state[name][0]: yield state[name][2], name # Use stored extension object # Extension functions that map string -> string def _apply_string_extensions(function_name, text, state, is_chat=False): for extension, _ in iterator(): if hasattr(extension, function_name): func = getattr(extension, function_name) # Handle old extensions without the 'state' arg or # the 'is_chat' kwarg count = 0 has_chat = False for k in signature(func).parameters: if k == 'is_chat': has_chat = True else: count += 1 if count == 2: args = [text, state] else: args = [text] if has_chat: kwargs = {'is_chat': is_chat} else: kwargs = {} text = func(*args, **kwargs) return text # Extension functions that map string -> string def _apply_chat_input_extensions(text, visible_text, state): for extension, _ in iterator(): if hasattr(extension, 'chat_input_modifier'): text, visible_text = extension.chat_input_modifier(text, visible_text, state) return text, visible_text # custom_generate_chat_prompt handling - currently only the first one will work def _apply_custom_generate_chat_prompt(text, state, **kwargs): for extension, _ in iterator(): if hasattr(extension, 'custom_generate_chat_prompt'): return extension.custom_generate_chat_prompt(text, state, **kwargs) return None # Extension that modifies the input parameters before they are used def _apply_state_modifier_extensions(state): for extension, _ in iterator(): if hasattr(extension, "state_modifier"): state = getattr(extension, "state_modifier")(state) return state # Extension that modifies the chat history before it is used def _apply_history_modifier_extensions(history): for extension, _ in iterator(): if hasattr(extension, "history_modifier"): history = getattr(extension, "history_modifier")(history) return history # Extension functions that override the default tokenizer output - The order of execution is not defined def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds): for extension, _ in iterator(): if hasattr(extension, function_name): prompt, input_ids, input_embeds = getattr(extension, function_name)(state, prompt, input_ids, input_embeds) return prompt, input_ids, input_embeds # Allow extensions to add their own logits processors to the stack being run. # Each extension would call `processor_list.append({their LogitsProcessor}())`. def _apply_logits_processor_extensions(function_name, processor_list, input_ids): for extension, _ in iterator(): if hasattr(extension, function_name): result = getattr(extension, function_name)(processor_list, input_ids) if type(result) is list: processor_list = result return processor_list # Get prompt length in tokens after applying extension functions which override the default tokenizer output # currently only the first one will work def _apply_custom_tokenized_length(prompt): for extension, _ in iterator(): if hasattr(extension, 'custom_tokenized_length'): return getattr(extension, 'custom_tokenized_length')(prompt) return None # Custom generate reply handling - currently only the first one will work def _apply_custom_generate_reply(): for extension, _ in iterator(): if hasattr(extension, 'custom_generate_reply'): return getattr(extension, 'custom_generate_reply') return None def _apply_custom_css(): all_css = '' for extension, _ in iterator(): if hasattr(extension, 'custom_css'): all_css += getattr(extension, 'custom_css')() return all_css def _apply_custom_js(): all_js = '' for extension, _ in iterator(): if hasattr(extension, 'custom_js'): all_js += getattr(extension, 'custom_js')() return all_js def create_extensions_block(): import gradio as gr to_display = [] for extension, name in iterator(): if hasattr(extension, "ui") and not (hasattr(extension, 'params') and extension.params.get('is_tab', False)): to_display.append((extension, name)) # Creating the extension ui elements if len(to_display) > 0: with gr.Column(elem_id="extensions"): for row in to_display: extension, _ = row extension.ui() def create_extensions_tabs(): import gradio as gr for extension, name in iterator(): if hasattr(extension, "ui") and (hasattr(extension, 'params') and extension.params.get('is_tab', False)): display_name = getattr(extension, 'params', {}).get('display_name', name) with gr.Tab(display_name, elem_classes="extension-tab"): extension.ui() EXTENSION_MAP = { "input": partial(_apply_string_extensions, "input_modifier"), "output": partial(_apply_string_extensions, "output_modifier"), "chat_input": _apply_chat_input_extensions, "state": _apply_state_modifier_extensions, "history": _apply_history_modifier_extensions, "bot_prefix": partial(_apply_string_extensions, "bot_prefix_modifier"), "tokenizer": partial(_apply_tokenizer_extensions, "tokenizer_modifier"), 'logits_processor': partial(_apply_logits_processor_extensions, 'logits_processor_modifier'), "custom_generate_chat_prompt": _apply_custom_generate_chat_prompt, "custom_generate_reply": _apply_custom_generate_reply, "tokenized_length": _apply_custom_tokenized_length, "css": _apply_custom_css, "js": _apply_custom_js } def apply_extensions(typ, *args, **kwargs): if typ not in EXTENSION_MAP: raise ValueError(f"Invalid extension type {typ}") return EXTENSION_MAP[typ](*args, **kwargs) ================================================ FILE: modules/grammar/grammar_utils.py ================================================ ''' This file has been 100% copied from this PR to the Transformers library: https://github.com/huggingface/transformers/pull/27557 Author: Saibo-creator Author GitHub: https://github.com/Saibo-creator All credits go to the author. ''' import logging import re import time from abc import ABC from functools import lru_cache from typing import Dict, List import torch from modules import shared logger = logging.getLogger(__name__) ######################## # EBNF Grammar Parsing # ######################## END_OF_ALTERNATE_MARKER = 0 END_OF_RULE_MARKER = 0 TO_BE_FILLED_MARKER = 0 REF_RULE_MARKER = 1 LITERAL_MARKER = 2 class ParseState: def __init__(self): self.symbol_ids = {} self.grammar_encoding = [] # old name: out_grammar def get_symbol_id(state, src): if src not in state.symbol_ids: state.symbol_ids[src] = len(state.symbol_ids) return state.symbol_ids[src] def generate_symbol_id(state, base_name): next_id = len(state.symbol_ids) state.symbol_ids[base_name + "_" + str(next_id)] = next_id return next_id def is_word_char(c): return c.isalnum() or c == "-" or c == "_" def hex_to_int(c): if c.isdigit(): return int(c) elif "a" <= c.lower() <= "f": return ord(c.lower()) - ord("a") + 10 raise RuntimeError("unknown hex char " + c) def remove_leading_white_space(src, newline_ok): """ Skips over whitespace and comments in the input string. This function processes the input string, skipping over any spaces, tabs, and content following a '#' character, which denotes a comment. The parsing of a comment continues until the end of the line (denoted by newline characters '\r' or '\n'). If the 'newline_ok' parameter is set to False, the function will stop processing and return the remaining string upon encountering a newline character, otherwise it will skip over newline characters as well. Parameters: src (str): The input string to be processed. newline_ok (bool): A flag indicating whether encountering a newline character should stop the parsing (False) or if it should be skipped (True). Returns: str: The remaining portion of the input string after skipping whitespace and comments. """ pos = 0 while pos < len(src) and (src[pos].isspace() or src[pos] == "#"): if src[pos] == "#": while pos < len(src) and src[pos] not in ("\r", "\n"): pos += 1 else: if not newline_ok and src[pos] in ("\r", "\n"): break pos += 1 return src[pos:] def parse_name(src): pos = 0 while pos < len(src) and is_word_char(src[pos]): pos += 1 if pos == 0: raise RuntimeError("expecting name at " + src) return src[:pos], src[pos:] def read_hex(s): val = 0 for c in s: val = (val << 4) + hex_to_int(c) return chr(val) def parse_char(src): """ parse the leading char from the input string :param src: :return: char, remaining_src """ # if we have a backslash, it's maybe an escape if src[0] == "\\": esc = src[1] if esc == "x": return read_hex(src[2:4]), src[4:] elif esc == "u": return read_hex(src[2:6]), src[6:] elif esc == "U": return read_hex(src[2:10]), src[10:] elif esc in ('"', "[", "]", "\\", "-"): return esc, src[2:] elif esc == "r": return "\r", src[2:] elif esc == "n": return "\n", src[2:] elif esc == "t": return "\t", src[2:] elif esc == "\\": return "\\", src[2:] raise RuntimeError("unknown escape at " + src) elif src: return src[0], src[1:] raise RuntimeError("unexpected end of input") def parse_sequence(state, src, rule_name, outbuf, is_nested): out_start_pos = len(outbuf) # sequence size, will be replaced at end when known outbuf.append(TO_BE_FILLED_MARKER) last_sym_start = len(outbuf) remaining_src = src while remaining_src: if remaining_src[0] == '"': # literal string remaining_src = remaining_src[1:] last_sym_start = len(outbuf) while remaining_src[0] != '"': char, remaining_src = parse_char(remaining_src) # each char of a literal is encoded as a "range" of char - char outbuf.append(LITERAL_MARKER) outbuf.append(ord(char)) outbuf.append(ord(char)) remaining_src = remove_leading_white_space(remaining_src[1:], is_nested) elif remaining_src[0] == "[": # char range(s) remaining_src = remaining_src[1:] last_sym_start = len(outbuf) # num chars in range - replaced at end of loop outbuf.append(TO_BE_FILLED_MARKER) while remaining_src[0] != "]": char, remaining_src = parse_char(remaining_src) outbuf.append(ord(char)) if remaining_src[0] == "-" and remaining_src[1] != "]": endchar_pair, remaining_src = parse_char(remaining_src[1:]) outbuf.append(ord(endchar_pair)) else: # chars that aren't part of a c1-c2 range are just doubled (i.e., c-c) outbuf.append(ord(char)) # replace num chars with actual outbuf[last_sym_start] = len(outbuf) - last_sym_start - 1 remaining_src = remove_leading_white_space(remaining_src[1:], is_nested) elif is_word_char(remaining_src[0]): # rule reference name, remaining_src = parse_name(remaining_src) ref_rule_id = get_symbol_id(state, name) remaining_src = remove_leading_white_space(remaining_src, is_nested) last_sym_start = len(outbuf) outbuf.append(REF_RULE_MARKER) outbuf.append(ref_rule_id) elif remaining_src[0] == "(": # grouping # parse nested alternates into synthesized rule remaining_src = remove_leading_white_space(remaining_src[1:], True) sub_rule_id = generate_symbol_id(state, rule_name) remaining_src = parse_alternates(state, remaining_src, rule_name, sub_rule_id, True) last_sym_start = len(outbuf) # output reference to synthesized rule outbuf.append(REF_RULE_MARKER) outbuf.append(sub_rule_id) if remaining_src[0] != ")": raise RuntimeError("expecting ')' at " + remaining_src) remaining_src = remove_leading_white_space(remaining_src[1:], is_nested) elif remaining_src[0] in ("*", "+", "?"): # repetition operator if len(outbuf) - out_start_pos - 1 == 0: raise RuntimeError("expecting preceeding item to */+/? at " + remaining_src) out_grammar = state.grammar_encoding # apply transformation to previous symbol (last_sym_start - # end) according to rewrite rules: # S* --> S' ::= S S' | # S+ --> S' ::= S S' | S # S? --> S' ::= S | sub_rule_id = generate_symbol_id(state, rule_name) out_grammar.append(sub_rule_id) sub_rule_start = len(out_grammar) # placeholder for size of 1st alternate out_grammar.append(TO_BE_FILLED_MARKER) # add preceding symbol to generated rule out_grammar.extend(outbuf[last_sym_start:]) if remaining_src[0] in ("*", "+"): # cause generated rule to recurse out_grammar.append(REF_RULE_MARKER) out_grammar.append(sub_rule_id) # apply actual size out_grammar[sub_rule_start] = len(out_grammar) - sub_rule_start # mark end of 1st alternate out_grammar.append(END_OF_ALTERNATE_MARKER) sub_rule_start = len(out_grammar) # placeholder for size of 2nd alternate out_grammar.append(TO_BE_FILLED_MARKER) if remaining_src[0] == "+": # add preceding symbol as alternate only for '+' out_grammar.extend(outbuf[last_sym_start:]) # apply actual size of 2nd alternate out_grammar[sub_rule_start] = len(out_grammar) - sub_rule_start # mark end of 2nd alternate, then end of rule out_grammar.append(END_OF_ALTERNATE_MARKER) out_grammar.append(END_OF_RULE_MARKER) # in original rule, replace previous symbol with reference to generated rule outbuf[last_sym_start:] = [1, sub_rule_id] remaining_src = remove_leading_white_space(remaining_src[1:], is_nested) else: break # apply actual size of this alternate sequence outbuf[out_start_pos] = len(outbuf) - out_start_pos # mark end of alternate outbuf.append(END_OF_ALTERNATE_MARKER) return remaining_src def parse_alternates(state, src, rule_name, rule_id, is_nested): outbuf = [] remaining_src = parse_sequence(state, src, rule_name, outbuf, is_nested) while remaining_src and remaining_src[0] == "|": remaining_src = remove_leading_white_space(remaining_src[1:], True) remaining_src = parse_sequence(state, remaining_src, rule_name, outbuf, is_nested) state.grammar_encoding.append(rule_id) state.grammar_encoding.extend(outbuf) state.grammar_encoding.append(0) return remaining_src def parse_rule(state, src): name, remaining_src = parse_name(src) remaining_src = remove_leading_white_space(remaining_src, False) rule_id = get_symbol_id(state, name) if remaining_src[:3] != "::=": raise RuntimeError("expecting ::= at " + remaining_src) remaining_src = remove_leading_white_space(remaining_src[3:], True) remaining_src = parse_alternates(state, remaining_src, name, rule_id, False) if remaining_src and remaining_src[0] == "\r": remaining_src = remaining_src[2:] if remaining_src[1] == "\n" else remaining_src[1:] elif remaining_src and remaining_src[0] == "\n": remaining_src = remaining_src[1:] elif remaining_src: raise RuntimeError("expecting newline or end at " + remaining_src) return remove_leading_white_space(remaining_src, True) def parse_ebnf(src): try: state = ParseState() grammar_repr = remove_leading_white_space(src, True) last_grammar_repr = "" while grammar_repr: if last_grammar_repr: last_parsed_rule_len = len(last_grammar_repr) - len(grammar_repr) logger.debug(f"last_parsed_rule: {last_grammar_repr[:last_parsed_rule_len]}") last_grammar_repr = grammar_repr grammar_repr = parse_rule(state, grammar_repr) state.grammar_encoding.append(0xFFFF) return state except RuntimeError as err: logger.warning("error parsing grammar:", err) return ParseState() def print_rule(file, grammar_encoding, index, symbol_id_names): rule_id = grammar_encoding[index] print(f"<{index}>{symbol_id_names[rule_id]} ::=", end=" ", file=file) pos = index + 1 while grammar_encoding[pos]: if pos - 1 > index: print("|", end=" ", file=file) pos += 1 # sequence size, not needed here while grammar_encoding[pos]: if grammar_encoding[pos] == REF_RULE_MARKER: ref_rule_id = grammar_encoding[pos + 1] print( f"<{pos}>{symbol_id_names[ref_rule_id]}", end=" ", file=file, ) pos += 2 else: print("<{}>[".format(pos), end="", file=file) num_chars = grammar_encoding[pos] pos += 1 for i in range(0, num_chars, 2): print("{}-".format(chr(grammar_encoding[pos + i])), end="", file=file) if i + 1 < num_chars: print("{}".format(chr(grammar_encoding[pos + i + 1])), end="", file=file) print("]", end=" ", file=file) pos += num_chars pos += 1 print(file=file) return pos + 1 def print_grammar(file, state): pos = 0 symbol_id_names = {v: k for k, v in state.symbol_ids.items()} print("Grammar Rules:", file=file) while state.grammar_encoding[pos] != 0xFFFF: pos = print_rule(file, state.grammar_encoding, pos, symbol_id_names) pos = 0 print("\nBinary representation:", file=file) while state.grammar_encoding[pos] != 0xFFFF: print(f"{state.grammar_encoding[pos]:04x}", end=" ", file=file) pos += 1 print("ffff\n") ################################### # EBNF Grammar Parsing ends here # ################################### class GrammarConstraint(ABC): def __init__(self, grammar_str, start_rule_name, tokenizer): self.tt = 0 self.nt = 0 state = parse_ebnf(grammar_str) grammar_encoding = state.grammar_encoding self.start_rule_id = state.symbol_ids.get(start_rule_name) self.eos_token_id = tokenizer.eos_token_id self.token_trie = TokenTrie(tokenizer) self.tokenizer = tokenizer self.grammar_encoding = grammar_encoding pos = 0 rules: Dict[int, int] = {} while grammar_encoding[pos] != 0xFFFF: rule_id = grammar_encoding[pos] # Store the current position in the 'rules' list at the index corresponding to rule_id. # This effectively maps each rule_id to its position in the grammar encoding. rules[rule_id] = pos pos += 1 # Continue to the next rule in the encoding. # The loop advances by the size indicated at the current position (grammar_encoding[pos]) # plus one for the size field itself. while grammar_encoding[pos]: pos += 1 + grammar_encoding[pos] # Now we're at the end of the rule, # so advance to the next rule by skipping the 0, which means 'end of rule'. pos += 1 self.start_rule_pos = rules[self.start_rule_id] self.rules_pos_dict: Dict[int, int] = rules def init_stacks(self): # suppose the start rule position is 0, then grammar_encoding[0] = rule_id # grammar_encoding[1] = rule_size # grammar_encoding[2] = rule_type # this is why we need to add 2 to the start rule position stack = [self.start_rule_pos + 2] # convert to tuple for caching(immutable) return self.advance_stack(tuple(stack)) # For each stack, resolve rules to find the actual characters that are # accepted by this stack (not the set of sub-rules). # This is where the parsing happens. # The parsing is a top-down, left-to-right, depth-first traversal of the # grammar. @lru_cache(maxsize=32768) def advance_stack(self, stack): stack = list(stack) # If the stack is empty, we're done. Because no more tokens should be accepted. if len(stack) == 0: return [stack] # Get the top of the stack. pos = stack[-1] # If the stack head is a terminal(literal), we can resolve it immediately. # literal is marked with 2 in the grammar encoding. if self.grammar_encoding[pos] > 1: return [stack] # The stack head is a nonterminal (a rule reference, 1 in the grammar encoding). # Resolving this rule gives a set of one or more possible positions # (e.g. two in `a ::= b | c`) # We pop the current rule off the stack and, for each option, push: # - the symbol following this symbol in the current rule; then # - the first symbol of the resolved rule. referenced_rule_id = self.grammar_encoding[pos + 1] # subpos should points to the size of the subrule subpos = self.rules_pos_dict[referenced_rule_id] + 1 stacks: List[List[int]] = [] # do depth-first search to find all possible rules and check the next terminal # When this value is non-zero, it indicates that subpos is not yet at the end of the rule, so we can continue. # here subpos is a pointer, and the value in the rule encoding can never be 0 except for the end of the rule. while self.grammar_encoding[subpos]: new_stack = stack[:-1] if self.grammar_encoding[pos + 2]: # check if there is a next symbol in the current rule, e.g. `a ::= b c | d` # if yes, push the pos to rule_size to the stack new_stack.append(pos + 2) # if the type of the next symbol is not "empty", push the first symbol of the resolved rule to the stack if self.grammar_encoding[subpos + 1]: new_stack.append(subpos + 1) stacks.extend(self.advance_stack(tuple(new_stack))) # The increment subpos += self.grammar_encoding[subpos] + 1 # moves subpos forward in the grammar encoding array to the next alternative in the current rule. subpos += self.grammar_encoding[subpos] + 1 return stacks def accept_char(self, *args, **kwargs): """Process a byte according to the grammar rules.""" raise NotImplementedError def accept_token_id(self, *args, **kwargs): """Process a token according to the grammar rules.""" raise NotImplementedError def filter_vocab(self, *args, **kwargs): raise NotImplementedError class IncrementalGrammarConstraint(GrammarConstraint): def __init__(self, grammar_str, start_rule_name, tokenizer): super().__init__(grammar_str, start_rule_name, tokenizer) def accept_char(self, char, stacks): byte = char if isinstance(char, int) else ord(char) new_stacks = [] for stack in stacks: # stack is empty if not stack: continue pos = stack[-1] num_chars = self.grammar_encoding[pos] # to make pos point to the size of the char range rule pos += 1 found = False for i in range(0, num_chars, 2): if self.grammar_encoding[pos + i] <= byte and byte <= self.grammar_encoding[pos + i + 1]: found = True break if self.grammar_encoding[pos + i] >= byte and byte >= self.grammar_encoding[pos + i + 1]: found = True break if not found: continue pos += num_chars new_stack = stack[:-1] if self.grammar_encoding[pos]: new_stack.append(pos) new_stacks.extend(self.advance_stack(tuple(new_stack))) return new_stacks def accept_string(self, string: str, stacks: List[List[int]]): for char in string: stacks = self.accept_char(char, stacks) return stacks def accept_token_id(self, token_id: int, stacks: List[List[int]]): if token_id == self.eos_token_id: if stacks and all(len(stack) != 0 for stack in stacks): raise Exception( f"At least one of the stack should be empty when EOS is reached. However, " f"the stacks are {stacks}" ) return [] for byte in self.token_trie.id2str(token_id): stacks = self.accept_char(byte, stacks) # check updated stacks # TODO, I commented this out because it will fail when the stack is empty # empty stack means the end of the grammar # assert stacks != [] return stacks def accept_token_ids(self, token_ids: List[int], stacks: List[List[int]], as_string=True): if as_string: string = self.tokenizer.decode(token_ids) stacks = self.accept_string(string, stacks) else: for token_id in token_ids: stacks = self.accept_token_id(token_id, stacks) return stacks def batch_filter_vocab(self, batch_stacks, device): batch_acceptance = [] for stacks in batch_stacks: batch_acceptance.append(self.filter_vocab(stacks, device)) return torch.stack(batch_acceptance) def filter_vocab(self, stacks, device): if not stacks: # Check if stacks is empty # Handle the empty case: for example, return a tensor of False # The size of the tensor should match the size of your vocabulary vocab_size = len(self.token_trie) logger.debug(f"sum of acceptance: {0}") return torch.zeros(vocab_size, dtype=torch.bool, device=device) acceptance_matrix = torch.cat([self.token_acceptance_for_stack(tuple(stack), device) for stack in stacks]) # Merge stacks: any True => True acceptance = acceptance_matrix.reshape(len(stacks), -1).any(dim=0) logger.debug(f"sum of acceptance: {acceptance.sum()}") return acceptance # For each sub-rule in the grammar, cache whether each byte is accepted. @lru_cache(maxsize=None) def pos_char_acceptance(self, pos, char): byte = char if isinstance(char, int) else ord(char) num_chars = self.grammar_encoding[pos] pos += 1 for i in range(0, num_chars, 2): start = self.grammar_encoding[pos + i] end = self.grammar_encoding[pos + i + 1] if byte >= start and byte <= end: return True if byte <= start and byte >= end: return True return False # Probably this should be configurable. If the grammar has an exceedingly # large number of states, the correct setting is a tradeoff between GPU # RAM usage and recomputation time. # # The main variable that pushes usage up here is number of states in the # grammar. @lru_cache(maxsize=32768) def token_acceptance_for_stack(self, stack, device): st = time.time() stack = list(stack) # needs to come in as a tuple for lru_cache accepts = [False] * len(self.token_trie) accepts[self.eos_token_id] = len(stack) == 0 if len(stack) == 0: logger.debug("empty stack") def traverse_trie(trie, stacks): for byte, next_trie in trie.items(): if byte == LEAF: token_id = next_trie if token_id != self.eos_token_id: accepts[token_id] = bool(stacks) continue new_stacks = [] for stk in stacks: if not stk: continue pos = stk[-1] num_chars = self.grammar_encoding[pos] if not self.pos_char_acceptance(pos, byte): continue pos += num_chars + 1 new_stack = stk[:-1] if self.grammar_encoding[pos]: new_stack.append(pos) new_stacks.extend(self.advance_stack(tuple(new_stack))) if new_stacks: traverse_trie(next_trie, new_stacks) traverse_trie(self.token_trie.trie, [stack]) et = time.time() - st x = torch.tensor(accepts, dtype=torch.bool, device=device) self.tt += et self.nt += 1 return x class StaticGrammarConstraint(GrammarConstraint): def __init__(self, grammar_str, start_rule_name, tokenizer): super().__init__(grammar_str, start_rule_name, tokenizer) def accept_char(self): raise NotImplementedError ################# # DATA STRUCTURES ################# LEAF = -1 class TokenTrie: def __init__(self, tokenizer): self.eos_token_id = tokenizer.eos_token_id self.tokens = [] self.trie = {} self.load_tokens(tokenizer) def id2str(self, token_id): return self.tokens[token_id] def __len__(self): return len(self.tokens) def load_tokens(self, tokenizer): def replace_hex(match): hex_value = match.group(1) return chr(int(hex_value, 16)) if "gpt2" in tokenizer.__class__.__name__.lower(): special = tokenizer.additional_special_tokens_ids # Here, the decoder does a string replace on a bunch of sequences # like ' .' for '.'. This interferes with our assumptions, where a # token should always have exactly one representation. # Fortunately(?) text-generation-inference doesn't seem to run this # cleanup, so we get extraneous spaces. So, in order to generate # the right token set for TGI, we have to skip the space trimming. # See: # https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L3588-L3600 def fmt_token(id): if id in special: return None return bytes(tokenizer.decode([id], clean_up_tokenization_spaces=False), "utf-8") elif "llama" in tokenizer.__class__.__name__.lower(): def fmt_token(id): token = tokenizer.convert_ids_to_tokens(id) token = re.sub(r"<0x([0-9a-fA-F]{2})>", replace_hex, token) token = token.replace("▁", " ") return token else: print("Warning: unrecognized tokenizer: using default token formatting") def fmt_token(id): token = tokenizer.convert_ids_to_tokens(id) return token # note: vocab_size doesn't work here because there are also # get_added_vocab() tokens self.tokens = [fmt_token(i) for i in range(len(tokenizer.get_vocab()))] for token_id, token_bytes in enumerate(self.tokens): if token_bytes is not None: self.insert_into_trie(self.trie, token_bytes, token_id) def insert_into_trie(self, trie, token_bytes, token_id): current = trie for byte in token_bytes: if byte not in current: current[byte] = {} current = current[byte] current[LEAF] = token_id @lru_cache(maxsize=5) def initialize_grammar(grammar_string): return IncrementalGrammarConstraint(grammar_string.strip(), start_rule_name="root", tokenizer=shared.tokenizer) ================================================ FILE: modules/grammar/logits_process.py ================================================ ''' This file has been 100% copied from this PR to the Transformers library: https://github.com/huggingface/transformers/pull/27557 Author: Saibo-creator Author GitHub: https://github.com/Saibo-creator All credits go to the author. ''' import math import torch from transformers.generation.logits_process import LogitsProcessor from transformers.utils import add_start_docstrings LOGITS_PROCESSOR_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`): Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam search or log softmax for each vocabulary token when using beam search Return: `torch.FloatTensor` of shape `(batch_size, config.vocab_size)`: The processed prediction scores. """ class GrammarConstrainedLogitsProcessor(LogitsProcessor): def __init__(self, grammar_constraint): self.last_size = None self.grammar_constraint = grammar_constraint self.batch_stacks = None def filter_logits(self, logits, device): # resolve each stack to a tensor of True/False for each token # indicating acceptance # acceptance = self.grammar_acceptor.filter_vocab(self.stacks, device) acceptance = self.grammar_constraint.batch_filter_vocab(self.batch_stacks, device) # logger.debug(acceptance) # Logits to -inf where False logits[~acceptance] = -math.inf # TODO: batching def process_logits(self, input_ids, scores, parse_start_index=None): """ :param input_ids: :param scores: :param parse_start_index: default None, which means generate from scratch. Set to 0 to parse all input_ids :return: """ # we dynamically create stacks at the first call, so that we know the batch size and beam size if self.batch_stacks is None: self.batch_stacks = [self.grammar_constraint.init_stacks() for _ in range(len(input_ids))] # if self.last_size is not set (which would be the case when processing the first token). # In this case, do nothing. if self.last_size is None: prefix_to_parse = [ single_input_ids[parse_start_index:] if parse_start_index is not None else [] for single_input_ids in input_ids ] # self.grammar_acceptor.accept_token_ids(prefix_to_parse, self.stacks) self.batch_stacks = [ self.grammar_constraint.accept_token_ids(prefix, stack) for prefix, stack in zip(prefix_to_parse, self.batch_stacks) ] # if the length of the current input IDs (input_ids[0]) is exactly one more than self.last_size. # This is expected in a scenario where inputs are processed incrementally, one token at a time. elif len(input_ids[0]) == self.last_size + 1: # self.stacks = self.grammar_acceptor.accept_token_id(input_ids[0][-1], self.stacks) self.batch_stacks = [ self.grammar_constraint.accept_token_id(single_input_ids[-1], stack) for single_input_ids, stack in zip(input_ids, self.batch_stacks) ] # ensure that the input size is consistent with the expected incremental processing # (i.e., one token at a time). else: # here we check if the input_ids are one token longer than the last time we processed # but we don't check if input_ids are actually valid. # Imagine a scenario where we generate 10 tokens, then we replace the 10 generated tokens with 10 new tokens. # In this case, the input_ids will be consistent with the last_size, but the input_ids are not valid. # However, should we really check if the input_ids are valid here? # If we do, then we need to reparse the whole input_ids at each call, which is not efficient. # Maybe we should just trust the user to provide valid input_ids? # The conclusion is that, we assume the input_ids are valid, and our generation will be correct. # If the input_ids are not valid, then the generation result will be wrong and we don't take responsibility for that. raise RuntimeError( "Input ID's length is inconsistent with the current state of " "the GrammarConstrainedLogitsProcessor. If you want to process " "another input sequence, please instantiate a new " "GrammarConstrainedLogitsProcessor." ) self.filter_logits(scores, scores.device) self.last_size = len(input_ids[0]) return scores @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: return self.process_logits(input_ids, scores) ================================================ FILE: modules/html_generator.py ================================================ import datetime import functools import html import os import re import time from pathlib import Path import markdown from PIL import Image, ImageOps from modules import shared from modules.reasoning import extract_reasoning from modules.sane_markdown_lists import SaneListExtension from modules.utils import get_available_chat_styles # This is to store the paths to the thumbnails of the profile pictures image_cache = {} def minify_css(css: str) -> str: # Step 1: Remove comments css = re.sub(r'/\*.*?\*/', '', css, flags=re.DOTALL) # Step 2: Remove leading and trailing whitespace css = re.sub(r'^[ \t]*|[ \t]*$', '', css, flags=re.MULTILINE) # Step 3: Remove spaces after specific characters ({ : ; ,}) css = re.sub(r'([:{;,])\s+', r'\1', css) # Step 4: Remove spaces before `{` css = re.sub(r'\s+{', '{', css) # Step 5: Remove empty lines css = re.sub(r'^\s*$', '', css, flags=re.MULTILINE) # Step 6: Collapse all lines into one css = re.sub(r'\n', '', css) return css with open(Path(__file__).resolve().parent / '../css/html_readable_style.css', 'r', encoding='utf-8') as f: readable_css = f.read() with open(Path(__file__).resolve().parent / '../css/html_instruct_style.css', 'r', encoding='utf-8') as f: instruct_css = f.read() # Custom chat styles chat_styles = {} for k in get_available_chat_styles(): with open(Path(f'css/chat_style-{k}.css'), 'r', encoding='utf-8') as f: chat_styles[k] = f.read() # Handle styles that derive from other styles for k in chat_styles: lines = chat_styles[k].split('\n') input_string = lines[0] match = re.search(r'chat_style-([a-z\-]*)\.css', input_string) if match: style = match.group(1) chat_styles[k] = chat_styles.get(style, '') + '\n\n' + '\n'.join(lines[1:]) # Reduce the size of the CSS sources above readable_css = minify_css(readable_css) instruct_css = minify_css(instruct_css) for k in chat_styles: chat_styles[k] = minify_css(chat_styles[k]) def fix_newlines(string): string = string.replace('\n', '\n\n') string = re.sub(r"\n{3,}", "\n\n", string) string = string.strip() return string def replace_quotes(text): # Define a list of quote pairs (opening and closing), using HTML entities quote_pairs = [ ('"', '"'), # Double quotes ('“', '”'), # Unicode left and right double quotation marks ('‘', '’'), # Unicode left and right single quotation marks ('«', '»'), # French quotes ('„', '“'), # German quotes ('‘', '’'), # Alternative single quotes ('“', '”'), # Unicode quotes (numeric entities) ('“', '”'), # Unicode quotes (hex entities) ('\u201C', '\u201D'), # Unicode quotes (literal chars) ] # Create a regex pattern that matches any of the quote pairs, including newlines pattern = '|'.join(f'({re.escape(open_q)})(.*?)({re.escape(close_q)})' for open_q, close_q in quote_pairs) # Replace matched patterns with tags, keeping original quotes def replacer(m): # Find the first non-None group set for i in range(1, len(m.groups()), 3): # Step through each sub-pattern's groups if m.group(i): # If this sub-pattern matched return f'{m.group(i)}{m.group(i + 1)}{m.group(i + 2)}' return m.group(0) # Fallback (shouldn't happen) replaced_text = re.sub(pattern, replacer, text, flags=re.DOTALL) return replaced_text def replace_blockquote(m): return m.group().replace('\n', '\n> ').replace('\\begin{blockquote}', '').replace('\\end{blockquote}', '') def extract_thinking_block(string): """Extract thinking blocks from the beginning of an HTML-escaped string.""" return extract_reasoning(string, html_escaped=True) def build_tool_call_block(header, body, message_id, index): """Build HTML for a tool call accordion block.""" block_id = f"tool-call-{message_id}-{index}" if body == '...': # Pending placeholder — no expandable body, just title with ellipsis return f'''
{tool_svg_small} {html.escape(header)} ...
''' # Build a plain
 directly to avoid highlight.js auto-detection
    escaped_body = html.escape(body)
    return f'''
    
{tool_svg_small} {html.escape(header)}
{escaped_body}
''' def build_thinking_block(thinking_content, message_id, has_remaining_content, thinking_index=0): """Build HTML for a thinking block.""" if thinking_content is None: return None # Process the thinking content through markdown thinking_html = process_markdown_content(thinking_content) # Generate unique ID for the thinking block block_id = f"thinking-{message_id}-{thinking_index}" # Check if thinking is complete or still in progress is_streaming = not has_remaining_content title_text = "Thinking..." if is_streaming else "Thought" return f'''
{info_svg_small} {title_text}
{thinking_html}
''' def build_main_content_block(content): """Build HTML for the main content block.""" if not content: return "" return process_markdown_content(content) def process_markdown_content(string): """ Process a string through the markdown conversion pipeline. Uses robust manual parsing to ensure correct LaTeX and Code Block rendering. """ if not string: return "" # Define unique placeholders for LaTeX asterisks and underscores LATEX_ASTERISK_PLACEHOLDER = "LATEXASTERISKPLACEHOLDER" LATEX_UNDERSCORE_PLACEHOLDER = "LATEXUNDERSCOREPLACEHOLDER" def protect_asterisks_underscores_in_latex(match): """A replacer function for re.sub to protect asterisks and underscores in multiple LaTeX formats.""" # Check which delimiter group was captured if match.group(1) is not None: # Content from $$...$$ content = match.group(1) modified_content = content.replace('*', LATEX_ASTERISK_PLACEHOLDER) modified_content = modified_content.replace('_', LATEX_UNDERSCORE_PLACEHOLDER) return f'{modified_content}' elif match.group(2) is not None: # Content from \[...\] content = match.group(2) modified_content = content.replace('*', LATEX_ASTERISK_PLACEHOLDER) modified_content = modified_content.replace('_', LATEX_UNDERSCORE_PLACEHOLDER) return f'\\[{modified_content}\\]' elif match.group(3) is not None: # Content from \(...\) content = match.group(3) modified_content = content.replace('*', LATEX_ASTERISK_PLACEHOLDER) modified_content = modified_content.replace('_', LATEX_UNDERSCORE_PLACEHOLDER) return f'\\({modified_content}\\)' return match.group(0) # Fallback # Make \[ \] LaTeX equations inline pattern = r'^\s*\\\[\s*\n([\s\S]*?)\n\s*\\\]\s*$' replacement = r'\\[ \1 \\]' string = re.sub(pattern, replacement, string, flags=re.MULTILINE) # Escape backslashes string = string.replace('\\', '\\\\') # Quote to string = replace_quotes(string) # Blockquote string = re.sub(r'(^|[\n])>', r'\1>', string) pattern = re.compile(r'\\begin{blockquote}(.*?)\\end{blockquote}', re.DOTALL) string = pattern.sub(replace_blockquote, string) # Code block standardization string = string.replace('\\begin{code}', '```') string = string.replace('\\end{code}', '```') string = string.replace('\\begin{align*}', '$$') string = string.replace('\\end{align*}', '$$') string = string.replace('\\begin{align}', '$$') string = string.replace('\\end{align}', '$$') string = string.replace('\\begin{equation}', '$$') string = string.replace('\\end{equation}', '$$') string = string.replace('\\begin{equation*}', '$$') string = string.replace('\\end{equation*}', '$$') string = re.sub(r"(.)```", r"\1\n```", string) # Protect asterisks and underscores within all LaTeX blocks before markdown conversion latex_pattern = re.compile(r'((?:^|[\r\n\s])\$\$[^`]*?\$\$)|\\\[(.*?)\\\]|\\\((.*?)\\\)', re.DOTALL) string = latex_pattern.sub(protect_asterisks_underscores_in_latex, string) result = '' is_code = False is_latex = False # Manual line iteration for robust structure parsing for line in string.split('\n'): stripped_line = line.strip() if stripped_line.startswith('```'): is_code = not is_code elif stripped_line.startswith('$$') and (stripped_line == "$$" or not stripped_line.endswith('$$')): is_latex = not is_latex elif stripped_line.endswith('$$'): is_latex = False elif stripped_line.startswith('\\\\[') and not stripped_line.endswith('\\\\]'): is_latex = True elif stripped_line.startswith('\\\\]'): is_latex = False elif stripped_line.endswith('\\\\]'): is_latex = False result += line # Don't add an extra \n for code, LaTeX, or tables if is_code or is_latex or line.startswith('|'): result += '\n' # Also don't add an extra \n for lists elif stripped_line.startswith('-') or stripped_line.startswith('*') or stripped_line.startswith('+') or stripped_line.startswith('>') or re.match(r'\d+\.', stripped_line): result += ' \n' else: result += ' \n' result = result.strip() if is_code: result += '\n```' # Unfinished code block # Unfinished list, like "\n1.". A |delete| string is added and then # removed to force a
    or
      to be generated instead of a

      . list_item_pattern = r'(\n\d+\.?|\n\s*[-*+]\s*([*_~]{1,3})?)$' if re.search(list_item_pattern, result): delete_str = '|delete|' if re.search(r'(\d+\.?)$', result) and not result.endswith('.'): result += '.' # Add the delete string after the list item result = re.sub(list_item_pattern, r'\g<1> ' + delete_str, result) # Convert to HTML using markdown html_output = markdown.markdown(result, extensions=['fenced_code', 'tables', SaneListExtension()]) # Remove the delete string from the HTML output pos = html_output.rfind(delete_str) if pos > -1: html_output = html_output[:pos] + html_output[pos + len(delete_str):] else: # Convert to HTML using markdown html_output = markdown.markdown(result, extensions=['fenced_code', 'tables', SaneListExtension()]) # Restore the LaTeX asterisks and underscores after markdown conversion html_output = html_output.replace(LATEX_ASTERISK_PLACEHOLDER, '*') html_output = html_output.replace(LATEX_UNDERSCORE_PLACEHOLDER, '_') # Remove extra newlines before html_output = re.sub(r'\s*', '', html_output) # Unescape code blocks pattern = re.compile(r']*>(.*?)', re.DOTALL) html_output = pattern.sub(lambda x: html.unescape(x.group()), html_output) # Unescape backslashes html_output = html_output.replace('\\\\', '\\') # Wrap tables in a scrollable div html_output = html_output.replace('', '
      ').replace('
      ', '') return html_output @functools.lru_cache(maxsize=None) def convert_to_markdown(string, message_id=None): """ Convert a string to markdown HTML with support for multiple block types. Blocks are assembled in order: thinking, main content, etc. """ if not string: return "" # Use a default message ID if none provided if message_id is None: message_id = "unknown" # Find tool call blocks by position, then process the text segments # between them using extract_thinking_block (which supports all # THINKING_FORMATS, including end-only variants like Qwen's). tool_call_pattern = re.compile(r'(.*?)\n(.*?)\n', re.DOTALL) tool_calls = list(tool_call_pattern.finditer(string)) if not tool_calls: # No tool calls — use original single-pass extraction thinking_content, remaining_content = extract_thinking_block(string) blocks = [] thinking_html = build_thinking_block(thinking_content, message_id, bool(remaining_content)) if thinking_html: blocks.append(thinking_html) main_html = build_main_content_block(remaining_content) if main_html: blocks.append(main_html) return ''.join(blocks) # Split string into text segments around tool_call blocks and # run extract_thinking_block on each segment for full format support. html_parts = [] last_end = 0 tool_idx = 0 think_idx = 0 def process_text_segment(text, is_last_segment): """Process a text segment between tool_call blocks for thinking content.""" nonlocal think_idx if not text.strip(): return while text.strip(): thinking_content, remaining = extract_thinking_block(text) if thinking_content is None: break has_remaining = bool(remaining.strip()) or not is_last_segment html_parts.append(build_thinking_block(thinking_content, message_id, has_remaining, think_idx)) think_idx += 1 text = remaining if text.strip(): html_parts.append(process_markdown_content(text)) for tc in tool_calls: # Process text before this tool_call process_text_segment(string[last_end:tc.start()], is_last_segment=False) # Add tool call accordion header = tc.group(1).strip() body = tc.group(2).strip() html_parts.append(build_tool_call_block(header, body, message_id, tool_idx)) tool_idx += 1 last_end = tc.end() # Process text after the last tool_call process_text_segment(string[last_end:], is_last_segment=True) return ''.join(html_parts) def convert_to_markdown_wrapped(string, message_id=None, use_cache=True): ''' Used to avoid caching convert_to_markdown calls during streaming. ''' if use_cache: return convert_to_markdown(string, message_id=message_id) return convert_to_markdown.__wrapped__(string, message_id=message_id) def generate_basic_html(string): convert_to_markdown.cache_clear() string = convert_to_markdown(string) string = f'

      {string}
      ' return string def make_thumbnail(image): image = image.resize((350, round(image.size[1] / image.size[0] * 350)), Image.Resampling.LANCZOS) if image.size[1] > 470: image = ImageOps.fit(image, (350, 470), Image.LANCZOS) return image def get_image_cache(path): cache_folder = Path(shared.args.disk_cache_dir) if not cache_folder.exists(): cache_folder.mkdir() mtime = os.stat(path).st_mtime if (path in image_cache and mtime != image_cache[path][0]) or (path not in image_cache): img = make_thumbnail(Image.open(path)) old_p = Path(f'{cache_folder}/{path.name}_cache.png') p = Path(f'{cache_folder}/cache_{path.name}.png') if old_p.exists(): old_p.rename(p) output_file = p img.convert('RGBA').save(output_file, format='PNG') image_cache[path] = [mtime, output_file.as_posix()] return image_cache[path][1] copy_svg = '''''' refresh_svg = '''''' continue_svg = '''''' remove_svg = '''''' branch_svg = '''''' edit_svg = '''''' info_svg = '''''' info_svg_small = '''''' tool_svg_small = '''''' attachment_svg = '''''' copy_button = f'' branch_button = f'' edit_button = f'' refresh_button = f'' continue_button = f'' remove_button = f'' info_button = f'' def format_message_timestamp(history, role, index, tooltip_include_timestamp=True): """Get a formatted timestamp HTML span for a message if available""" key = f"{role}_{index}" if 'metadata' in history and key in history['metadata'] and history['metadata'][key].get('timestamp'): timestamp = history['metadata'][key]['timestamp'] tooltip_text = get_message_tooltip(history, role, index, include_timestamp=tooltip_include_timestamp) title_attr = f' title="{html.escape(tooltip_text)}"' if tooltip_text else '' return f"{timestamp}" return "" def format_message_attachments(history, role, index): """Get formatted HTML for message attachments if available""" key = f"{role}_{index}" if 'metadata' in history and key in history['metadata'] and 'attachments' in history['metadata'][key]: attachments = history['metadata'][key]['attachments'] if not attachments: return "" attachments_html = '
      ' for attachment in attachments: name = html.escape(attachment["name"]) if attachment.get("type") == "image": image_data = attachment.get("image_data", "") attachments_html += ( f'
      ' f'{name}' f'
      {name}
      ' f'
      ' ) else: # Make clickable if URL exists (web search) if "url" in attachment: name = f'{name}' attachments_html += ( f'
      ' f'
      {attachment_svg}
      ' f'
      {name}
      ' f'
      ' ) attachments_html += '
      ' return attachments_html return "" def get_message_tooltip(history, role, index, include_timestamp=True): """Get tooltip text combining timestamp and model name for a message""" key = f"{role}_{index}" if 'metadata' not in history or key not in history['metadata']: return "" meta = history['metadata'][key] tooltip_parts = [] if include_timestamp and meta.get('timestamp'): tooltip_parts.append(meta['timestamp']) if meta.get('model_name'): tooltip_parts.append(f"Model: {meta['model_name']}") return " | ".join(tooltip_parts) def get_version_navigation_html(history, i, role): """Generate simple navigation arrows for message versions""" key = f"{role}_{i}" metadata = history.get('metadata', {}) if key not in metadata or 'versions' not in metadata[key]: return "" versions = metadata[key]['versions'] # Default to the last version if current_version_index isn't set in metadata current_idx = metadata[key].get('current_version_index', len(versions) - 1 if versions else 0) if len(versions) <= 1: return "" left_disabled = ' disabled' if current_idx == 0 else '' right_disabled = ' disabled' if current_idx >= len(versions) - 1 else '' left_arrow = f'' right_arrow = f'' position = f'{current_idx + 1}/{len(versions)}' return f'
      {left_arrow}{position}{right_arrow}
      ' def actions_html(history, i, role, info_message=""): action_buttons = "" version_nav_html = "" if role == "assistant": action_buttons = ( f'{copy_button}' f'{edit_button}' f'{refresh_button if i == len(history["visible"]) - 1 else ""}' f'{continue_button if i == len(history["visible"]) - 1 else ""}' f'{remove_button if i == len(history["visible"]) - 1 else ""}' f'{branch_button}' ) version_nav_html = get_version_navigation_html(history, i, "assistant") elif role == "user": action_buttons = ( f'{copy_button}' f'{edit_button}' ) version_nav_html = get_version_navigation_html(history, i, "user") return (f'
      ' f'{action_buttons}' f'{info_message}' f'
      ' f'{version_nav_html}') def generate_instruct_html(history, last_message_only=False): if not last_message_only: output = f'
      ' else: output = "" def create_message(role, content, raw_content): """Inner function that captures variables from outer scope.""" class_name = "user-message" if role == "user" else "assistant-message" # Get role-specific data timestamp = format_message_timestamp(history, role, i) attachments = format_message_attachments(history, role, i) # Create info button if timestamp exists info_message = "" if timestamp: tooltip_text = get_message_tooltip(history, role, i) info_message = info_button.replace('title="message"', f'title="{html.escape(tooltip_text)}"') return ( f'
      ' f'
      ' f'
      {content}
      ' f'{attachments}' f'{actions_html(history, i, role, info_message)}' f'
      ' f'
      ' ) # Determine range start_idx = len(history['visible']) - 1 if last_message_only else 0 end_idx = len(history['visible']) for i in range(start_idx, end_idx): row_visible = history['visible'][i] row_internal = history['internal'][i] # Convert content if last_message_only: converted_visible = [None, convert_to_markdown_wrapped(row_visible[1], message_id=i, use_cache=i != len(history['visible']) - 1)] else: converted_visible = [convert_to_markdown_wrapped(entry, message_id=i, use_cache=i != len(history['visible']) - 1) for entry in row_visible] # Generate messages if not last_message_only and converted_visible[0]: output += create_message("user", converted_visible[0], row_internal[0]) output += create_message("assistant", converted_visible[1], row_internal[1]) if not last_message_only: output += "
      " return output def get_character_image_with_cache_buster(): """Get character image URL with cache busting based on file modification time""" cache_path = shared.user_data_dir / "cache" / "pfp_character_thumb.png" if cache_path.exists(): mtime = int(cache_path.stat().st_mtime) return f'' return '' def generate_cai_chat_html(history, name1, name2, style, character, reset_cache=False, last_message_only=False): if not last_message_only: output = f'
      ' else: output = "" img_bot = get_character_image_with_cache_buster() def create_message(role, content, raw_content): """Inner function for CAI-style messages.""" circle_class = "circle-you" if role == "user" else "circle-bot" name = name1 if role == "user" else name2 # Get role-specific data timestamp = format_message_timestamp(history, role, i, tooltip_include_timestamp=False) attachments = format_message_attachments(history, role, i) # Get appropriate image if role == "user": img = (f'' if (shared.user_data_dir / "cache" / "pfp_me.png").exists() else '') else: img = img_bot return ( f'
      ' f'
      {img}
      ' f'
      ' f'
      {name}{timestamp}
      ' f'
      {content}
      ' f'{attachments}' f'{actions_html(history, i, role)}' f'
      ' f'
      ' ) # Determine range start_idx = len(history['visible']) - 1 if last_message_only else 0 end_idx = len(history['visible']) for i in range(start_idx, end_idx): row_visible = history['visible'][i] row_internal = history['internal'][i] # Convert content if last_message_only: converted_visible = [None, convert_to_markdown_wrapped(row_visible[1], message_id=i, use_cache=i != len(history['visible']) - 1)] else: converted_visible = [convert_to_markdown_wrapped(entry, message_id=i, use_cache=i != len(history['visible']) - 1) for entry in row_visible] # Generate messages if not last_message_only and converted_visible[0]: output += create_message("user", converted_visible[0], row_internal[0]) output += create_message("assistant", converted_visible[1], row_internal[1]) if not last_message_only: output += "
      " return output def time_greeting(): current_hour = datetime.datetime.now().hour if 5 <= current_hour < 12: return "Good morning!" elif 12 <= current_hour < 18: return "Good afternoon!" else: return "Good evening!" def chat_html_wrapper(history, name1, name2, mode, style, character, reset_cache=False, last_message_only=False): if len(history['visible']) == 0: greeting = f"
      {time_greeting()} How can I help you today?
      " result = f'
      {greeting}
      ' elif mode == 'instruct': result = generate_instruct_html(history, last_message_only=last_message_only) else: result = generate_cai_chat_html(history, name1, name2, style, character, reset_cache=reset_cache, last_message_only=last_message_only) return {'html': result, 'last_message_only': last_message_only} ================================================ FILE: modules/image_models.py ================================================ import time import modules.shared as shared from modules.logging_colors import logger from modules.utils import resolve_model_path def get_quantization_config(quant_method): """ Get the appropriate quantization config based on the selected method. Applies quantization to both the transformer and the text_encoder. """ import torch # Import BitsAndBytesConfig from BOTH libraries to be safe from diffusers import BitsAndBytesConfig as DiffusersBnBConfig from diffusers import TorchAoConfig from diffusers.quantizers import PipelineQuantizationConfig from transformers import BitsAndBytesConfig as TransformersBnBConfig if quant_method == 'none' or not quant_method: return None # Bitsandbytes 8-bit quantization elif quant_method == 'bnb-8bit': return PipelineQuantizationConfig( quant_mapping={ "transformer": DiffusersBnBConfig( load_in_8bit=True ), "text_encoder": TransformersBnBConfig( load_in_8bit=True ) } ) # Bitsandbytes 4-bit quantization elif quant_method == 'bnb-4bit': return PipelineQuantizationConfig( quant_mapping={ "transformer": DiffusersBnBConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True ), "text_encoder": TransformersBnBConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True ) } ) # torchao int8 weight-only elif quant_method == 'torchao-int8wo': return PipelineQuantizationConfig( quant_mapping={ "transformer": TorchAoConfig("int8wo"), "text_encoder": TorchAoConfig("int8wo") } ) # torchao fp4 (e2m1) elif quant_method == 'torchao-fp4': return PipelineQuantizationConfig( quant_mapping={ "transformer": TorchAoConfig("fp4_e2m1"), "text_encoder": TorchAoConfig("fp4_e2m1") } ) # torchao float8 weight-only elif quant_method == 'torchao-float8wo': return PipelineQuantizationConfig( quant_mapping={ "transformer": TorchAoConfig("float8wo"), "text_encoder": TorchAoConfig("float8wo") } ) else: logger.warning(f"Unknown quantization method: {quant_method}. Loading without quantization.") return None def get_pipeline_type(pipe): """ Detect the pipeline type based on the loaded pipeline class. Returns: str: 'zimage', 'qwenimage', or 'unknown' """ class_name = pipe.__class__.__name__ if class_name == 'ZImagePipeline': return 'zimage' elif class_name == 'QwenImagePipeline': return 'qwenimage' else: return 'unknown' def load_image_model(model_name, dtype='bfloat16', attn_backend='sdpa', cpu_offload=False, compile_model=False, quant_method='none'): """ Load a diffusers image generation model. Args: model_name: Name of the model directory dtype: 'bfloat16' or 'float16' attn_backend: 'sdpa' or 'flash_attention_2' cpu_offload: Enable CPU offloading for low VRAM compile_model: Compile the model for faster inference (slow first run) quant_method: 'none', 'bnb-8bit', 'bnb-4bit', or torchao options (int8wo, fp4, float8wo) """ import torch from diffusers import DiffusionPipeline from modules.torch_utils import get_device logger.info(f"Loading image model \"{model_name}\" with quantization: {quant_method}") t0 = time.time() dtype_map = {"bfloat16": torch.bfloat16, "float16": torch.float16} target_dtype = dtype_map.get(dtype, torch.bfloat16) model_path = resolve_model_path(model_name, image_model=True) try: # Get quantization config based on selected method pipeline_quant_config = get_quantization_config(quant_method) # Load the pipeline load_kwargs = { "torch_dtype": target_dtype, "low_cpu_mem_usage": True, } if pipeline_quant_config is not None: load_kwargs["quantization_config"] = pipeline_quant_config # Use DiffusionPipeline for automatic pipeline detection # This handles both ZImagePipeline and QwenImagePipeline pipe = DiffusionPipeline.from_pretrained( str(model_path), **load_kwargs ) pipeline_type = get_pipeline_type(pipe) if not cpu_offload: pipe.to(get_device()) modules = ["transformer", "unet"] # Set attention backend if attn_backend == 'flash_attention_2': for name in modules: mod = getattr(pipe, name, None) if hasattr(mod, "set_attention_backend"): mod.set_attention_backend("flash") break # Compile model if compile_model: for name in modules: mod = getattr(pipe, name, None) if hasattr(mod, "compile"): logger.info("Compiling model (first run will be slow)...") mod.compile() break if cpu_offload: pipe.enable_model_cpu_offload() shared.image_model = pipe shared.image_model_name = model_name shared.image_pipeline_type = pipeline_type logger.info(f"Loaded image model \"{model_name}\" in {(time.time() - t0):.2f} seconds.") return pipe except Exception as e: logger.error(f"Failed to load image model: {str(e)}") return None def unload_image_model(): """Unload the current image model and free VRAM.""" if shared.image_model is None: return del shared.image_model shared.image_model = None shared.image_model_name = 'None' shared.image_pipeline_type = None from modules.torch_utils import clear_torch_cache clear_torch_cache() logger.info("Image model unloaded.") ================================================ FILE: modules/image_utils.py ================================================ import base64 import io import os from pathlib import Path from typing import Any, List, Tuple from PIL import Image from modules.logging_colors import logger def open_image_safely(path): if path is None or not isinstance(path, str) or not Path(path).exists(): return None if os.path.islink(path): return None try: return Image.open(path) except Exception as e: logger.error(f"Failed to open image file: {path}. Reason: {e}") return None def convert_pil_to_base64(image: Image.Image) -> str: """Converts a PIL Image to a base64 encoded string.""" buffered = io.BytesIO() # Save image to an in-memory bytes buffer in PNG format image.save(buffered, format="PNG") # Encode the bytes to a base64 string return base64.b64encode(buffered.getvalue()).decode('utf-8') def decode_base64_image(base64_string: str) -> Image.Image: """Decodes a base64 string to a PIL Image.""" try: if base64_string.startswith('data:image/'): base64_string = base64_string.split(',', 1)[1] image_data = base64.b64decode(base64_string) image = Image.open(io.BytesIO(image_data)) return image except Exception as e: logger.error(f"Failed to decode base64 image: {e}") raise ValueError(f"Invalid base64 image data: {e}") def process_message_content(content: Any) -> Tuple[str, List[Image.Image]]: """ Processes message content that may contain text and images. Returns: A tuple of (text_content, list_of_pil_images). """ if isinstance(content, str): return content, [] if isinstance(content, list): text_parts = [] images = [] for item in content: if not isinstance(item, dict): continue item_type = item.get('type', '') if item_type == 'text': text_parts.append(item.get('text', '')) elif item_type == 'image_url': image_url_data = item.get('image_url', {}) image_url = image_url_data.get('url', '') if image_url.startswith('data:image/'): try: images.append(decode_base64_image(image_url)) except Exception as e: logger.warning(f"Failed to process a base64 image: {e}") elif image_url.startswith('http'): # Support external URLs try: import requests from urllib.parse import urljoin from modules.web_search import _validate_url _validate_url(image_url) url = image_url for _ in range(5): response = requests.get(url, timeout=10, allow_redirects=False) if response.is_redirect and 'Location' in response.headers: url = urljoin(url, response.headers['Location']) _validate_url(url) else: break response.raise_for_status() image_data = response.content image = Image.open(io.BytesIO(image_data)) images.append(image) logger.info("Successfully loaded external image from URL") except Exception as e: logger.warning(f"Failed to fetch external image: {e}") else: logger.warning(f"Unsupported image URL format: {image_url[:70]}...") return ' '.join(text_parts), images return str(content), [] def convert_image_attachments_to_pil(image_attachments: List[dict]) -> List[Image.Image]: """Convert webui image_attachments format to PIL Images.""" pil_images = [] for attachment in image_attachments: if attachment.get('type') == 'image' and 'image_data' in attachment: try: image = decode_base64_image(attachment['image_data']) if image.mode != 'RGB': image = image.convert('RGB') pil_images.append(image) except Exception as e: logger.warning(f"Failed to process image attachment: {e}") return pil_images def convert_openai_messages_to_images(messages: List[dict]) -> List[Image.Image]: """Convert OpenAI messages format to PIL Images.""" all_images = [] for message in messages: if isinstance(message, dict) and 'content' in message: _, images = process_message_content(message['content']) all_images.extend(images) return all_images ================================================ FILE: modules/llama_cpp_server.py ================================================ import json import os import pprint import re import socket import subprocess import sys import threading import time from pathlib import Path from typing import Any, List import llama_cpp_binaries import requests from modules import shared from modules.image_utils import ( convert_image_attachments_to_pil, convert_openai_messages_to_images, convert_pil_to_base64 ) from modules.logging_colors import logger from modules.utils import resolve_model_path llamacpp_valid_cache_types = {"fp16", "q8_0", "q4_0"} class LlamaServer: def __init__(self, model_path, server_path=None): """ Initialize and start a server for llama.cpp models. """ self.model_path = model_path self.server_path = server_path self.port = self._find_available_port() self.process = None self.session = requests.Session() self.vocabulary_size = None self.n_ctx = None self.bos_token = "" self.last_prompt_token_count = 0 # Start the server self._start_server() def encode(self, text, add_bos_token=False, **kwargs): if self.bos_token and text.startswith(self.bos_token): add_bos_token = False url = f"http://127.0.0.1:{self.port}/tokenize" payload = { "content": text, "add_special": add_bos_token, } response = self.session.post(url, json=payload) result = response.json() return result.get("tokens", []) def decode(self, token_ids, **kwargs): url = f"http://127.0.0.1:{self.port}/detokenize" payload = { "tokens": token_ids, } response = self.session.post(url, json=payload) result = response.json() return result.get("content", "") def prepare_payload(self, state): payload = { "temperature": state["temperature"] if not state["dynamic_temperature"] else (state["dynatemp_low"] + state["dynatemp_high"]) / 2, "dynatemp_range": 0 if not state["dynamic_temperature"] else (state["dynatemp_high"] - state["dynatemp_low"]) / 2, "dynatemp_exponent": state["dynatemp_exponent"], "top_k": state["top_k"], "top_p": state["top_p"], "min_p": state["min_p"], "top_n_sigma": state["top_n_sigma"] if state["top_n_sigma"] > 0 else -1, "adaptive_target": state["adaptive_target"] if state["adaptive_target"] > 0 else -1, "adaptive_decay": state["adaptive_decay"], "typical_p": state["typical_p"], "repeat_penalty": state["repetition_penalty"], "repeat_last_n": state["repetition_penalty_range"], "presence_penalty": state["presence_penalty"], "frequency_penalty": state["frequency_penalty"], "dry_multiplier": state["dry_multiplier"], "dry_base": state["dry_base"], "dry_allowed_length": state["dry_allowed_length"], "dry_penalty_last_n": state["repetition_penalty_range"], "xtc_probability": state["xtc_probability"], "xtc_threshold": state["xtc_threshold"], "mirostat": state["mirostat_mode"], "mirostat_tau": state["mirostat_tau"], "mirostat_eta": state["mirostat_eta"], "grammar": state["grammar_string"], "seed": state["seed"], "ignore_eos": state["ban_eos_token"], } # DRY dry_sequence_breakers = state['dry_sequence_breakers'] if not dry_sequence_breakers.startswith("["): dry_sequence_breakers = "[" + dry_sequence_breakers + "]" dry_sequence_breakers = json.loads(dry_sequence_breakers) payload["dry_sequence_breakers"] = dry_sequence_breakers # Sampler order if state["sampler_priority"]: samplers = state["sampler_priority"] samplers = samplers.split("\n") if isinstance(samplers, str) else samplers filtered_samplers = [] penalty_found = False for s in samplers: if s.strip() in ["dry", "top_k", "top_p", "top_n_sigma", "min_p", "temperature", "xtc"]: filtered_samplers.append(s.strip()) elif s.strip() == "typical_p": filtered_samplers.append("typ_p") elif not penalty_found and s.strip() == "repetition_penalty": filtered_samplers.append("penalties") penalty_found = True # Move temperature to the end if temperature_last is true and temperature exists in the list if state["temperature_last"] and "temperature" in filtered_samplers: filtered_samplers.remove("temperature") filtered_samplers.append("temperature") # adaptive-p replaces the default dist sampler; llama.cpp always # places it at the end of the chain regardless of position, so we # activate it based on the parameter value rather than sampler order. if state.get("adaptive_target", 0) > 0: filtered_samplers.append("adaptive_p") payload["samplers"] = filtered_samplers logit_bias = [] if state['custom_token_bans']: logit_bias.extend([[int(token_id.strip()), False] for token_id in state['custom_token_bans'].split(',') if token_id.strip()]) if state.get('logit_bias'): for token_id_str, bias in state['logit_bias'].items(): logit_bias.append([int(token_id_str), bias]) if logit_bias: payload["logit_bias"] = logit_bias n_probs = state.get('logprobs', 0) if n_probs and n_probs > 0: payload["n_probs"] = n_probs return payload def _process_images_for_generation(self, state: dict) -> List[Any]: """ Process all possible image inputs and return PIL images """ pil_images = [] # Source 1: Web UI (from chatbot_wrapper) if 'image_attachments' in state and state['image_attachments']: pil_images.extend(convert_image_attachments_to_pil(state['image_attachments'])) # Source 2: Chat Completions API (/v1/chat/completions) elif 'history' in state and state.get('history', {}).get('messages'): pil_images.extend(convert_openai_messages_to_images(state['history']['messages'])) # Source 3: Legacy Completions API (/v1/completions) elif 'raw_images' in state and state['raw_images']: pil_images.extend(state.get('raw_images', [])) return pil_images def is_multimodal(self) -> bool: """Check if this model supports multimodal input.""" return shared.args.mmproj not in [None, 'None'] def generate_with_streaming(self, prompt, state): url = f"http://127.0.0.1:{self.port}/completion" payload = self.prepare_payload(state) pil_images = [] if shared.is_multimodal: pil_images = self._process_images_for_generation(state) if pil_images: # Multimodal case IMAGE_TOKEN_COST_ESTIMATE = 600 # A safe, conservative estimate per image base64_images = [convert_pil_to_base64(img) for img in pil_images] payload["prompt"] = { "prompt_string": prompt, "multimodal_data": base64_images } # Calculate an estimated token count text_tokens = self.encode(prompt, add_bos_token=state["add_bos_token"]) self.last_prompt_token_count = len(text_tokens) + (len(pil_images) * IMAGE_TOKEN_COST_ESTIMATE) else: # Text only case token_ids = self.encode(prompt, add_bos_token=state["add_bos_token"]) self.last_prompt_token_count = len(token_ids) payload["prompt"] = token_ids if state['auto_max_new_tokens']: max_new_tokens = state['truncation_length'] - self.last_prompt_token_count else: max_new_tokens = state['max_new_tokens'] payload.update({ "n_predict": max_new_tokens, "stream": True, "cache_prompt": True }) if shared.args.verbose: logger.info("GENERATE_PARAMS=") printable_payload = {k: v for k, v in payload.items() if k != "prompt"} pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(printable_payload) print() # Make the generation request response = self.session.post(url, json=payload, stream=True) try: if response.status_code == 400 and response.json().get("error", {}).get("type") == "exceed_context_size_error": logger.error("The request exceeds the available context size, try increasing it") return else: response.raise_for_status() # Raise an exception for HTTP errors full_text = "" self.last_completion_probabilities = [] # Process the streaming response stop_event = state.get('stop_event') for line in response.iter_lines(): if shared.stop_everything or (stop_event and stop_event.is_set()): break if not line: continue try: line = line.decode('utf-8') # Check if the line starts with "data: " and remove it if line.startswith('data: '): line = line[6:] # Remove the "data: " prefix # Parse the JSON data data = json.loads(line) # Extract the token content if data.get('content', ''): full_text += data['content'] yield full_text # Capture logprobs if present if 'completion_probabilities' in data: self.last_completion_probabilities.extend(data['completion_probabilities']) # Check if generation is complete if data.get('stop', False): break except json.JSONDecodeError as e: # Log the error and the problematic line print(f"JSON decode error: {e}") print(f"Problematic line: {line}") continue finally: response.close() def generate(self, prompt, state): output = "" for output in self.generate_with_streaming(prompt, state): pass return output def get_logits(self, prompt, state, n_probs=128, use_samplers=False): """Get the logits/probabilities for the next token after a prompt""" url = f"http://127.0.0.1:{self.port}/completion" payload = self.prepare_payload(state) payload.update({ "prompt": self.encode(prompt, add_bos_token=state["add_bos_token"]), "n_predict": 0, "logprobs": True, "n_probs": n_probs, "stream": False, "post_sampling_probs": use_samplers, }) if shared.args.verbose and use_samplers: logger.info("GENERATE_PARAMS=") printable_payload = {k: v for k, v in payload.items() if k != "prompt"} pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(printable_payload) print() for retry in range(5): response = self.session.post(url, json=payload) result = response.json() if "completion_probabilities" in result: if use_samplers: return result["completion_probabilities"][0]["top_probs"] else: return result["completion_probabilities"][0]["top_logprobs"] time.sleep(0.05) else: raise Exception(f"Unexpected response format: 'completion_probabilities' not found in {result}") def _get_vocabulary_size(self): """Get and store the model's maximum context length.""" url = f"http://127.0.0.1:{self.port}/v1/models" response = self.session.get(url).json() if "data" in response and len(response["data"]) > 0: model_info = response["data"][0] if "meta" in model_info and "n_vocab" in model_info["meta"]: self.vocabulary_size = model_info["meta"]["n_vocab"] def _get_bos_token(self): """Get and store the model's BOS token and context size.""" url = f"http://127.0.0.1:{self.port}/props" response = self.session.get(url).json() if "bos_token" in response: self.bos_token = response["bos_token"] # Get actual n_ctx from the server (important when --fit auto-selects it) n_ctx = response.get("default_generation_settings", {}).get("n_ctx") if n_ctx: self.n_ctx = n_ctx def _is_port_available(self, port): """Check if a port is available for use.""" with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: try: s.bind(('', port)) return True except OSError: return False def _find_available_port(self): """Find an available port, preferring main port + 5.""" preferred_port = shared.args.api_port + 5 if self._is_port_available(preferred_port): return preferred_port # Fall back to OS-assigned random port with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(('', 0)) return s.getsockname()[1] def _start_server(self): """Start the llama.cpp server and wait until it's ready.""" # Determine the server path if self.server_path is None: self.server_path = llama_cpp_binaries.get_binary_path() # Build the command cmd = [ self.server_path, "--model", self.model_path, "--batch-size", str(shared.args.batch_size), "--ubatch-size", str(shared.args.ubatch_size), "--port", str(self.port), "--no-webui", "--flash-attn", "on", ] if shared.args.ctx_size > 0: cmd += ["--ctx-size", str(shared.args.ctx_size)] elif shared.args.gpu_layers >= 0: cmd += ["--ctx-size", "8192"] if shared.args.gpu_layers >= 0: cmd += ["--gpu-layers", str(shared.args.gpu_layers), "--fit", "off"] else: cmd += ["--fit", "on"] cmd += ["--fit-ctx", "8192"] if shared.args.fit_target: cmd += ["--fit-target", shared.args.fit_target] if shared.args.threads > 0: cmd += ["--threads", str(shared.args.threads)] if shared.args.threads_batch > 0: cmd += ["--threads-batch", str(shared.args.threads_batch)] if shared.args.cpu_moe: cmd.append("--cpu-moe") if shared.args.no_mmap: cmd.append("--no-mmap") if shared.args.mlock: cmd.append("--mlock") if shared.args.tensor_split: cmd += ["--tensor-split", shared.args.tensor_split] if shared.args.numa: cmd += ["--numa", "distribute"] if shared.args.no_kv_offload: cmd.append("--no-kv-offload") if shared.args.row_split: cmd += ["--split-mode", "row"] cache_type = "fp16" if shared.args.cache_type != "fp16" and shared.args.cache_type in llamacpp_valid_cache_types: cmd += ["--cache-type-k", shared.args.cache_type, "--cache-type-v", shared.args.cache_type] cache_type = shared.args.cache_type if shared.args.mmproj not in [None, 'None']: path = Path(shared.args.mmproj) if not path.exists(): path = shared.user_data_dir / 'mmproj' / shared.args.mmproj if path.exists(): cmd += ["--mmproj", str(path)] if shared.args.model_draft not in [None, 'None']: path = resolve_model_path(shared.args.model_draft) if path.is_file(): model_file = path else: model_file = sorted(path.glob('*.gguf'))[0] cmd += ["--model-draft", str(model_file)] if shared.args.draft_max > 0: cmd += ["--draft-max", str(shared.args.draft_max)] if shared.args.gpu_layers_draft > 0: cmd += ["--gpu-layers-draft", str(shared.args.gpu_layers_draft)] if shared.args.device_draft: cmd += ["--device-draft", shared.args.device_draft] if shared.args.ctx_size_draft > 0: cmd += ["--ctx-size-draft", str(shared.args.ctx_size_draft)] if shared.args.spec_type != 'none': cmd += ["--spec-type", shared.args.spec_type] cmd += ["--draft-max", str(shared.args.draft_max)] cmd += ["--spec-ngram-size-n", str(shared.args.spec_ngram_size_n)] cmd += ["--spec-ngram-size-m", str(shared.args.spec_ngram_size_m)] cmd += ["--spec-ngram-min-hits", str(shared.args.spec_ngram_min_hits)] cmd += ["--parallel", str(shared.args.parallel)] if shared.args.streaming_llm: cmd += ["--cache-reuse", "1"] cmd += ["--swa-full"] if shared.args.extra_flags: # Clean up the input extra_flags = shared.args.extra_flags.strip() if extra_flags.startswith('"') and extra_flags.endswith('"'): extra_flags = extra_flags[1:-1].strip() elif extra_flags.startswith("'") and extra_flags.endswith("'"): extra_flags = extra_flags[1:-1].strip() for flag_item in extra_flags.split(','): flag_item = flag_item.strip() if '=' in flag_item: flag, value = flag_item.split('=', 1) flag = flag.strip() value = value.strip() if len(flag) <= 3: cmd += [f"-{flag}", value] else: cmd += [f"--{flag}", value] else: if len(flag_item) <= 3: cmd.append(f"-{flag_item}") else: cmd.append(f"--{flag_item}") env = os.environ.copy() if os.name == 'posix': current_path = env.get('LD_LIBRARY_PATH', '') if current_path: env['LD_LIBRARY_PATH'] = f"{current_path}:{os.path.dirname(self.server_path)}" else: env['LD_LIBRARY_PATH'] = os.path.dirname(self.server_path) if shared.args.verbose: logger.info("llama-server command-line flags:") print(' '.join(str(item) for item in cmd[1:])) print() gpu_layers_str = "auto" if shared.args.gpu_layers < 0 else str(shared.args.gpu_layers) ctx_size_str = "auto" if shared.args.ctx_size == 0 and shared.args.gpu_layers < 0 else str(shared.args.ctx_size or 8192) logger.info(f"Using gpu_layers={gpu_layers_str} | ctx_size={ctx_size_str} | cache_type={cache_type}") # Start the server with pipes for output self.process = subprocess.Popen( cmd, stderr=subprocess.PIPE, bufsize=0, env=env ) threading.Thread(target=filter_stderr_with_progress, args=(self.process.stderr,), daemon=True).start() # Wait for server to be healthy health_url = f"http://127.0.0.1:{self.port}/health" while True: # Check if process is still alive if self.process.poll() is not None: # Process has terminated exit_code = self.process.poll() raise RuntimeError(f"Server process terminated unexpectedly with exit code: {exit_code}") try: response = self.session.get(health_url) if response.status_code == 200: break except Exception: pass time.sleep(1) # Server is now healthy, get model info self._get_vocabulary_size() self._get_bos_token() return self.port def __enter__(self): """Support for context manager.""" return self def __exit__(self, exc_type, exc_val, exc_tb): """Support for context manager.""" self.stop() def __del__(self): """Cleanup when the object is deleted.""" self.stop() def stop(self): """Stop the server process.""" if self.process: self.process.terminate() try: self.process.wait(timeout=5) except subprocess.TimeoutExpired: self.process.kill() self.process.wait(timeout=5) self.process = None def filter_stderr_with_progress(process_stderr): """ Reads stderr lines from a process, filters out noise, and displays progress updates inline (overwriting the same line) until completion. """ progress_re = re.compile(r'slot update_slots: id.*progress = (\d+\.\d+)') ansi_re = re.compile(r'\x1b\[[0-9;]*[a-zA-Z]') log_prefix_re = re.compile(r'^[IWED] ') last_was_progress = False try: # Read in binary mode and decode manually buffer = b"" while True: # Read chunks aggressively to prevent buffer overflow chunk = process_stderr.read(4096) if not chunk: break buffer += chunk # Process complete lines while b'\n' in buffer: line_bytes, buffer = buffer.split(b'\n', 1) try: line = line_bytes.decode('utf-8', errors='replace').strip('\r\n') line = log_prefix_re.sub('', ansi_re.sub('', line)) if line: # Process non-empty lines match = progress_re.search(line) if match: progress = float(match.group(1)) # Extract just the part from "prompt processing" onwards prompt_processing_idx = line.find('prompt processing') if prompt_processing_idx != -1: display_line = line[prompt_processing_idx:] else: display_line = line # fallback to full line # choose carriage return for in-progress or newline at completion end_char = '\r' if progress < 1.0 else '\n' print(display_line, end=end_char, file=sys.stderr, flush=True) last_was_progress = (progress < 1.0) # skip noise lines elif not (line.startswith(('srv ', 'slot ')) or 'log_server_r: request: GET /health' in line or 'No parser definition detected' in line): # if we were in progress, finish that line first if last_was_progress: print(file=sys.stderr) print(line, file=sys.stderr, flush=True) last_was_progress = False except Exception: continue except (ValueError, IOError): pass finally: try: process_stderr.close() except Exception: pass ================================================ FILE: modules/loaders.py ================================================ import functools from collections import OrderedDict loaders_and_params = OrderedDict({ 'llama.cpp': [ 'gpu_layers', 'fit_target', 'cpu_moe', 'threads', 'threads_batch', 'batch_size', 'ubatch_size', 'ctx_size', 'cache_type', 'tensor_split', 'extra_flags', 'streaming_llm', 'row_split', 'no_kv_offload', 'no_mmap', 'mlock', 'numa', 'parallel', 'model_draft', 'draft_max', 'gpu_layers_draft', 'device_draft', 'ctx_size_draft', 'ngram_header', 'spec_type', 'spec_ngram_size_n', 'spec_ngram_size_m', 'spec_ngram_min_hits', 'speculative_decoding_accordion', 'mmproj', 'mmproj_accordion', 'vram_info', ], 'Transformers': [ 'gpu_split', 'cpu_memory', 'compute_dtype', 'quant_type', 'load_in_8bit', 'load_in_4bit', 'attn_implementation', 'cpu', 'disk', 'use_double_quant', 'bf16', 'no_use_fast', ], 'ExLlamav3_HF': [ 'ctx_size', 'cache_type', 'gpu_split', 'cfg_cache', 'no_use_fast', 'enable_tp', 'tp_backend', ], 'ExLlamav3': [ 'ctx_size', 'cache_type', 'gpu_split', 'model_draft', 'draft_max', 'speculative_decoding_accordion', 'enable_tp', 'tp_backend', ], 'TensorRT-LLM': [ 'ctx_size', 'tensorrt_llm_info', ] }) def transformers_samplers(): return { 'temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', 'smoothing_curve', 'min_p', 'top_p', 'top_k', 'typical_p', 'xtc_threshold', 'xtc_probability', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'top_n_sigma', 'adaptive_target', 'adaptive_decay', 'dry_multiplier', 'dry_allowed_length', 'dry_base', 'repetition_penalty', 'frequency_penalty', 'presence_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'repetition_penalty_range', 'penalty_alpha', 'guidance_scale', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'prompt_lookup_num_tokens', 'do_sample', 'dynamic_temperature', 'temperature_last', 'auto_max_new_tokens', 'ban_eos_token', 'add_bos_token', 'enable_thinking', 'reasoning_effort', 'skip_special_tokens', 'static_cache', 'seed', 'sampler_priority', 'custom_token_bans', 'negative_prompt', 'dry_sequence_breakers', 'grammar_string', 'grammar_file_row', } loaders_samplers = { 'Transformers': transformers_samplers(), 'ExLlamav3_HF': { 'temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', 'smoothing_curve', 'min_p', 'top_p', 'top_k', 'typical_p', 'xtc_threshold', 'xtc_probability', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'top_n_sigma', 'adaptive_target', 'adaptive_decay', 'dry_multiplier', 'dry_allowed_length', 'dry_base', 'repetition_penalty', 'frequency_penalty', 'presence_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'repetition_penalty_range', 'guidance_scale', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'dynamic_temperature', 'temperature_last', 'auto_max_new_tokens', 'ban_eos_token', 'add_bos_token', 'enable_thinking', 'reasoning_effort', 'skip_special_tokens', 'seed', 'sampler_priority', 'custom_token_bans', 'negative_prompt', 'dry_sequence_breakers', 'grammar_string', 'grammar_file_row', }, 'ExLlamav3': { 'temperature', 'min_p', 'top_p', 'top_k', 'adaptive_target', 'adaptive_decay', 'repetition_penalty', 'frequency_penalty', 'presence_penalty', 'repetition_penalty_range', 'temperature_last', 'sampler_priority', 'auto_max_new_tokens', 'ban_eos_token', 'add_bos_token', 'enable_thinking', 'reasoning_effort', 'seed', 'skip_special_tokens', }, 'llama.cpp': { 'temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'min_p', 'top_p', 'top_k', 'typical_p', 'xtc_threshold', 'xtc_probability', 'top_n_sigma', 'adaptive_target', 'adaptive_decay', 'dry_multiplier', 'dry_allowed_length', 'dry_base', 'repetition_penalty', 'frequency_penalty', 'presence_penalty', 'repetition_penalty_range', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'dynamic_temperature', 'temperature_last', 'auto_max_new_tokens', 'ban_eos_token', 'add_bos_token', 'enable_thinking', 'reasoning_effort', 'seed', 'sampler_priority', 'custom_token_bans', 'dry_sequence_breakers', 'grammar_string', 'grammar_file_row', }, 'TensorRT-LLM': { 'temperature', 'top_p', 'top_k', 'min_p', 'repetition_penalty', 'frequency_penalty', 'presence_penalty', 'no_repeat_ngram_size', 'auto_max_new_tokens', 'ban_eos_token', 'add_bos_token', 'skip_special_tokens', 'seed', } } @functools.cache def list_all_samplers(): all_samplers = set() for k in loaders_samplers: for sampler in loaders_samplers[k]: all_samplers.add(sampler) return sorted(all_samplers) def blacklist_samplers(loader, dynamic_temperature): import gradio as gr all_samplers = list_all_samplers() output = [] for sampler in all_samplers: if loader == 'All' or sampler in loaders_samplers[loader]: if sampler.startswith('dynatemp'): output.append(gr.update(visible=dynamic_temperature)) else: output.append(gr.update(visible=True)) else: output.append(gr.update(visible=False)) return output @functools.cache def get_all_params(): all_params = set() for k in loaders_and_params: for el in loaders_and_params[k]: all_params.add(el) return sorted(all_params) def list_model_elements(): return [ 'filter_by_loader', 'loader', 'cpu_memory', 'gpu_layers', 'fit_target', 'cpu_moe', 'threads', 'threads_batch', 'batch_size', 'ubatch_size', 'ctx_size', 'cache_type', 'tensor_split', 'extra_flags', 'streaming_llm', 'gpu_split', 'compute_dtype', 'quant_type', 'load_in_8bit', 'load_in_4bit', 'attn_implementation', 'cpu', 'disk', 'row_split', 'no_kv_offload', 'no_mmap', 'mlock', 'numa', 'parallel', 'use_double_quant', 'bf16', 'enable_tp', 'tp_backend', 'cfg_cache', 'no_use_fast', 'model_draft', 'draft_max', 'gpu_layers_draft', 'device_draft', 'ctx_size_draft', 'spec_type', 'spec_ngram_size_n', 'spec_ngram_size_m', 'spec_ngram_min_hits', 'mmproj', ] def make_loader_params_visible(loader): import gradio as gr params = [] all_params = get_all_params() if loader in loaders_and_params: params = loaders_and_params[loader] return [gr.update(visible=True) if k in params else gr.update(visible=False) for k in all_params] ================================================ FILE: modules/logging_colors.py ================================================ import logging logger = logging.getLogger('text-generation-webui') def setup_logging(): ''' Copied from: https://github.com/vladmandic/automatic All credits to vladmandic. ''' class RingBuffer(logging.StreamHandler): def __init__(self, capacity): super().__init__() self.capacity = capacity self.buffer = [] self.formatter = logging.Formatter('{ "asctime":"%(asctime)s", "created":%(created)f, "facility":"%(name)s", "pid":%(process)d, "tid":%(thread)d, "level":"%(levelname)s", "module":"%(module)s", "func":"%(funcName)s", "msg":"%(message)s" }') def emit(self, record): msg = self.format(record) # self.buffer.append(json.loads(msg)) self.buffer.append(msg) if len(self.buffer) > self.capacity: self.buffer.pop(0) def get(self): return self.buffer from rich.console import Console from rich.logging import RichHandler from rich.pretty import install as pretty_install from rich.theme import Theme from rich.traceback import install as traceback_install level = logging.DEBUG logger.setLevel(logging.DEBUG) # log to file is always at level debug for facility `sd` console = Console(log_time=True, log_time_format='%H:%M:%S-%f', theme=Theme({ "traceback.border": "black", "traceback.border.syntax_error": "black", "inspect.value.border": "black", })) logging.basicConfig(level=logging.ERROR, format='%(asctime)s | %(name)s | %(levelname)s | %(module)s | %(message)s', handlers=[logging.NullHandler()]) # redirect default logger to null pretty_install(console=console) traceback_install(console=console, extra_lines=1, max_frames=10, width=console.width, word_wrap=False, indent_guides=False, suppress=[]) while logger.hasHandlers() and len(logger.handlers) > 0: logger.removeHandler(logger.handlers[0]) # handlers rh = RichHandler(show_time=True, omit_repeated_times=False, show_level=True, show_path=False, markup=False, rich_tracebacks=True, log_time_format='%H:%M:%S-%f', level=level, console=console) rh.setLevel(level) logger.addHandler(rh) rb = RingBuffer(100) # 100 entries default in log ring buffer rb.setLevel(level) logger.addHandler(rb) logger.buffer = rb.buffer # overrides logging.getLogger("urllib3").setLevel(logging.ERROR) logging.getLogger("httpx").setLevel(logging.ERROR) logging.getLogger("diffusers").setLevel(logging.ERROR) logging.getLogger("torch").setLevel(logging.ERROR) logging.getLogger("lycoris").handlers = logger.handlers setup_logging() ================================================ FILE: modules/logits.py ================================================ import time import traceback import numpy as np from modules import models, shared from modules.logging_colors import logger from modules.models import load_model from modules.text_generation import generate_reply from modules.utils import check_model_loaded global_scores = None def get_next_logits(*args, **kwargs): if shared.args.idle_timeout > 0 and shared.model is None and shared.model_name not in [None, 'None']: shared.model, shared.tokenizer = load_model(shared.model_name) needs_lock = not args[2] # use_samplers if needs_lock: shared.generation_lock.acquire() try: result = _get_next_logits(*args, **kwargs) except Exception: traceback.print_exc() result = None if needs_lock: models.last_generation_time = time.time() shared.generation_lock.release() return result def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, return_dict=False): model_is_loaded, error_message = check_model_loaded() if not model_is_loaded: return error_message, previous # llama.cpp case if shared.model.__class__.__name__ == 'LlamaServer': logprobs = shared.model.get_logits(prompt, state, n_probs=top_logits, use_samplers=use_samplers) if return_dict: output = {} for entry in logprobs: token = repr(entry['token']) if len(token) > 2 and token.startswith("'") and token.endswith("'"): token = token[1:-1] prob = entry['prob'] if use_samplers else np.exp(entry['logprob']) output[token] = prob return output else: output = '' for entry in logprobs: token = repr(entry['token']) if len(token) > 2 and token.startswith("'") and token.endswith("'"): token = token[1:-1] prob = entry['prob'] if use_samplers else np.exp(entry['logprob']) output += f"{prob:.5f} - {token}\n" return output, previous # All other model types else: import torch from modules import sampler_hijack from modules.torch_utils import get_device is_non_hf_exllamav3 = shared.model.__class__.__name__ == 'Exllamav3Model' if not use_samplers: state = {'stream': True} if use_samplers: state['max_new_tokens'] = 1 state['auto_max_new_tokens'] = False state.setdefault('stream', True) for _ in generate_reply(prompt, state): pass scores = sampler_hijack.global_scores[-1] else: if is_non_hf_exllamav3: device = get_device() tokens = shared.tokenizer.encode(prompt) if device: tokens = tokens.to(device) scores = shared.model.get_logits(tokens)[-1][-1] else: device = get_device() tokens = shared.tokenizer.encode(prompt, return_tensors='pt') if device: tokens = tokens.to(device) output = shared.model(input_ids=tokens) scores = output['logits'][-1][-1] probs = torch.softmax(scores.detach(), dim=-1, dtype=torch.float) topk_values, topk_indices = torch.topk(probs, k=top_logits, largest=True, sorted=True) if hasattr(shared.tokenizer, 'convert_ids_to_tokens'): tokens = [shared.tokenizer.convert_ids_to_tokens(int(i)) for i in topk_indices] else: tokens = [shared.tokenizer.decode(i) for i in topk_indices] if return_dict: topk_values = [float(i) for i in topk_values] output = {} for row in list(zip(topk_values, tokens)): key = row[1] if isinstance(key, bytes): try: key = key.decode() except Exception: key = key.decode('latin') output[key] = row[0] return output else: topk_values = [f"{float(i):.5f}" for i in topk_values] output = '' for row in list(zip(topk_values, tokens)): output += f"{row[0]} - {repr(row[1])}\n" return output, previous ================================================ FILE: modules/metadata_gguf.py ================================================ import struct from enum import IntEnum class GGUFValueType(IntEnum): UINT8 = 0 INT8 = 1 UINT16 = 2 INT16 = 3 UINT32 = 4 INT32 = 5 FLOAT32 = 6 BOOL = 7 STRING = 8 ARRAY = 9 UINT64 = 10 INT64 = 11 FLOAT64 = 12 _simple_value_packing = { GGUFValueType.UINT8: " 0: shared.settings['truncation_length'] = shared.args.ctx_size elif loader == 'llama.cpp' and hasattr(model, 'n_ctx') and model.n_ctx: shared.settings['truncation_length'] = model.n_ctx shared.is_multimodal = False if loader.lower() in ('exllamav3', 'llama.cpp') and hasattr(model, 'is_multimodal'): shared.is_multimodal = model.is_multimodal() logger.info(f"Loaded \"{model_name}\" in {(time.time()-t0):.2f} seconds.") logger.info(f"LOADER: \"{loader}\"") logger.info(f"TRUNCATION LENGTH: {shared.settings['truncation_length']}") logger.info(f"INSTRUCTION TEMPLATE: \"{metadata['instruction_template']}\"") return model, tokenizer def llama_cpp_server_loader(model_name): from modules.llama_cpp_server import LlamaServer path = resolve_model_path(model_name) if path.is_file(): model_file = path else: gguf_files = sorted(path.glob('*.gguf')) if not gguf_files: logger.error(f"No .gguf models found in the directory: {path}") return None, None model_file = gguf_files[0] try: model = LlamaServer(model_file) return model, model except Exception as e: logger.error(f"Error loading the model with llama.cpp: {str(e)}") return None, None def transformers_loader(model_name): from modules.transformers_loader import load_model_HF return load_model_HF(model_name) def ExLlamav3_HF_loader(model_name): from modules.exllamav3_hf import Exllamav3HF return Exllamav3HF.from_pretrained(model_name) def ExLlamav3_loader(model_name): from modules.exllamav3 import Exllamav3Model model, tokenizer = Exllamav3Model.from_pretrained(model_name) return model, tokenizer def TensorRT_LLM_loader(model_name): try: from modules.tensorrt_llm import TensorRTLLMModel except ModuleNotFoundError: raise ModuleNotFoundError("Failed to import 'tensorrt_llm'. Please install it manually following the instructions in the TensorRT-LLM GitHub repository.") model = TensorRTLLMModel.from_pretrained(model_name) return model, model.tokenizer def unload_model(keep_model_name=False): if shared.model is None: return model_class_name = shared.model.__class__.__name__ is_llamacpp = (model_class_name == 'LlamaServer') if model_class_name in ['Exllamav3Model', 'Exllamav3HF', 'TensorRTLLMModel']: shared.model.unload() elif model_class_name == 'LlamaServer': shared.model.stop() shared.model = shared.tokenizer = None shared.lora_names = [] shared.model_dirty_from_training = False if not is_llamacpp: from modules.torch_utils import clear_torch_cache clear_torch_cache() if not keep_model_name: shared.model_name = 'None' def reload_model(): unload_model() shared.model, shared.tokenizer = load_model(shared.model_name) def unload_model_if_idle(): global last_generation_time logger.info(f"Setting a timeout of {shared.args.idle_timeout} minutes to unload the model in case of inactivity.") while True: shared.generation_lock.acquire() try: if time.time() - last_generation_time > shared.args.idle_timeout * 60: if shared.model is not None: logger.info("Unloading the model for inactivity.") unload_model(keep_model_name=True) finally: shared.generation_lock.release() time.sleep(60) ================================================ FILE: modules/models_settings.py ================================================ import functools import json import re from math import floor from pathlib import Path import yaml from modules import loaders, metadata_gguf, shared from modules.logging_colors import logger from modules.utils import resolve_model_path def get_fallback_settings(): return { 'bf16': False, 'ctx_size': 8192, 'truncation_length': shared.settings['truncation_length'], 'truncation_length_info': shared.settings['truncation_length'], 'skip_special_tokens': shared.settings['skip_special_tokens'], } def get_model_metadata(model): model_path = resolve_model_path(model) model_settings = {} # Get settings from user_data/models/config.yaml and user_data/models/config-user.yaml settings = shared.model_config for pat in settings: if re.match(pat.lower(), Path(model).name.lower()): for k in settings[pat]: model_settings[k] = settings[pat][k] path = model_path / 'config.json' if path.exists(): hf_metadata = json.loads(open(path, 'r', encoding='utf-8').read()) else: hf_metadata = None if 'loader' not in model_settings: quant_method = None if hf_metadata is None else hf_metadata.get("quantization_config", {}).get("quant_method", None) model_settings['loader'] = infer_loader( model, model_settings, hf_quant_method=quant_method ) # GGUF metadata if model_settings['loader'] == 'llama.cpp': path = model_path if path.is_file(): model_file = path else: gguf_files = list(path.glob('*.gguf')) if not gguf_files: error_msg = f"No .gguf models found in directory: {path}" logger.error(error_msg) raise FileNotFoundError(error_msg) model_file = gguf_files[0] metadata = load_gguf_metadata_with_cache(model_file) for k in metadata: if k.endswith('.context_length'): model_settings['ctx_size'] = 0 model_settings['truncation_length_info'] = metadata[k] elif k.endswith('.block_count'): model_settings['gpu_layers'] = -1 model_settings['max_gpu_layers'] = metadata[k] + 1 if 'tokenizer.chat_template' in metadata: template = metadata['tokenizer.chat_template'] if 'tokenizer.ggml.eos_token_id' in metadata: eos_token = metadata['tokenizer.ggml.tokens'][metadata['tokenizer.ggml.eos_token_id']] else: eos_token = "" if 'tokenizer.ggml.bos_token_id' in metadata: bos_token = metadata['tokenizer.ggml.tokens'][metadata['tokenizer.ggml.bos_token_id']] else: bos_token = "" shared.bos_token = bos_token shared.eos_token = eos_token template = re.sub(r"\{\{-?\s*raise_exception\(.*?\)\s*-?\}\}", "", template, flags=re.DOTALL) template = re.sub(r'raise_exception\([^)]*\)', "''", template) model_settings['instruction_template'] = 'Custom (obtained from model metadata)' model_settings['instruction_template_str'] = template else: # Transformers metadata if hf_metadata is not None: metadata = json.loads(open(path, 'r', encoding='utf-8').read()) if 'pretrained_config' in metadata: metadata = metadata['pretrained_config'] for k in ['max_position_embeddings', 'model_max_length', 'max_seq_len']: if k in metadata: value = metadata[k] elif k in metadata.get('text_config', {}): value = metadata['text_config'][k] else: continue model_settings['truncation_length'] = value model_settings['truncation_length_info'] = value model_settings['ctx_size'] = min(value, 8192) break if 'torch_dtype' in metadata and metadata['torch_dtype'] == 'bfloat16': model_settings['bf16'] = True # Try to find the Jinja instruct template path = model_path / 'tokenizer_config.json' template = None # 1. Prioritize reading from chat_template.jinja if it exists jinja_path = model_path / 'chat_template.jinja' if jinja_path.exists(): with open(jinja_path, 'r', encoding='utf-8') as f: template = f.read() # 2. If no .jinja file, try chat_template.json if template is None: json_template_path = model_path / 'chat_template.json' if json_template_path.exists(): with open(json_template_path, 'r', encoding='utf-8') as f: json_data = json.load(f) if 'chat_template' in json_data: template = json_data['chat_template'] # 3. Fall back to tokenizer_config.json metadata if path.exists(): metadata = json.loads(open(path, 'r', encoding='utf-8').read()) # Only read from metadata if we haven't already loaded from .jinja or .json if template is None and 'chat_template' in metadata: template = metadata['chat_template'] if isinstance(template, list): template = template[0]['template'] # 4. If a template was found from any source, process it if template: shared.bos_token = '' shared.eos_token = '' for k in ['eos_token', 'bos_token']: if k in metadata: value = metadata[k] if isinstance(value, dict): value = value['content'] setattr(shared, k, value) template = re.sub(r"\{\{-?\s*raise_exception\(.*?\)\s*-?\}\}", "", template, flags=re.DOTALL) template = re.sub(r'raise_exception\([^)]*\)', "''", template) model_settings['instruction_template'] = 'Custom (obtained from model metadata)' model_settings['instruction_template_str'] = template if 'instruction_template' not in model_settings: model_settings['instruction_template'] = 'Alpaca' # Apply user settings from user_data/models/config-user.yaml settings = shared.user_config for pat in settings: if re.match(pat.lower(), Path(model).name.lower()): for k in settings[pat]: new_k = k if k == 'n_gpu_layers': new_k = 'gpu_layers' model_settings[new_k] = settings[pat][k] # Load instruction template if defined by name rather than by value if model_settings['instruction_template'] != 'Custom (obtained from model metadata)': model_settings['instruction_template_str'] = load_instruction_template(model_settings['instruction_template']) return model_settings def infer_loader(model_name, model_settings, hf_quant_method=None): path_to_model = resolve_model_path(model_name) if not path_to_model.exists(): loader = None elif shared.args.portable: loader = 'llama.cpp' elif len(list(path_to_model.glob('*.gguf'))) > 0: loader = 'llama.cpp' elif re.match(r'.*\.gguf', model_name.lower()): loader = 'llama.cpp' elif hf_quant_method == 'exl3': loader = 'ExLlamav3' elif re.match(r'.*exl3', model_name.lower()): loader = 'ExLlamav3' else: loader = 'Transformers' return loader def update_model_parameters(state, initial=False): ''' UI: update the command-line arguments based on the interface values ''' elements = loaders.list_model_elements() # the names of the parameters for i, element in enumerate(elements): if element not in state: continue value = state[element] if initial and element in shared.provided_arguments: continue if element == 'cpu_memory' and value == 0: value = vars(shared.args_defaults)[element] setattr(shared.args, element, value) def apply_model_settings_to_state(model, state): ''' UI: update the state variable with the model settings ''' import gradio as gr model_settings = get_model_metadata(model) if 'loader' in model_settings: loader = model_settings.pop('loader') if not (loader == 'ExLlamav3_HF' and state['loader'] == 'ExLlamav3'): state['loader'] = loader for k in model_settings: if k in state and k != 'gpu_layers': # Skip gpu_layers, handle separately state[k] = model_settings[k] # Handle GPU layers and VRAM update for llama.cpp if state['loader'] == 'llama.cpp' and 'gpu_layers' in model_settings: gpu_layers = model_settings['gpu_layers'] # -1 (auto) by default, or user-saved value max_layers = model_settings.get('max_gpu_layers', 256) state['gpu_layers'] = gr.update(value=gpu_layers, maximum=max_layers) vram_info = update_gpu_layers_and_vram( state['loader'], model, gpu_layers, state['ctx_size'], state['cache_type'], ) state['vram_info'] = vram_info return state def save_model_settings(model, state): ''' Save the settings for this model to user_data/models/config-user.yaml ''' if model == 'None': yield ("Not saving the settings because no model is selected in the menu.") return user_config = shared.load_user_config() model_regex = Path(model).name + '$' # For exact matches if model_regex not in user_config: user_config[model_regex] = {} for k in loaders.list_model_elements(): if k == 'loader' or k in loaders.loaders_and_params[state['loader']]: user_config[model_regex][k] = state[k] shared.user_config = user_config output = yaml.dump(user_config, sort_keys=False) p = Path(f'{shared.args.model_dir}/config-user.yaml') with open(p, 'w') as f: f.write(output) yield (f"Settings for `{model}` saved to `{p}`.") def save_instruction_template(model, template): ''' Similar to the function above, but it saves only the instruction template. ''' if model == 'None': yield ("Not saving the template because no model is selected in the menu.") return user_config = shared.load_user_config() model_regex = Path(model).name + '$' # For exact matches if model_regex not in user_config: user_config[model_regex] = {} if template == 'None': user_config[model_regex].pop('instruction_template', None) else: user_config[model_regex]['instruction_template'] = template shared.user_config = user_config output = yaml.dump(user_config, sort_keys=False) p = Path(f'{shared.args.model_dir}/config-user.yaml') with open(p, 'w') as f: f.write(output) if template == 'None': yield (f"Instruction template for `{model}` unset in `{p}`, as the value for template was `{template}`.") else: yield (f"Instruction template for `{model}` saved to `{p}` as `{template}`.") @functools.lru_cache(maxsize=1) def load_gguf_metadata_with_cache(model_file): return metadata_gguf.load_metadata(model_file) def get_model_size_mb(model_file: Path) -> float: filename = model_file.name # Check for multipart pattern match = re.match(r'(.+)-\d+-of-\d+\.gguf$', filename) if match: # It's a multipart file, find all matching parts base_pattern = match.group(1) part_files = sorted(model_file.parent.glob(f'{base_pattern}-*-of-*.gguf')) total_size = sum(p.stat().st_size for p in part_files) else: # Single part total_size = model_file.stat().st_size return total_size / (1024 ** 2) # Return size in MB def estimate_vram(gguf_file, gpu_layers, ctx_size, cache_type): model_file = resolve_model_path(gguf_file) metadata = load_gguf_metadata_with_cache(model_file) size_in_mb = get_model_size_mb(model_file) # Extract values from metadata n_layers = None n_kv_heads = None n_attention_heads = None # Fallback for models without separate KV heads embedding_dim = None for key, value in metadata.items(): if key.endswith('.block_count'): n_layers = value elif key.endswith('.attention.head_count_kv'): n_kv_heads = max(value) if isinstance(value, list) else value elif key.endswith('.attention.head_count'): n_attention_heads = max(value) if isinstance(value, list) else value elif key.endswith('.embedding_length'): embedding_dim = value if n_kv_heads is None: n_kv_heads = n_attention_heads if gpu_layers > n_layers: gpu_layers = n_layers # Convert cache_type to numeric if cache_type == 'q4_0': cache_type = 4 elif cache_type == 'q8_0': cache_type = 8 else: cache_type = 16 # Derived features size_per_layer = size_in_mb / max(n_layers, 1e-6) kv_cache_factor = n_kv_heads * cache_type * ctx_size embedding_per_context = embedding_dim / ctx_size # Calculate VRAM using the model # Details: https://oobabooga.github.io/blog/posts/gguf-vram-formula/ vram = ( (size_per_layer - 17.99552795246051 + 3.148552680382576e-05 * kv_cache_factor) * (gpu_layers + max(0.9690636483914102, cache_type - (floor(50.77817218646521 * embedding_per_context) + 9.987899908205632))) + 1516.522943869404 ) return vram def update_gpu_layers_and_vram(loader, model, gpu_layers, ctx_size, cache_type): """ Compute the estimated VRAM usage for the given GPU layers and return an HTML string for the UI display. """ if loader != 'llama.cpp' or model in ["None", None] or not model.endswith(".gguf") or gpu_layers < 0 or ctx_size == 0: return f"
      Estimated VRAM to load the model: auto
      " vram_usage = estimate_vram(model, gpu_layers, ctx_size, cache_type) return f"
      Estimated VRAM to load the model: {vram_usage:.0f} MiB
      " def load_instruction_template(template): if template == 'None': return '' for filepath in [shared.user_data_dir / 'instruction-templates' / f'{template}.yaml', shared.user_data_dir / 'instruction-templates' / 'Alpaca.yaml']: if filepath.exists(): break else: return '' with open(filepath, 'r', encoding='utf-8') as f: file_contents = f.read() data = yaml.safe_load(file_contents) if 'instruction_template' in data: return data['instruction_template'] else: return _jinja_template_from_old_format(data) def _jinja_template_from_old_format(params, verbose=False): MASTER_TEMPLATE = """ {%- set ns = namespace(found=false) -%} {%- for message in messages -%} {%- if message['role'] == 'system' -%} {%- set ns.found = true -%} {%- endif -%} {%- endfor -%} {%- if not ns.found -%} {{- '<|PRE-SYSTEM|>' + '<|SYSTEM-MESSAGE|>' + '<|POST-SYSTEM|>' -}} {%- endif %} {%- for message in messages %} {%- if message['role'] == 'system' -%} {{- '<|PRE-SYSTEM|>' + message['content'] + '<|POST-SYSTEM|>' -}} {%- else -%} {%- if message['role'] == 'user' -%} {{-'<|PRE-USER|>' + message['content'] + '<|POST-USER|>'-}} {%- else -%} {{-'<|PRE-ASSISTANT|>' + message['content'] + '<|POST-ASSISTANT|>' -}} {%- endif -%} {%- endif -%} {%- endfor -%} {%- if add_generation_prompt -%} {{-'<|PRE-ASSISTANT-GENERATE|>'-}} {%- endif -%} """ if 'context' in params and '<|system-message|>' in params['context']: pre_system = params['context'].split('<|system-message|>')[0] post_system = params['context'].split('<|system-message|>')[1] else: pre_system = '' post_system = '' pre_user = params['turn_template'].split('<|user-message|>')[0].replace('<|user|>', params['user']) post_user = params['turn_template'].split('<|user-message|>')[1].split('<|bot|>')[0] pre_assistant = '<|bot|>' + params['turn_template'].split('<|bot-message|>')[0].split('<|bot|>')[1] pre_assistant = pre_assistant.replace('<|bot|>', params['bot']) post_assistant = params['turn_template'].split('<|bot-message|>')[1] def preprocess(string): return string.replace('\n', '\\n').replace('\'', '\\\'') pre_system = preprocess(pre_system) post_system = preprocess(post_system) pre_user = preprocess(pre_user) post_user = preprocess(post_user) pre_assistant = preprocess(pre_assistant) post_assistant = preprocess(post_assistant) if verbose: print( '\n', repr(pre_system) + '\n', repr(post_system) + '\n', repr(pre_user) + '\n', repr(post_user) + '\n', repr(pre_assistant) + '\n', repr(post_assistant) + '\n', ) result = MASTER_TEMPLATE if 'system_message' in params: result = result.replace('<|SYSTEM-MESSAGE|>', preprocess(params['system_message'])) else: result = result.replace('<|SYSTEM-MESSAGE|>', '') result = result.replace('<|PRE-SYSTEM|>', pre_system) result = result.replace('<|POST-SYSTEM|>', post_system) result = result.replace('<|PRE-USER|>', pre_user) result = result.replace('<|POST-USER|>', post_user) result = result.replace('<|PRE-ASSISTANT|>', pre_assistant) result = result.replace('<|PRE-ASSISTANT-GENERATE|>', pre_assistant.rstrip(' ')) result = result.replace('<|POST-ASSISTANT|>', post_assistant) result = result.strip() return result ================================================ FILE: modules/paths.py ================================================ import sys from pathlib import Path def resolve_user_data_dir(): """ Resolve the user_data directory path. Order of precedence: 1. --user-data-dir CLI flag (pre-parsed from sys.argv before argparse) 2. In --portable mode, prefer ../user_data if it exists 3. Default: 'user_data' """ script_dir = Path(__file__).resolve().parent.parent # Check sys.argv for --user-data-dir before argparse runs for i, arg in enumerate(sys.argv): if arg == '--user-data-dir' and i + 1 < len(sys.argv): return Path(sys.argv[i + 1]) elif arg.startswith('--user-data-dir='): return Path(arg.split('=', 1)[1]) # In portable mode, prefer ../user_data if it exists is_portable = '--portable' in sys.argv if is_portable: parent_path = script_dir.parent / 'user_data' if parent_path.exists(): return parent_path return Path('user_data') ================================================ FILE: modules/presets.py ================================================ import functools import pprint from pathlib import Path import yaml from modules import shared from modules.loaders import loaders_samplers from modules.logging_colors import logger default_preset_values = { 'temperature': 1, 'dynatemp_low': 1, 'dynatemp_high': 1, 'dynatemp_exponent': 1, 'smoothing_factor': 0, 'smoothing_curve': 1, 'top_p': 1, 'top_k': 0, 'min_p': 0, 'top_n_sigma': 0, 'typical_p': 1, 'xtc_threshold': 0.1, 'xtc_probability': 0, 'epsilon_cutoff': 0, 'eta_cutoff': 0, 'tfs': 1, 'top_a': 0, 'adaptive_target': 0, 'adaptive_decay': 0.9, 'dry_multiplier': 0, 'dry_allowed_length': 2, 'dry_base': 1.75, 'repetition_penalty': 1, 'frequency_penalty': 0, 'presence_penalty': 0, 'encoder_repetition_penalty': 1, 'no_repeat_ngram_size': 0, 'repetition_penalty_range': 1024, 'penalty_alpha': 0, 'guidance_scale': 1, 'mirostat_mode': 0, 'mirostat_tau': 5, 'mirostat_eta': 0.1, 'do_sample': True, 'dynamic_temperature': False, 'temperature_last': False, 'sampler_priority': 'repetition_penalty\npresence_penalty\nfrequency_penalty\ndry\ntop_n_sigma\ntemperature\ndynamic_temperature\nquadratic_sampling\ntop_k\ntop_p\ntypical_p\nepsilon_cutoff\neta_cutoff\ntfs\ntop_a\nmin_p\nadaptive_p\nmirostat\nxtc\nencoder_repetition_penalty\nno_repeat_ngram', 'dry_sequence_breakers': '"\\n", ":", "\\"", "*"', } def default_preset(): result = dict(default_preset_values) if shared.args.portable: samplers = result['sampler_priority'].split('\n') samplers = [sampler for sampler in samplers if sampler in ["dry", "top_k", "top_p", "top_n_sigma", "min_p", "temperature", "xtc", "typical_p", "repetition_penalty"]] result['sampler_priority'] = '\n'.join(samplers) return result def presets_params(): return [k for k in default_preset()] def load_preset(name, verbose=False): generate_params = default_preset() if name not in ['None', None, '']: path = shared.user_data_dir / 'presets' / f'{name}.yaml' if path.exists(): with open(path, 'r') as infile: preset = yaml.safe_load(infile) for k in preset: generate_params[k] = preset[k] else: logger.error(f"The preset \"{name}\" does not exist under \"{path}\". Using the default parameters.") if verbose: logger.info(f"\"{name}\" preset:") pprint.PrettyPrinter(indent=4, width=1, sort_dicts=False).pprint(remove_defaults(generate_params)) return generate_params @functools.cache def load_preset_memoized(name): return load_preset(name) def load_preset_for_ui(name, state): generate_params = load_preset(name, verbose=True) state.update(generate_params) return state, *[generate_params[k] for k in presets_params()] def reset_preset_for_ui(name, state): """Reset current preset to its saved values from file""" generate_params = load_preset(name, verbose=True) state.update(generate_params) return state, *[generate_params[k] for k in presets_params()] def neutralize_samplers_for_ui(state): """Set all samplers to their default/neutral values""" generate_params = default_preset() state.update(generate_params) return state, *[generate_params[k] for k in presets_params()] def loader_contains(sampler): if sampler == 'dynamic_temperature' and 'dynatemp_low' in loaders_samplers[shared.args.loader]: return True else: return sampler in loaders_samplers[shared.args.loader] def remove_defaults(state): defaults = default_preset() data = {k: state[k] for k in presets_params()} for k in list(data.keys()): if data[k] == defaults[k]: del data[k] return data def generate_preset_yaml(state): data = remove_defaults(state) return yaml.dump(data, sort_keys=False) ================================================ FILE: modules/prompts.py ================================================ from pathlib import Path from modules import shared, utils from modules.text_generation import get_encoded_length def load_prompt(fname): if not fname: # Create new file new_name = utils.current_time() prompt_path = shared.user_data_dir / "logs" / "notebook" / f"{new_name}.txt" prompt_path.parent.mkdir(parents=True, exist_ok=True) initial_content = "In this story," prompt_path.write_text(initial_content, encoding='utf-8') # Update settings to point to new file shared.settings['prompt-notebook'] = new_name return initial_content file_path = shared.user_data_dir / 'logs' / 'notebook' / f'{fname}.txt' if file_path.exists(): with open(file_path, 'r', encoding='utf-8') as f: text = f.read() text = text.rstrip() return text else: return '' def count_tokens(text): try: tokens = get_encoded_length(text) return str(tokens) except Exception: return '0' ================================================ FILE: modules/reasoning.py ================================================ import html as html_module # Thinking block format definitions: (start_tag, end_tag, content_start_tag) # Use None for start_tag to match from beginning (end-only formats should be listed last) THINKING_FORMATS = [ ('', '', None), ('<|channel|>analysis<|message|>', '<|end|>', '<|channel|>final<|message|>'), ('<|channel|>commentary<|message|>', '<|end|>', '<|channel|>final<|message|>'), ('', '', None), ('<|think|>', '<|end|>', '<|content|>'), # Solar Open # ('Thinking Process:', '', None), # Qwen3.5 verbose thinking outside tags -- removed: too prone to false positives in streaming (None, '', None), # End-only variant (e.g., Qwen3-next) ] def extract_reasoning(text, html_escaped=False): """Extract reasoning/thinking blocks from the beginning of a string. When html_escaped=True, tags are HTML-escaped before searching (for use on already-escaped UI strings). Returns (reasoning_content, final_content) where reasoning_content is None if no thinking block is found. """ if not text: return None, text esc = html_module.escape if html_escaped else lambda s: s for start_tag, end_tag, content_tag in THINKING_FORMATS: end_esc = esc(end_tag) content_esc = esc(content_tag) if content_tag else None if start_tag is None: # End-only format: require end tag, start from beginning end_pos = text.find(end_esc) if end_pos == -1: continue thought_start = 0 else: # Normal format: require start tag start_esc = esc(start_tag) start_pos = text.find(start_esc) if start_pos == -1: # During streaming, the start tag may be arriving partially. # If the text is a prefix of a start tag, return empty content # to prevent the partial tag from leaking. stripped = text.strip() if stripped and start_esc.startswith(stripped): return '', '' continue thought_start = start_pos + len(start_esc) end_pos = text.find(end_esc, thought_start) if end_pos == -1: # End tag missing - check if content tag can serve as fallback if content_esc: content_pos = text.find(content_esc, thought_start) if content_pos != -1: thought_end = content_pos content_start = content_pos + len(content_esc) else: thought_end = len(text) content_start = len(text) else: thought_end = len(text) content_start = len(text) else: thought_end = end_pos if content_esc: content_pos = text.find(content_esc, end_pos) if content_pos != -1: content_start = content_pos + len(content_esc) else: # Content tag expected but not yet present (e.g. partial # streaming) — suppress intermediate tags between end_tag # and content_tag so they don't leak as content. content_start = len(text) else: content_start = end_pos + len(end_esc) return text[thought_start:thought_end], text[content_start:] # Handle standalone GPT-OSS final channel marker without a preceding # analysis/commentary block (the model skipped thinking entirely). for marker in ['<|start|>assistant<|channel|>final<|message|>', '<|channel|>final<|message|>']: marker_esc = esc(marker) pos = text.find(marker_esc) if pos != -1: before = text[:pos].strip() after = text[pos + len(marker_esc):] return (before if before else None), after return None, text ================================================ FILE: modules/sampler_hijack.py ================================================ import json import math import pprint import random import torch import transformers from transformers.generation.logits_process import ( LogitNormalization, LogitsProcessor, LogitsProcessorList ) from modules import shared from modules.logging_colors import logger from modules.torch_utils import get_device original_init = transformers.GenerationConfig.__init__ original_get_logits_processor = transformers.GenerationMixin._get_logits_processor global_scores = None class TemperatureLogitsWarperCustom(LogitsProcessor): ''' A copy of the original Transformers temperature logits warper. ''' def __init__(self, temperature: float): if not isinstance(temperature, float) or not (temperature > 0): except_msg = ( f"`temperature` (={temperature}) has to be a strictly positive float, otherwise your next token " "scores will be invalid." ) if isinstance(temperature, float) and temperature == 0.0: except_msg += " If you're looking for greedy decoding strategies, set `do_sample=False`." raise ValueError(except_msg) self.temperature = temperature def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: scores = scores / self.temperature return scores class DynamicTemperatureLogitsWarper(LogitsProcessor): ''' Dynamic temperature. ''' def __init__(self, dynatemp_low: float, dynatemp_high: float, dynatemp_exponent: float): self.dynatemp_low = dynatemp_low self.dynatemp_high = dynatemp_high self.dynatemp_exponent = dynatemp_exponent def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: min_temp = self.dynatemp_low max_temp = self.dynatemp_high exponent_val = self.dynatemp_exponent # Convert logits to probabilities probs = torch.softmax(scores, dim=-1) # Calculate entropy of the softmax probabilities entropy = -1.0 * torch.where(probs > 0, probs * torch.log(probs), torch.zeros_like(probs)).sum() # Guard against future possible division by zero entropy = max(entropy, torch.tensor(1e-10)) # Ensures entropy is slightly greater than 0 # Any logits which are not -Infinity will be considered for calculating max entropy. num_valid_tokens = torch.sum(scores > -float('inf')).item() # Now, calculate the max entropy by using only the valid tokens' count max_entropy = math.log(num_valid_tokens) # Guard against future possible division by zero max_entropy = max_entropy if max_entropy > 0.0 else 1e-10 # Normalize the entropy normalized_entropy = entropy / max_entropy # Map the normalized entropy to the desired temperature range using the power function dyn_temp = min_temp + (max_temp - min_temp) * (normalized_entropy.pow(exponent_val)) # Apply the dynamically calculated temperature scaling scores = scores / dyn_temp # print("----------------------\nTemperature from generation_config:", self.temperature) # print("min_temp:", min_temp) # print("max_temp:", max_temp) # print("Entropy:", entropy.item()) # print("Max Possible Entropy considering valid tokens only:", max_entropy) # print("Normalized Entropy:", normalized_entropy.item()) # print("Dynamic Temperature (dyn_temp):", dyn_temp.item()) # print("----------------------") # max_prob_token_id = torch.argmax(scores, dim=-1) # Get the token ID with the highest probability # max_prob_token = shared.tokenizer.convert_ids_to_tokens(int(max_prob_token_id)) # Convert ID to token # print("--- T=", float(dyn_temp), "token=", max_prob_token, "min=", min_temp, "max=", max_temp, "exponent=", exponent_val) return scores class QuadraticSamplingLogitsWarper(LogitsProcessor): ''' Quadratic sampling with smoothing factor and smoothing curve parameters. ''' def __init__(self, smoothing_factor, smoothing_curve): self.smoothing_factor = smoothing_factor self.smoothing_curve = smoothing_curve def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: # Compute necessary values max_logit = scores.max() diff = scores - max_logit k = (3 - self.smoothing_curve) / 2 s = (self.smoothing_curve - 1) / 2 # Apply transformation to non-negative infinity values transformed_logits = torch.where( scores != float('-inf'), -(k * self.smoothing_factor * diff**2) + (s * self.smoothing_factor * diff**3) + max_logit, scores ) return transformed_logits class TailFreeLogitsWarper(LogitsProcessor): def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): tfs = float(tfs) if tfs < 0 or tfs > 1.0: raise ValueError(f"`tfs` has to be a float >= 0 and <= 1, but is {tfs}") self.tfs = tfs self.filter_value = filter_value self.min_tokens_to_keep = min_tokens_to_keep def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: sorted_logits, sorted_indices = torch.sort(scores, descending=True) probs = sorted_logits.softmax(dim=-1) # Compute second derivative normalized CDF d2 = probs.diff().diff().abs() normalized_d2 = d2 / d2.sum(dim=-1, keepdim=True) normalized_d2_cdf = normalized_d2.cumsum(dim=-1) # Remove tokens with CDF value above the threshold (token with 0 are kept) sorted_indices_to_remove = normalized_d2_cdf > self.tfs # Centre the distribution around the cutoff as in the original implementation of the algorithm sorted_indices_to_remove = torch.cat( ( torch.zeros(scores.shape[0], 1, dtype=torch.bool, device=scores.device), sorted_indices_to_remove, torch.ones(scores.shape[0], 1, dtype=torch.bool, device=scores.device), ), dim=-1, ) if self.min_tokens_to_keep > 1: # Keep at least min_tokens_to_keep sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) scores = scores.masked_fill(indices_to_remove, self.filter_value) return scores class TopALogitsWarper(LogitsProcessor): def __init__(self, top_a: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): top_a = float(top_a) if top_a < 0 or top_a > 1.0: raise ValueError(f"`top_a` has to be a float >= 0 and <= 1, but is {top_a}") self.top_a = top_a self.filter_value = filter_value self.min_tokens_to_keep = min_tokens_to_keep def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: sorted_logits, sorted_indices = torch.sort(scores, descending=True) probs = sorted_logits.softmax(dim=-1) # Remove tokens with probability less than top_a*(max(probs))^2 (token with 0 are kept) probs_max = probs[..., 0, None] sorted_indices_to_remove = probs < probs_max * probs_max * self.top_a if self.min_tokens_to_keep > 1: # Keep at least min_tokens_to_keep sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) scores = scores.masked_fill(indices_to_remove, self.filter_value) return scores class TopNSigmaLogitsWarper(LogitsProcessor): def __init__(self, n_sigma: float = 2.0, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): """ Initialize Top-nσ Sampling logits warper. Args: n_sigma: The threshold multiplier for standard deviation filter_value: Value to assign to filtered logits min_tokens_to_keep: Minimum number of tokens to keep """ if n_sigma < 0: raise ValueError(f"`n_sigma` must be a non-negative float, but is {n_sigma}") self.n_sigma = n_sigma self.filter_value = filter_value self.min_tokens_to_keep = min_tokens_to_keep def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: # Calculate max of logits max_logit = torch.max(scores, dim=-1, keepdim=True)[0] # Calculate standard deviation only on finite values finite_mask = torch.isfinite(scores) finite_scores = scores.masked_fill(~finite_mask, 0.0) std_logit = torch.std(finite_scores, dim=-1, keepdim=True) # Create mask where tokens with logits >= max_logit - n_sigma * std_logit are kept threshold = max_logit - self.n_sigma * std_logit indices_to_remove = scores < threshold if self.min_tokens_to_keep > 1: # Keep at least min_tokens_to_keep tokens top_k_indices = torch.topk(scores, self.min_tokens_to_keep, dim=-1)[1] indices_to_remove.scatter_(-1, top_k_indices, False) # Apply mask by setting filtered tokens to filter_value scores = scores.masked_fill(indices_to_remove, self.filter_value) return scores class AdaptivePLogitsWarper(LogitsProcessor): ''' Adaptive-p sampling. A stateful sampler that favors tokens near a target probability, using an EMA-based control loop to adapt over time. Matches the llama.cpp implementation from PR #17927. ''' DISTRIBUTION_WIDTH = 0.3 PEAK_LOGIT_VALUE = 5.0 SHARPNESS = 10.0 INV_WIDTH = 1.0 / DISTRIBUTION_WIDTH def __init__(self, adaptive_target, adaptive_decay, filter_value=-float("Inf"), min_tokens_to_keep=1): self.target = adaptive_target self.decay = min(adaptive_decay, 0.99) self.filter_value = filter_value self.min_tokens_to_keep = min_tokens_to_keep # Initialize EMA at equilibrium (as if target was already achieved) if self.decay < 1.0: self.weighted_sum = self.target / (1.0 - self.decay) self.total_weight = 1.0 / (1.0 - self.decay) else: self.weighted_sum = 0.0 self.total_weight = 0.0 def __call__(self, input_ids, scores): logits = scores[0] # Compute original probabilities (before transform) probs = torch.softmax(logits, dim=-1) # Compute adapted target using proportional control on the EMA if self.total_weight > 0: ema_avg = self.weighted_sum / self.total_weight else: ema_avg = self.target adapted_target = max(0.0, min(1.0, 2.0 * self.target - ema_avg)) # Adaptive probability transform: # quadratic near target for fine differentiation, transitioning # to linear decay in the tails for proper suppression after softmax dist = torch.abs((probs - adapted_target) * self.INV_WIDTH) new_logits = self.PEAK_LOGIT_VALUE - self.SHARPNESS * dist * dist / (1.0 + dist) # Preserve already-masked tokens (-inf logits from prior samplers) new_logits = torch.where(torch.isfinite(logits), new_logits, logits) # Softmax and sample from the transformed distribution new_probs = torch.softmax(new_logits, dim=-1) selected = torch.multinomial(new_probs, num_samples=1, replacement=True) # Update EMA with the original probability of the selected token original_prob = probs[selected[0]].item() self.weighted_sum = original_prob + self.decay * self.weighted_sum self.total_weight = 1.0 + self.decay * self.total_weight # Mask all tokens except the selected one indices_to_remove = torch.ones_like(scores[0], dtype=torch.bool) indices_to_remove[selected[0]] = False indices_to_remove = indices_to_remove.unsqueeze(0) scores = scores.masked_fill(indices_to_remove, self.filter_value) return scores # Exclude Top Choices (XTC) class XTCLogitsWarper(LogitsProcessor): def __init__(self, threshold: float, probability: float, filter_value: float = -float("Inf")): self.threshold = threshold self.probability = probability self.filter_value = filter_value self.special_token_ids = [ shared.tokenizer.encode("\n")[-1], ] if shared.tokenizer.eos_token_id is not None: self.special_token_ids.append(shared.tokenizer.eos_token_id) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: # `random` returns values in the half-open range [0, 1), so setting `probability` # to 0 means the sampler never takes action, while setting it to 1 means the sampler # always takes action. # # Note that while XTC is most intuitively described as "if multiple tokens meet # the threshold, then with probability...", reversing the two conditions is logically # equivalent, and improves performance because processing can immediately be stopped # if the random check fails. if random.random() >= self.probability: return scores sorted_logits, sorted_indices = torch.sort(scores, descending=True) probs = sorted_logits.softmax(dim=-1) sorted_indices_to_remove = torch.full_like(probs, False, dtype=torch.bool) # This operation sets exactly those indices to `True` for which the next index has # probability above the threshold. Since `probs` is sorted, those are the indices # of all tokens that meet the threshold, *except* the least probable one. sorted_indices_to_remove[..., :-1] = probs[..., 1:] >= self.threshold # Convert sorted_indices_to_remove to the original indices indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) # If newline or EOS tokens would be removed, return the original scores if indices_to_remove[:, self.special_token_ids].any(): return scores # Otherwise, remove tokens with the mask scores = scores.masked_fill(indices_to_remove, self.filter_value) return scores class DRYLogitsProcessor(LogitsProcessor): def __init__(self, multiplier: float, base: float, allowed_length: int, sequence_breakers: set[int], _range: int): self.multiplier = multiplier self.base = base self.allowed_length = allowed_length self.sequence_breakers = sequence_breakers self._range = _range def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: if self._range > 0: input_ids = input_ids[:, -self._range:] for input_ids_row, scores_row in zip(input_ids, scores): # Use normal Python data types for improved performance input_ids = input_ids_row.tolist() last_token = input_ids[-1] if last_token in self.sequence_breakers: continue # Exclude the last token as it always matches. match_indices = [] for idx, val in enumerate(input_ids[:-1]): if val == last_token: match_indices.append(idx) # Stores the maximum matching sequence length # for each token immediately following the sequence in the input. match_lengths = {} for i in match_indices: next_token = input_ids[i + 1] if next_token in self.sequence_breakers: continue # We have already found that `last_token` matches at this index, # so the match is at least of length 1. match_length = 1 # Extend the match backwards (at most to 50 to prevent exponent overflow at penalty calculation) (this cap also improves performance on worst case) while match_length < 50: j = i - match_length if j < 0: # Start of input reached. break previous_token = input_ids[-(match_length + 1)] if input_ids[j] != previous_token: # Start of match reached. break if previous_token in self.sequence_breakers: # Sequence-breaking token reached. break match_length += 1 if next_token in match_lengths: match_lengths[next_token] = max(match_length, match_lengths[next_token]) else: match_lengths[next_token] = match_length # Apply penalties. for token, match_length in match_lengths.items(): if match_length >= self.allowed_length: penalty = self.multiplier * self.base ** (match_length - self.allowed_length) scores_row[token] -= penalty return scores class MirostatLogitsWarper(LogitsProcessor): def __init__(self, mirostat_mode: int, mirostat_tau: float, mirostat_eta: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): if mirostat_mode not in [2]: raise ValueError(f"`mirostat` has to be a an integer 2, but is {mirostat_mode}") self.mirostat_mode = mirostat_mode self.mirostat_eta = mirostat_eta self.mirostat_tau = mirostat_tau self.filter_value = filter_value self.min_tokens_to_keep = min_tokens_to_keep self.mu = 2 * self.mirostat_tau self.e = 0 def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: logits = scores[0] sorted_logits, sorted_indices = torch.sort(logits, descending=True) prob_original = torch.softmax(sorted_logits, dim=-1).tolist() # candidates # Truncate the words with surprise values greater than mu for i, candidate in enumerate(prob_original): if candidate > 0 and -math.log2(candidate) > self.mu: if (i == 0): sorted_logits = sorted_logits[:1] else: sorted_logits = sorted_logits[:i] break # Normalize the probabilities of the remaining words prob_topk = torch.softmax(sorted_logits, dim=0) prev_i = torch.multinomial(prob_topk, num_samples=1, replacement=True) device = get_device() if device: prob_topk = prob_topk.to(device) prev_i = prev_i.to(device) observed_surprise = -math.log2(prob_topk[prev_i]) self.e = observed_surprise - self.mirostat_tau # Update mu using the learning rate and error self.mu -= self.mirostat_eta * self.e sorted_indices_to_remove = torch.ones_like(scores[0], dtype=torch.bool) sorted_indices_to_remove[prev_i] = False indices_to_remove = sorted_indices_to_remove.unsqueeze(0).scatter(1, sorted_indices.unsqueeze(0), sorted_indices_to_remove.unsqueeze(0)) scores = scores.masked_fill(indices_to_remove, self.filter_value) return scores class SpyLogitsWarper(LogitsProcessor): def __init__(self): pass def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: global global_scores global_scores = scores return scores class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor): def __init__(self, penalty: float, _range: int): if not (penalty > 0): raise ValueError(f"`penalty` has to be strictly positive, but is {penalty}") self.penalty = penalty self._range = _range def apply_repetition_penalty(self, input_ids_row, scores_row): unique_ids = torch.unique(input_ids_row) score = torch.gather(scores_row, 0, unique_ids) # Apply multiplicative repetition penalty score = torch.where(score < 0, score * self.penalty, score / self.penalty) scores_row.scatter_(0, unique_ids, score) return scores_row def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: input_ids = input_ids[:, -self._range:] for input_ids_row, scores_row in zip(input_ids, scores): scores_row = self.apply_repetition_penalty(input_ids_row, scores_row) return scores class PresencePenaltyLogitsProcessor(LogitsProcessor): def __init__(self, presence_penalty: float, _range: int): self.presence_penalty = presence_penalty self._range = _range def apply_presence_penalty(self, input_ids_row, scores_row): unique_ids, counts = torch.unique(input_ids_row, return_counts=True) # Apply presence penalty raw_presence_penalty = (counts > 0).to(scores_row.dtype) presence_penalty = raw_presence_penalty * self.presence_penalty scores_row.scatter_add_(0, unique_ids, -presence_penalty) return scores_row def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: input_ids = input_ids[:, -self._range:] for input_ids_row, scores_row in zip(input_ids, scores): scores_row = self.apply_presence_penalty(input_ids_row, scores_row) return scores class FrequencyPenaltyLogitsProcessor(LogitsProcessor): def __init__(self, frequency_penalty: float, _range: int): self.frequency_penalty = frequency_penalty self._range = _range def apply_frequency_penalty(self, input_ids_row, scores_row): unique_ids, counts = torch.unique(input_ids_row, return_counts=True) # Apply frequency penalty raw_frequency_penalty = counts.to(scores_row.dtype) frequency_penalty = raw_frequency_penalty * self.frequency_penalty scores_row.scatter_add_(0, unique_ids, -frequency_penalty) return scores_row def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: input_ids = input_ids[:, -self._range:] for input_ids_row, scores_row in zip(input_ids, scores): scores_row = self.apply_frequency_penalty(input_ids_row, scores_row) return scores def get_logits_processor_patch(self, **kwargs): generation_config = kwargs['generation_config'] # Parameter sanitization if isinstance(generation_config.temperature, int): generation_config.temperature = float(generation_config.temperature) # Must be float # Get the original warpers warpers = original_get_logits_processor(self, **kwargs) for i in range(len(warpers) - 1, -1, -1): # Replace temperature with our modified class. if warpers[i].__class__.__name__ == 'TemperatureLogitsWarper': warpers[i] = TemperatureLogitsWarperCustom( generation_config.temperature, ) # Stuff we don't need elif warpers[i].__class__.__name__ in ['RepetitionPenaltyLogitsProcessor']: del warpers[i] # Add custom warpers warpers_to_add = LogitsProcessorList() min_tokens_to_keep = 2 if generation_config.num_beams > 1 else 1 if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0: warpers_to_add.append( RepetitionPenaltyLogitsProcessorWithRange( penalty=generation_config.repetition_penalty, _range=generation_config.repetition_penalty_range ) ) if generation_config.presence_penalty is not None and generation_config.presence_penalty != 0.0: warpers_to_add.append( PresencePenaltyLogitsProcessor( presence_penalty=generation_config.presence_penalty, _range=generation_config.repetition_penalty_range ) ) if generation_config.frequency_penalty is not None and generation_config.frequency_penalty != 0.0: warpers_to_add.append( FrequencyPenaltyLogitsProcessor( frequency_penalty=generation_config.frequency_penalty, _range=generation_config.repetition_penalty_range ) ) if generation_config.dry_multiplier is not None and generation_config.dry_multiplier > 0.0: dry_sequence_breakers = generation_config.dry_sequence_breakers # Support both JSON array notation and comma-separated strings. if not dry_sequence_breakers.startswith("["): dry_sequence_breakers = "[" + dry_sequence_breakers + "]" sequence_breaker_strings = json.loads(dry_sequence_breakers) # Prefix with 'a' to get the correct encoding of the token at the end of a text. sequence_breakers = { shared.tokenizer.encode(f'a{s}')[-1] for s in sequence_breaker_strings } warpers.append( DRYLogitsProcessor( multiplier=generation_config.dry_multiplier, base=generation_config.dry_base, allowed_length=generation_config.dry_allowed_length, sequence_breakers=sequence_breakers, _range=generation_config.repetition_penalty_range, ) ) if generation_config.tfs is not None and 0.0 <= generation_config.tfs < 1.0: warpers_to_add.append( TailFreeLogitsWarper( tfs=generation_config.tfs, min_tokens_to_keep=min_tokens_to_keep ) ) if generation_config.top_a is not None and 0.0 < generation_config.top_a <= 1.0: warpers_to_add.append( TopALogitsWarper( top_a=generation_config.top_a, min_tokens_to_keep=min_tokens_to_keep ) ) if generation_config.top_n_sigma is not None and generation_config.top_n_sigma > 0.0: warpers_to_add.append( TopNSigmaLogitsWarper( n_sigma=generation_config.top_n_sigma, min_tokens_to_keep=min_tokens_to_keep ) ) if generation_config.adaptive_target is not None and generation_config.adaptive_target > 0.0: warpers_to_add.append( AdaptivePLogitsWarper( adaptive_target=generation_config.adaptive_target, adaptive_decay=generation_config.adaptive_decay, min_tokens_to_keep=min_tokens_to_keep ) ) if generation_config.xtc_probability is not None and generation_config.xtc_probability > 0: warpers_to_add.append( XTCLogitsWarper( threshold=generation_config.xtc_threshold, probability=generation_config.xtc_probability, ) ) if generation_config.dynamic_temperature: warpers_to_add.append( DynamicTemperatureLogitsWarper( dynatemp_low=generation_config.dynatemp_low, dynatemp_high=generation_config.dynatemp_high, dynatemp_exponent=generation_config.dynatemp_exponent, ) ) if generation_config.smoothing_factor > 0: warpers_to_add.append( QuadraticSamplingLogitsWarper( smoothing_factor=generation_config.smoothing_factor, smoothing_curve=generation_config.smoothing_curve ) ) if generation_config.mirostat_mode is not None and generation_config.mirostat_mode == 2: warpers_to_add.append( MirostatLogitsWarper( mirostat_mode=generation_config.mirostat_mode, mirostat_eta=generation_config.mirostat_eta, mirostat_tau=generation_config.mirostat_tau, min_tokens_to_keep=min_tokens_to_keep ) ) if len(warpers) > 0 and isinstance(warpers[-1], LogitNormalization): normalize = warpers.pop(-1) else: normalize = None warpers += warpers_to_add # Sort the samplers. sampler_priority = generation_config.sampler_priority # Handle temperature_last if generation_config.temperature_last: for param_name in ['temperature', 'dynamic_temperature', 'quadratic_sampling']: if param_name in sampler_priority: index = sampler_priority.index(param_name) sampler_priority.append(sampler_priority.pop(index)) else: sampler_priority.append(param_name) class_name_to_nickname = { 'DynamicTemperatureLogitsWarper': 'dynamic_temperature', 'EpsilonLogitsWarper': 'epsilon_cutoff', 'EtaLogitsWarper': 'eta_cutoff', 'MinPLogitsWarper': 'min_p', 'MirostatLogitsWarper': 'mirostat', 'QuadraticSamplingLogitsWarper': 'quadratic_sampling', 'TailFreeLogitsWarper': 'tfs', 'TemperatureLogitsWarperCustom': 'temperature', 'TopALogitsWarper': 'top_a', 'TopNSigmaLogitsWarper': 'top_n_sigma', 'AdaptivePLogitsWarper': 'adaptive_p', 'TopKLogitsWarper': 'top_k', 'TopPLogitsWarper': 'top_p', 'TypicalLogitsWarper': 'typical_p', 'XTCLogitsWarper': 'xtc', 'RepetitionPenaltyLogitsProcessorWithRange': 'repetition_penalty', 'PresencePenaltyLogitsProcessor': 'presence_penalty', 'FrequencyPenaltyLogitsProcessor': 'frequency_penalty', 'DRYLogitsProcessor': 'dry', 'EncoderRepetitionPenaltyLogitsProcessor': 'encoder_repetition_penalty', 'NoRepeatNGramLogitsProcessor': 'no_repeat_ngram', } def custom_sort_key(obj): class_name = obj.__class__.__name__ # Return -1 if class_name is not mapped if class_name not in class_name_to_nickname or class_name_to_nickname[class_name] not in sampler_priority: return -1 return sampler_priority.index(class_name_to_nickname[class_name]) # Sort the list using the custom key function warpers = sorted(warpers, key=custom_sort_key) if shared.args.verbose: logger.info("WARPERS=") pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint([x.__class__.__name__ for x in warpers]) print() if normalize is not None: warpers.append(normalize) warpers.append(SpyLogitsWarper()) warpers = LogitsProcessorList(warpers) return warpers def generation_config_init_patch(self, **kwargs): original_init(self, **kwargs) self.min_p = kwargs.pop("min_p", 0.0) self.dynamic_temperature = kwargs.pop("dynamic_temperature", False) self.dynatemp_low = kwargs.pop("dynatemp_low", 1) self.dynatemp_high = kwargs.pop("dynatemp_high", 1) self.dynatemp_exponent = kwargs.pop("dynatemp_exponent", 1) self.smoothing_factor = kwargs.pop("smoothing_factor", 0.0) self.smoothing_curve = kwargs.pop("smoothing_curve", 1.0) self.tfs = kwargs.pop("tfs", 1.0) self.top_a = kwargs.pop("top_a", 0.0) self.top_n_sigma = kwargs.pop("top_n_sigma", 0.0) self.adaptive_target = kwargs.pop("adaptive_target", 0.0) self.adaptive_decay = kwargs.pop("adaptive_decay", 0.9) self.mirostat_mode = kwargs.pop("mirostat_mode", 0) self.mirostat_eta = kwargs.pop("mirostat_eta", 0.1) self.mirostat_tau = kwargs.pop("mirostat_tau", 5) self.repetition_penalty_range = kwargs.pop("repetition_penalty_range", 0) self.presence_penalty = kwargs.pop("presence_penalty", 0) self.frequency_penalty = kwargs.pop("frequency_penalty", 0) self.dry_multiplier = kwargs.pop("dry_multiplier", 0.0) self.dry_base = kwargs.pop("dry_base", 1.75) self.dry_allowed_length = kwargs.pop("dry_allowed_length", 2) self.dry_sequence_breakers = kwargs.pop("dry_sequence_breakers", '"\\n", ":", "\\"", "*"') self.xtc_threshold = kwargs.pop("xtc_threshold", 0.1) self.xtc_probability = kwargs.pop("xtc_probability", 0) self.temperature_last = kwargs.pop("temperature_last", False) self.sampler_priority = kwargs.pop("sampler_priority", ['repetition_penalty', 'presence_penalty', 'frequency_penalty', 'dry', 'temperature', 'dynamic_temperature', 'quadratic_sampling', 'top_n_sigma', 'top_k', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'min_p', 'adaptive_p', 'mirostat', 'xtc', 'encoder_repetition_penalty', 'no_repeat_ngram']) def hijack_samplers(): transformers.GenerationMixin._get_logits_processor = get_logits_processor_patch transformers.GenerationConfig.__init__ = generation_config_init_patch ================================================ FILE: modules/sane_markdown_lists.py ================================================ # Code based on the Sane List Extension for Python-Markdown # ======================================= # Modify the behavior of Lists in Python-Markdown to act in a sane manner. # See https://Python-Markdown.github.io/extensions/sane_lists # for documentation. # Original code Copyright 2011 [Waylan Limberg](http://achinghead.com) # All changes Copyright 2011-2014 The Python Markdown Project # License: [BSD](https://opensource.org/licenses/bsd-license.php) """ Modify the behavior of Lists in Python-Markdown to act in a sane manner. """ from __future__ import annotations import re import xml.etree.ElementTree as etree from typing import TYPE_CHECKING from markdown import Extension from markdown.blockparser import BlockParser from markdown.blockprocessors import ( ListIndentProcessor, OListProcessor, ParagraphProcessor ) if TYPE_CHECKING: # pragma: no cover from markdown import blockparser # The min. number of added leading spaces needed to start a nested list MIN_NESTED_LIST_INDENT = 2 assert MIN_NESTED_LIST_INDENT > 1, "'MIN_NESTED_LIST_INDENT' must be > 1" class SaneListIndentProcessor(ListIndentProcessor): """ Process children of list items. Example * a list item process this part or this part """ def __init__(self, *args): super().__init__(*args) self.INDENT_RE = re.compile(r'^(([ ])+)') def test(self, parent: etree.Element, block: str) -> bool: return block.startswith(' ' * MIN_NESTED_LIST_INDENT) and \ not self.parser.state.isstate('detabbed') and \ (parent.tag in self.ITEM_TYPES or (len(parent) and parent[-1] is not None and (parent[-1].tag in self.LIST_TYPES))) def get_level(self, parent: etree.Element, block: str) -> tuple[int, etree.Element]: """ Get level of indentation based on list level. """ # Get indent level m = self.INDENT_RE.match(block) if m: indent_level = len(m.group(1)) / MIN_NESTED_LIST_INDENT else: indent_level = 0 if self.parser.state.isstate('list'): # We're in a tight-list - so we already are at correct parent. level = 1 else: # We're in a loose-list - so we need to find parent. level = 0 # Step through children of tree to find matching indent level. while indent_level > level: child = self.lastChild(parent) if child is not None and (child.tag in self.LIST_TYPES or child.tag in self.ITEM_TYPES): if child.tag in self.LIST_TYPES: level += 1 parent = child else: # No more child levels. If we're short of `indent_level`, # we have a code block. So we stop here. break return level, parent def detab(self, text: str, length: int | None = None) -> tuple[str, str]: """ Remove a tab from the front of each line of the given text. """ if length is None: length = MIN_NESTED_LIST_INDENT newtext = [] lines = text.split('\n') for line in lines: if line.startswith(' ' * length): newtext.append(line[length:]) elif not line.strip(): newtext.append('') else: break return '\n'.join(newtext), '\n'.join(lines[len(newtext):]) def looseDetab(self, text: str, level: int = 1) -> str: """ Remove indentation from front of lines but allowing dedented lines. """ lines = text.split('\n') for i in range(len(lines)): if lines[i].startswith(' ' * MIN_NESTED_LIST_INDENT * level): lines[i] = lines[i][MIN_NESTED_LIST_INDENT * level:] return '\n'.join(lines) class SaneOListProcessor(OListProcessor): """ Override `SIBLING_TAGS` to not include `ul` and set `LAZY_OL` to `False`. """ SIBLING_TAGS = ['ol'] """ Exclude `ul` from list of siblings. """ LAZY_OL = False """ Disable lazy list behavior. """ def __init__(self, parser: blockparser.BlockParser): super().__init__(parser) max_list_start_indent = self.tab_length # Detect an item (e.g., `1. item`) self.RE = re.compile(r'^[ ]{0,%d}[\*_]{0,2}\d+\.[ ]+(.*)' % max_list_start_indent) # Detect items on secondary lines. they can be of either list type. self.CHILD_RE = re.compile(r'^[ ]{0,%d}([\*_]{0,2})((\d+\.))[ ]+(.*)' % (MIN_NESTED_LIST_INDENT - 1)) # Detect indented (nested) items of either type self.INDENT_RE = re.compile(r'^[ ]{%d,%d}[\*_]{0,2}((\d+\.)|[*+-])[ ]+.*' % (MIN_NESTED_LIST_INDENT, self.tab_length * 2)) def run(self, parent: etree.Element, blocks: list[str]) -> None: # Check for multiple items in one block. items = self.get_items(blocks.pop(0)) sibling = self.lastChild(parent) if sibling is not None and sibling.tag in self.SIBLING_TAGS: # Previous block was a list item, so set that as parent lst = sibling # make sure previous item is in a `p` - if the item has text, # then it isn't in a `p` if lst[-1].text: # since it's possible there are other children for this # sibling, we can't just `SubElement` the `p`, we need to # insert it as the first item. p = etree.Element('p') p.text = lst[-1].text lst[-1].text = '' lst[-1].insert(0, p) # if the last item has a tail, then the tail needs to be put in a `p` # likely only when a header is not followed by a blank line lch = self.lastChild(lst[-1]) if lch is not None and lch.tail: p = etree.SubElement(lst[-1], 'p') p.text = lch.tail.lstrip() lch.tail = '' # parse first block differently as it gets wrapped in a `p`. li = etree.SubElement(lst, 'li') self.parser.state.set('looselist') firstitem = items.pop(0) self.parser.parseBlocks(li, [firstitem]) self.parser.state.reset() elif parent.tag in ['ol', 'ul']: # this catches the edge case of a multi-item indented list whose # first item is in a blank parent-list item: # * * subitem1 # * subitem2 # see also `ListIndentProcessor` lst = parent else: # This is a new list so create parent with appropriate tag. lst = etree.SubElement(parent, self.TAG) # Check if a custom start integer is set if not self.LAZY_OL and self.STARTSWITH != '1': lst.attrib['start'] = self.STARTSWITH self.parser.state.set('list') # Loop through items in block, recursively parsing each with the # appropriate parent. for item in items: if item.startswith(" " * MIN_NESTED_LIST_INDENT): # Item is indented. Parse with last item as parent self.parser.parseBlocks(lst[-1], [item]) else: # New item. Create `li` and parse with it as parent li = etree.SubElement(lst, 'li') self.parser.parseBlocks(li, [item]) self.parser.state.reset() def looseDetab(self, text: str, indent_length: int, level: int = 1) -> str: """ Remove indentation from front of lines but allowing dedented lines. """ lines = text.split('\n') for i in range(len(lines)): if lines[i].startswith(' ' * indent_length * level): lines[i] = lines[i][indent_length * level:] return '\n'.join(lines) def get_items(self, block: str) -> list[str]: """ Break a block into list items. """ # If first level of list is indented, remove that indentation if (indent_len := len(block) - len(block.lstrip())) > 0: block = self.looseDetab(block, indent_len) items = [] for line in block.split('\n'): m = self.CHILD_RE.match(line) if m: # This is a new list item # Check first item for the start index if not items: # Detect the integer value of first list item INTEGER_RE = re.compile(r'(\d+)') self.STARTSWITH = INTEGER_RE.match(m.group(2)).group() # Append to the list items.append(m.group(1) + m.group(4)) elif self.INDENT_RE.match(line): # This is an indented (possibly nested) item. if items[-1].startswith(' ' * MIN_NESTED_LIST_INDENT): # Previous item was indented. Append to that item. items[-1] = '{}\n{}'.format(items[-1], line) else: items.append(line) else: # This is another line of previous item. Append to that item. items[-1] = '{}\n{}'.format(items[-1], line) return items class SaneUListProcessor(SaneOListProcessor): """ Override `SIBLING_TAGS` to not include `ol`. """ TAG: str = 'ul' SIBLING_TAGS = ['ul'] """ Exclude `ol` from list of siblings. """ def __init__(self, parser: blockparser.BlockParser): super().__init__(parser) # Detect an item (e.g., `- item` or `+ item` or `* item`). max_list_start_indent = self.tab_length self.RE = re.compile(r'^[ ]{0,%d}[*+-][ ]+(.*)' % max_list_start_indent) self.CHILD_RE = re.compile(r'^[ ]{0,%d}(([*+-]))[ ]+(.*)' % (MIN_NESTED_LIST_INDENT - 1)) def get_items(self, block: str) -> list[str]: """ Break a block into list items. """ # If first level of list is indented, remove that indentation if (indent_len := len(block) - len(block.lstrip())) > 0: block = self.looseDetab(block, indent_len) items = [] for line in block.split('\n'): m = self.CHILD_RE.match(line) if m: # Append to the list items.append(m.group(3)) elif self.INDENT_RE.match(line): # This is an indented (possibly nested) item. if items[-1].startswith(' ' * MIN_NESTED_LIST_INDENT): # Previous item was indented. Append to that item. items[-1] = '{}\n{}'.format(items[-1], line) else: items.append(line) else: # This is another line of previous item. Append to that item. items[-1] = '{}\n{}'.format(items[-1], line) return items class SaneParagraphProcessor(ParagraphProcessor): """ Process Paragraph blocks. """ def __init__(self, parser: BlockParser): super().__init__(parser) max_list_start_indent = self.tab_length self.LIST_RE = re.compile(r"\s{2}\n(\s{0,%d}[\d+*-])" % max_list_start_indent) def run(self, parent: etree.Element, blocks: list[str]) -> None: block = blocks.pop(0) if block.strip(): # Not a blank block. Add to parent, otherwise throw it away. if self.parser.state.isstate('list'): # The parent is a tight-list. # # Check for any children. This will likely only happen in a # tight-list when a header isn't followed by a blank line. # For example: # # * # Header # Line 2 of list item - not part of header. sibling = self.lastChild(parent) if sibling is not None: # Insert after sibling. if sibling.tail: sibling.tail = '{}\n{}'.format(sibling.tail, block) else: sibling.tail = '\n%s' % block else: # Append to parent.text if parent.text: parent.text = '{}\n{}'.format(parent.text, block) else: parent.text = block.lstrip() else: # Check if paragraph contains a list next_list_block = None if list_match := self.LIST_RE.search(block): list_start = list_match.end() - len(list_match.group(1)) next_list_block = block[list_start:] block = block[:list_start] # Create a regular paragraph p = etree.SubElement(parent, 'p') p.text = block.lstrip() # If a list was found, parse its block separately with the paragraph as the parent if next_list_block: self.parser.parseBlocks(p, [next_list_block]) class SaneListExtension(Extension): """ Add sane lists to Markdown. """ def extendMarkdown(self, md): """ Override existing Processors. """ md.parser.blockprocessors.register(SaneListIndentProcessor(md.parser), 'indent', 90) md.parser.blockprocessors.register(SaneOListProcessor(md.parser), 'olist', 40) md.parser.blockprocessors.register(SaneUListProcessor(md.parser), 'ulist', 30) md.parser.blockprocessors.register(SaneParagraphProcessor(md.parser), 'paragraph', 10) # Disable uncommon indented codeblocks (as opposed to fenced codeblocks delimited by "```") md.parser.blockprocessors.deregister('code') def makeExtension(**kwargs): # pragma: no cover return SaneListExtension(**kwargs) ================================================ FILE: modules/shared.py ================================================ import argparse import copy import os import shlex import sys from collections import OrderedDict from pathlib import Path import yaml from modules.logging_colors import logger from modules.paths import resolve_user_data_dir from modules.presets import default_preset, default_preset_values # Resolve user_data directory early (before argparse defaults are set) user_data_dir = resolve_user_data_dir() # Text model variables model = None tokenizer = None model_name = 'None' is_seq2seq = False is_multimodal = False model_dirty_from_training = False lora_names = [] bos_token = '' eos_token = '' # Image model variables image_model = None image_model_name = 'None' image_pipeline_type = None # Generation variables stop_everything = False generation_lock = None processing_message = '' # UI variables gradio = {} persistent_interface_state = {} need_restart = False # Parser copied from https://github.com/vladmandic/automatic parser = argparse.ArgumentParser(description="Text Generation Web UI", conflict_handler='resolve', add_help=True, formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=55, indent_increment=2, width=200)) # Basic settings group = parser.add_argument_group('Basic settings') group.add_argument('--user-data-dir', type=str, default=str(user_data_dir), help='Path to the user data directory. Default: auto-detected.') group.add_argument('--multi-user', action='store_true', help='Multi-user mode. Chat histories are not saved or automatically loaded. Best suited for small trusted teams.') group.add_argument('--model', type=str, help='Name of the model to load by default.') group.add_argument('--lora', type=str, nargs='+', help='The list of LoRAs to load. If you want to load more than one LoRA, write the names separated by spaces.') group.add_argument('--model-dir', type=str, default=str(user_data_dir / 'models'), help='Path to directory with all the models.') group.add_argument('--lora-dir', type=str, default=str(user_data_dir / 'loras'), help='Path to directory with all the loras.') group.add_argument('--model-menu', action='store_true', help='Show a model menu in the terminal when the web UI is first launched.') group.add_argument('--settings', type=str, help='Load the default interface settings from this yaml file. See user_data/settings-template.yaml for an example. If you create a file called user_data/settings.yaml, this file will be loaded by default without the need to use the --settings flag.') group.add_argument('--extensions', type=str, nargs='+', help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.') group.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.') group.add_argument('--idle-timeout', type=int, default=0, help='Unload model after this many minutes of inactivity. It will be automatically reloaded when you try to use it again.') # Image generation group = parser.add_argument_group('Image model') group.add_argument('--image-model', type=str, help='Name of the image model to select on startup (overrides saved setting).') group.add_argument('--image-model-dir', type=str, default=str(user_data_dir / 'image_models'), help='Path to directory with all the image models.') group.add_argument('--image-dtype', type=str, default=None, choices=['bfloat16', 'float16'], help='Data type for image model.') group.add_argument('--image-attn-backend', type=str, default=None, choices=['flash_attention_2', 'sdpa'], help='Attention backend for image model.') group.add_argument('--image-cpu-offload', action='store_true', help='Enable CPU offloading for image model.') group.add_argument('--image-compile', action='store_true', help='Compile the image model for faster inference.') group.add_argument('--image-quant', type=str, default=None, choices=['none', 'bnb-8bit', 'bnb-4bit', 'torchao-int8wo', 'torchao-fp4', 'torchao-float8wo'], help='Quantization method for image model.') # Model loader group = parser.add_argument_group('Model loader') group.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: Transformers, llama.cpp, ExLlamav3_HF, ExLlamav3, TensorRT-LLM.') # Cache group = parser.add_argument_group('Context and cache') group.add_argument('--ctx-size', '--n_ctx', '--max_seq_len', type=int, default=0, metavar='N', help='Context size in tokens. 0 = auto for llama.cpp (requires gpu-layers=-1), 8192 for other loaders.') group.add_argument('--cache-type', '--cache_type', type=str, default='fp16', metavar='N', help='KV cache type; valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV3 - fp16, q2 to q8 (can specify k_bits and v_bits separately, e.g. q4_q8).') # Speculative decoding group = parser.add_argument_group('Speculative decoding') group.add_argument('--model-draft', type=str, default=None, help='Path to the draft model for speculative decoding.') group.add_argument('--draft-max', type=int, default=4, help='Number of tokens to draft for speculative decoding.') group.add_argument('--gpu-layers-draft', type=int, default=256, help='Number of layers to offload to the GPU for the draft model.') group.add_argument('--device-draft', type=str, default=None, help='Comma-separated list of devices to use for offloading the draft model. Example: CUDA0,CUDA1') group.add_argument('--ctx-size-draft', type=int, default=0, help='Size of the prompt context for the draft model. If 0, uses the same as the main model.') group.add_argument('--spec-type', type=str, default='none', choices=['none', 'ngram-mod', 'ngram-simple', 'ngram-map-k', 'ngram-map-k4v', 'ngram-cache'], help='Draftless speculative decoding type. Recommended: ngram-mod.') group.add_argument('--spec-ngram-size-n', type=int, default=24, help='N-gram lookup size for ngram speculative decoding.') group.add_argument('--spec-ngram-size-m', type=int, default=48, help='Draft n-gram size for ngram speculative decoding.') group.add_argument('--spec-ngram-min-hits', type=int, default=1, help='Minimum n-gram hits for ngram-map speculative decoding.') # llama.cpp group = parser.add_argument_group('llama.cpp') group.add_argument('--gpu-layers', '--n-gpu-layers', type=int, default=-1, metavar='N', help='Number of layers to offload to the GPU. -1 = auto.') group.add_argument('--cpu-moe', action='store_true', help='Move the experts to the CPU (for MoE models).') group.add_argument('--mmproj', type=str, default=None, help='Path to the mmproj file for vision models.') group.add_argument('--streaming-llm', action='store_true', help='Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.') group.add_argument('--tensor-split', type=str, default=None, help='Split the model across multiple GPUs. Comma-separated list of proportions. Example: 60,40.') group.add_argument('--row-split', action='store_true', help='Split the model by rows across GPUs. This may improve multi-gpu performance.') group.add_argument('--no-mmap', action='store_true', help='Prevent mmap from being used.') group.add_argument('--mlock', action='store_true', help='Force the system to keep the model in RAM.') group.add_argument('--no-kv-offload', action='store_true', help='Do not offload the K, Q, V to the GPU. This saves VRAM but reduces the performance.') group.add_argument('--batch-size', type=int, default=1024, help='Maximum number of prompt tokens to batch together when calling llama-server. This is the application level batch size.') group.add_argument('--ubatch-size', type=int, default=1024, help='Maximum number of prompt tokens to batch together when calling llama-server. This is the max physical batch size for computation (device level).') group.add_argument('--threads', type=int, default=0, help='Number of threads to use.') group.add_argument('--threads-batch', type=int, default=0, help='Number of threads to use for batches/prompt processing.') group.add_argument('--numa', action='store_true', help='Activate NUMA task allocation for llama.cpp.') group.add_argument('--parallel', type=int, default=1, help='Number of parallel request slots. The context size is divided equally among slots. For example, to have 4 slots with 8192 context each, set ctx_size to 32768.') group.add_argument('--fit-target', type=str, default='512', help='Target VRAM margin per device for auto GPU layers, comma-separated list of values in MiB. A single value is broadcast across all devices.') group.add_argument('--extra-flags', type=str, default=None, help='Extra flags to pass to llama-server. Format: "flag1=value1,flag2,flag3=value3". Example: "override-tensor=exps=CPU"') # Transformers/Accelerate group = parser.add_argument_group('Transformers/Accelerate') group.add_argument('--cpu', action='store_true', help='Use the CPU to generate text. Warning: Training on CPU is extremely slow.') group.add_argument('--cpu-memory', type=float, default=0, help='Maximum CPU memory in GiB. Use this for CPU offloading.') group.add_argument('--disk', action='store_true', help='If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.') group.add_argument('--disk-cache-dir', type=str, default=str(user_data_dir / 'cache'), help='Directory to save the disk cache to.') group.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision (using bitsandbytes).') group.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.') group.add_argument('--no-cache', action='store_true', help='Set use_cache to False while generating text. This reduces VRAM usage slightly, but it comes at a performance cost.') group.add_argument('--trust-remote-code', action='store_true', help='Set trust_remote_code=True while loading the model. Necessary for some models.') group.add_argument('--force-safetensors', action='store_true', help='Set use_safetensors=True while loading the model. This prevents arbitrary code execution.') group.add_argument('--no_use_fast', action='store_true', help='Set use_fast=False while loading the tokenizer (it\'s True by default). Use this if you have any problems related to use_fast.') group.add_argument('--attn-implementation', type=str, default='sdpa', metavar="IMPLEMENTATION", help='Attention implementation. Valid options: sdpa, eager, flash_attention_2.') # bitsandbytes 4-bit group = parser.add_argument_group('bitsandbytes 4-bit') group.add_argument('--load-in-4bit', action='store_true', help='Load the model with 4-bit precision (using bitsandbytes).') group.add_argument('--use_double_quant', action='store_true', help='use_double_quant for 4-bit.') group.add_argument('--compute_dtype', type=str, default='float16', help='compute dtype for 4-bit. Valid options: bfloat16, float16, float32.') group.add_argument('--quant_type', type=str, default='nf4', help='quant_type for 4-bit. Valid options: nf4, fp4.') # ExLlamaV3 group = parser.add_argument_group('ExLlamaV3') group.add_argument('--gpu-split', type=str, help='Comma-separated list of VRAM (in GB) to use per GPU device for model layers. Example: 20,7,7.') group.add_argument('--enable-tp', '--enable_tp', action='store_true', help='Enable Tensor Parallelism (TP) to split the model across GPUs.') group.add_argument('--tp-backend', type=str, default='native', help='The backend for tensor parallelism. Valid options: native, nccl. Default: native.') group.add_argument('--cfg-cache', action='store_true', help='Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader.') # Gradio group = parser.add_argument_group('Gradio') group.add_argument('--listen', action='store_true', help='Make the web UI reachable from your local network.') group.add_argument('--listen-port', type=int, help='The listening port that the server will use.') group.add_argument('--listen-host', type=str, help='The hostname that the server will use.') group.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.') group.add_argument('--auto-launch', action='store_true', default=False, help='Open the web UI in the default browser upon launch.') group.add_argument('--gradio-auth', type=str, help='Set Gradio authentication password in the format "username:password". Multiple credentials can also be supplied with "u1:p1,u2:p2,u3:p3".', default=None) group.add_argument('--gradio-auth-path', type=str, help='Set the Gradio authentication file path. The file should contain one or more user:password pairs in the same format as above.', default=None) group.add_argument('--ssl-keyfile', type=str, help='The path to the SSL certificate key file.', default=None) group.add_argument('--ssl-certfile', type=str, help='The path to the SSL certificate cert file.', default=None) group.add_argument('--subpath', type=str, help='Customize the subpath for gradio, use with reverse proxy') group.add_argument('--old-colors', action='store_true', help='Use the legacy Gradio colors, before the December/2024 update.') group.add_argument('--portable', action='store_true', help='Hide features not available in portable mode like training.') # API group = parser.add_argument_group('API') group.add_argument('--api', action='store_true', help='Enable the API extension.') group.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudflare.') group.add_argument('--public-api-id', type=str, help='Tunnel ID for named Cloudflare Tunnel. Use together with public-api option.', default=None) group.add_argument('--api-port', type=int, default=5000, help='The listening port for the API.') group.add_argument('--api-key', type=str, default='', help='API authentication key.') group.add_argument('--admin-key', type=str, default='', help='API authentication key for admin tasks like loading and unloading models. If not set, will be the same as --api-key.') group.add_argument('--api-enable-ipv6', action='store_true', help='Enable IPv6 for the API') group.add_argument('--api-disable-ipv4', action='store_true', help='Disable IPv4 for the API') group.add_argument('--nowebui', action='store_true', help='Do not launch the Gradio UI. Useful for launching the API in standalone mode.') # API generation defaults _d = default_preset_values group = parser.add_argument_group('API generation defaults') group.add_argument('--temperature', type=float, default=_d['temperature'], metavar='N', help='Temperature') group.add_argument('--dynatemp-low', type=float, default=_d['dynatemp_low'], metavar='N', help='Dynamic temperature low') group.add_argument('--dynatemp-high', type=float, default=_d['dynatemp_high'], metavar='N', help='Dynamic temperature high') group.add_argument('--dynatemp-exponent', type=float, default=_d['dynatemp_exponent'], metavar='N', help='Dynamic temperature exponent') group.add_argument('--smoothing-factor', type=float, default=_d['smoothing_factor'], metavar='N', help='Smoothing factor') group.add_argument('--smoothing-curve', type=float, default=_d['smoothing_curve'], metavar='N', help='Smoothing curve') group.add_argument('--top-p', type=float, default=_d['top_p'], metavar='N', help='Top P') group.add_argument('--top-k', type=int, default=_d['top_k'], metavar='N', help='Top K') group.add_argument('--min-p', type=float, default=_d['min_p'], metavar='N', help='Min P') group.add_argument('--top-n-sigma', type=float, default=_d['top_n_sigma'], metavar='N', help='Top N Sigma') group.add_argument('--typical-p', type=float, default=_d['typical_p'], metavar='N', help='Typical P') group.add_argument('--xtc-threshold', type=float, default=_d['xtc_threshold'], metavar='N', help='XTC threshold') group.add_argument('--xtc-probability', type=float, default=_d['xtc_probability'], metavar='N', help='XTC probability') group.add_argument('--epsilon-cutoff', type=float, default=_d['epsilon_cutoff'], metavar='N', help='Epsilon cutoff') group.add_argument('--eta-cutoff', type=float, default=_d['eta_cutoff'], metavar='N', help='Eta cutoff') group.add_argument('--tfs', type=float, default=_d['tfs'], metavar='N', help='TFS') group.add_argument('--top-a', type=float, default=_d['top_a'], metavar='N', help='Top A') group.add_argument('--adaptive-target', type=float, default=_d['adaptive_target'], metavar='N', help='Adaptive target') group.add_argument('--adaptive-decay', type=float, default=_d['adaptive_decay'], metavar='N', help='Adaptive decay') group.add_argument('--dry-multiplier', type=float, default=_d['dry_multiplier'], metavar='N', help='DRY multiplier') group.add_argument('--dry-allowed-length', type=int, default=_d['dry_allowed_length'], metavar='N', help='DRY allowed length') group.add_argument('--dry-base', type=float, default=_d['dry_base'], metavar='N', help='DRY base') group.add_argument('--repetition-penalty', type=float, default=_d['repetition_penalty'], metavar='N', help='Repetition penalty') group.add_argument('--frequency-penalty', type=float, default=_d['frequency_penalty'], metavar='N', help='Frequency penalty') group.add_argument('--presence-penalty', type=float, default=_d['presence_penalty'], metavar='N', help='Presence penalty') group.add_argument('--encoder-repetition-penalty', type=float, default=_d['encoder_repetition_penalty'], metavar='N', help='Encoder repetition penalty') group.add_argument('--no-repeat-ngram-size', type=int, default=_d['no_repeat_ngram_size'], metavar='N', help='No repeat ngram size') group.add_argument('--repetition-penalty-range', type=int, default=_d['repetition_penalty_range'], metavar='N', help='Repetition penalty range') group.add_argument('--penalty-alpha', type=float, default=_d['penalty_alpha'], metavar='N', help='Penalty alpha') group.add_argument('--guidance-scale', type=float, default=_d['guidance_scale'], metavar='N', help='Guidance scale') group.add_argument('--mirostat-mode', type=int, default=_d['mirostat_mode'], metavar='N', help='Mirostat mode') group.add_argument('--mirostat-tau', type=float, default=_d['mirostat_tau'], metavar='N', help='Mirostat tau') group.add_argument('--mirostat-eta', type=float, default=_d['mirostat_eta'], metavar='N', help='Mirostat eta') group.add_argument('--do-sample', action=argparse.BooleanOptionalAction, default=_d['do_sample'], help='Do sample') group.add_argument('--dynamic-temperature', action=argparse.BooleanOptionalAction, default=_d['dynamic_temperature'], help='Dynamic temperature') group.add_argument('--temperature-last', action=argparse.BooleanOptionalAction, default=_d['temperature_last'], help='Temperature last') group.add_argument('--sampler-priority', type=str, default=_d['sampler_priority'], metavar='N', help='Sampler priority') group.add_argument('--dry-sequence-breakers', type=str, default=_d['dry_sequence_breakers'], metavar='N', help='DRY sequence breakers') group.add_argument('--enable-thinking', action=argparse.BooleanOptionalAction, default=True, help='Enable thinking') group.add_argument('--reasoning-effort', type=str, default='medium', metavar='N', help='Reasoning effort') group.add_argument('--chat-template-file', type=str, default=None, help='Path to a chat template file (.jinja, .jinja2, or .yaml) to use as the default instruction template for API requests. Overrides the model\'s built-in template.') # Handle CMD_FLAGS.txt cmd_flags_path = user_data_dir / "CMD_FLAGS.txt" if cmd_flags_path.exists(): with cmd_flags_path.open('r', encoding='utf-8') as f: cmd_flags = ' '.join( line.strip().rstrip('\\').strip() for line in f if line.strip().rstrip('\\').strip() and not line.strip().startswith('#') ) if cmd_flags: # Command-line takes precedence over CMD_FLAGS.txt sys.argv = [sys.argv[0]] + shlex.split(cmd_flags) + sys.argv[1:] args = parser.parse_args() user_data_dir = Path(args.user_data_dir) # Update from parsed args (may differ from pre-parse) original_args = copy.deepcopy(args) args_defaults = parser.parse_args([]) # Create a mapping of all argument aliases to their canonical names alias_to_dest = {} for action in parser._actions: for opt in action.option_strings: alias_to_dest[opt.lstrip('-').replace('-', '_')] = action.dest provided_arguments = [] for arg in sys.argv[1:]: arg = arg.lstrip('-').replace('-', '_') if arg in alias_to_dest: provided_arguments.append(alias_to_dest[arg]) elif hasattr(args, arg): provided_arguments.append(arg) # Default generation parameters neutral_samplers = default_preset() # UI defaults settings = { 'show_controls': True, 'start_with': '', 'mode': 'instruct', 'chat_style': 'cai-chat', 'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>". Reply directly, without starting the reply with the character name.\n\n<|prompt|>', 'enable_web_search': False, 'web_search_pages': 3, 'selected_tools': [], 'prompt-notebook': '', 'preset': 'Top-P' if (user_data_dir / 'presets/Top-P.yaml').exists() else None, 'max_new_tokens': 512, 'max_new_tokens_min': 1, 'max_new_tokens_max': 4096, 'prompt_lookup_num_tokens': 0, 'max_tokens_second': 0, 'auto_max_new_tokens': True, 'ban_eos_token': False, 'add_bos_token': True, 'enable_thinking': True, 'reasoning_effort': 'medium', 'skip_special_tokens': True, 'stream': True, 'static_cache': False, 'truncation_length': 8192, 'seed': -1, 'custom_stopping_strings': '', 'custom_token_bans': '', 'negative_prompt': '', 'dark_theme': True, 'show_two_notebook_columns': False, 'paste_to_attachment': False, 'include_past_attachments': True, # Generation parameters - Curve shape 'temperature': neutral_samplers['temperature'], 'dynatemp_low': neutral_samplers['dynatemp_low'], 'dynatemp_high': neutral_samplers['dynatemp_high'], 'dynatemp_exponent': neutral_samplers['dynatemp_exponent'], 'smoothing_factor': neutral_samplers['smoothing_factor'], 'smoothing_curve': neutral_samplers['smoothing_curve'], # Generation parameters - Curve cutoff 'top_p': 0.95, 'top_k': neutral_samplers['top_k'], 'min_p': neutral_samplers['min_p'], 'top_n_sigma': neutral_samplers['top_n_sigma'], 'typical_p': neutral_samplers['typical_p'], 'xtc_threshold': neutral_samplers['xtc_threshold'], 'xtc_probability': neutral_samplers['xtc_probability'], 'epsilon_cutoff': neutral_samplers['epsilon_cutoff'], 'eta_cutoff': neutral_samplers['eta_cutoff'], 'tfs': neutral_samplers['tfs'], 'top_a': neutral_samplers['top_a'], 'adaptive_target': neutral_samplers['adaptive_target'], 'adaptive_decay': neutral_samplers['adaptive_decay'], # Generation parameters - Repetition suppression 'dry_multiplier': neutral_samplers['dry_multiplier'], 'dry_allowed_length': neutral_samplers['dry_allowed_length'], 'dry_base': neutral_samplers['dry_base'], 'repetition_penalty': neutral_samplers['repetition_penalty'], 'frequency_penalty': neutral_samplers['frequency_penalty'], 'presence_penalty': neutral_samplers['presence_penalty'], 'encoder_repetition_penalty': neutral_samplers['encoder_repetition_penalty'], 'no_repeat_ngram_size': neutral_samplers['no_repeat_ngram_size'], 'repetition_penalty_range': neutral_samplers['repetition_penalty_range'], # Generation parameters - Alternative sampling methods 'penalty_alpha': neutral_samplers['penalty_alpha'], 'guidance_scale': neutral_samplers['guidance_scale'], 'mirostat_mode': neutral_samplers['mirostat_mode'], 'mirostat_tau': neutral_samplers['mirostat_tau'], 'mirostat_eta': neutral_samplers['mirostat_eta'], # Generation parameters - Other options 'do_sample': neutral_samplers['do_sample'], 'dynamic_temperature': neutral_samplers['dynamic_temperature'], 'temperature_last': neutral_samplers['temperature_last'], 'sampler_priority': neutral_samplers['sampler_priority'], 'dry_sequence_breakers': neutral_samplers['dry_sequence_breakers'], 'grammar_string': '', # Character settings 'character': 'Assistant', 'user': 'Default', 'name1': 'You', 'name2': 'AI', 'user_bio': '', 'context': 'The following is a conversation with an AI Large Language Model. The AI has been trained to answer questions, provide recommendations, and help with decision making. The AI follows user requests. The AI thinks outside the box.', 'greeting': 'How can I help you today?', 'custom_system_message': '', 'instruction_template_str': "{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{%- if not ns.found -%}\n {{- '' + 'Below is an instruction that describes a task. Write a response that appropriately completes the request.' + '\\n\\n' -}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' -%}\n {{- '' + message['content'] + '\\n\\n' -}}\n {%- else -%}\n {%- if message['role'] == 'user' -%}\n {{-'### Instruction:\\n' + message['content'] + '\\n\\n'-}}\n {%- else -%}\n {{-'### Response:\\n' + message['content'] + '\\n\\n' -}}\n {%- endif -%}\n {%- endif -%}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{-'### Response:\\n'-}}\n{%- endif -%}", 'chat_template_str': "{%- for message in messages %}\n {%- if message['role'] == 'system' -%}\n {%- if message['content'] -%}\n {{- message['content'] + '\\n\\n' -}}\n {%- endif -%}\n {%- if user_bio -%}\n {{- user_bio + '\\n\\n' -}}\n {%- endif -%}\n {%- elif message['role'] == 'tool' -%}\n {{- '[Tool result: ' + message['content'] + ']\\n' -}}\n {%- elif message['role'] == 'user' -%}\n {{- name1 + ': ' + message['content'] + '\\n'-}}\n {%- elif message['tool_calls'] is defined and message['tool_calls'] -%}\n {%- for tc in message['tool_calls'] -%}\n {{- '[Calling: ' + tc['function']['name'] + '(' + tc['function']['arguments'] + ')]\\n' -}}\n {%- endfor -%}\n {%- else -%}\n {{- name2 + ': ' + message['content'] + '\\n' -}}\n {%- endif -%}\n{%- endfor -%}\n{%- if add_generation_prompt %}\n {{- name2 + ':' -}}\n{%- endif %}", # Extensions 'default_extensions': [], # Image generation settings 'image_prompt': '', 'image_neg_prompt': '', 'image_width': 1024, 'image_height': 1024, 'image_aspect_ratio': '1:1 Square', 'image_steps': 9, 'image_cfg_scale': 0.0, 'image_seed': -1, 'image_batch_size': 1, 'image_batch_count': 1, 'image_llm_variations': False, 'image_llm_variations_prompt': 'Write a variation of the image generation prompt above. Consider the intent of the user with that prompt and write something that will likely please them, with added details. Output only the new prompt. Do not add any explanations, prefixes, or additional text.', 'image_model_menu': 'None', 'image_dtype': 'bfloat16', 'image_attn_backend': 'flash_attention_2', 'image_cpu_offload': False, 'image_compile': False, 'image_quant': 'none', } default_settings = copy.deepcopy(settings) def do_cmd_flags_warnings(): # Validate --chat-template-file if args.chat_template_file and not Path(args.chat_template_file).is_file(): logger.error(f"--chat-template-file: file not found: {args.chat_template_file}") sys.exit(1) # Security warnings if args.trust_remote_code: logger.warning( "The `--trust-remote-code` flag is enabled.\n" "This allows models to execute arbitrary code on your machine.\n\n" "1. Only use with models from sources you fully trust.\n" "2. Set an access password with `--gradio-auth`." ) if 'COLAB_GPU' not in os.environ and not args.nowebui: if args.share: logger.warning("The gradio \"share link\" feature uses a proprietary executable to create a reverse tunnel. Use it with care.") if any((args.listen, args.share)) and not any((args.gradio_auth, args.gradio_auth_path)): logger.warning("You are potentially exposing the web UI to the entire internet without any access password.\nYou can create one with the \"--gradio-auth\" flag like this:\n\n--gradio-auth username:password\n\nMake sure to replace username:password with your own.") if args.multi_user: logger.warning( 'Multi-user mode is enabled. Known limitations:' '\n- The Stop button stops generation for all users, not just you.' '\n- Chat history is not saved and will be lost on page refresh.' '\n- Only one user can generate at a time unless using a parallel-capable backend (e.g. llama.cpp with --parallel N for N > 1, or ExLlamaV3).' '\n\nThis mode works best for small trusted teams.' '\n\nDo not expose publicly. Grayed-out actions can easily be bypassed client-side.\n' ) def apply_image_model_cli_overrides(): """Apply command-line overrides for image model settings.""" if args.image_model is not None: settings['image_model_menu'] = args.image_model if args.image_dtype is not None: settings['image_dtype'] = args.image_dtype if args.image_attn_backend is not None: settings['image_attn_backend'] = args.image_attn_backend if args.image_cpu_offload: settings['image_cpu_offload'] = True if args.image_compile: settings['image_compile'] = True if args.image_quant is not None: settings['image_quant'] = args.image_quant def fix_loader_name(name): if not name: return name name = name.lower() if name in ['llama.cpp', 'llamacpp', 'llama-cpp', 'llama cpp']: return 'llama.cpp' elif name in ['transformers', 'huggingface', 'hf', 'hugging_face', 'hugging face']: return 'Transformers' elif name in ['exllamav3-hf', 'exllamav3_hf', 'exllama-v3-hf', 'exllama_v3_hf', 'exllama-v3_hf', 'exllama3-hf', 'exllama3_hf', 'exllama-3-hf', 'exllama_3_hf', 'exllama-3_hf']: return 'ExLlamav3_HF' elif name in ['exllamav3']: return 'ExLlamav3' elif name in ['tensorrt', 'tensorrtllm', 'tensorrt_llm', 'tensorrt-llm', 'tensort', 'tensortllm']: return 'TensorRT-LLM' def add_extension(name, last=False): if args.extensions is None: args.extensions = [name] elif last: args.extensions = [x for x in args.extensions if x != name] args.extensions.append(name) elif name not in args.extensions: args.extensions.append(name) def is_chat(): return True def load_user_config(): ''' Loads custom model-specific settings ''' if Path(f'{args.model_dir}/config-user.yaml').exists(): file_content = open(f'{args.model_dir}/config-user.yaml', 'r').read().strip() if file_content: user_config = yaml.safe_load(file_content) else: user_config = {} else: user_config = {} return user_config args.loader = fix_loader_name(args.loader) # Activate the API extension if args.api or args.public_api: add_extension('openai', last=True) # Load model-specific settings p = Path(f'{args.model_dir}/config.yaml') if p.exists(): model_config = yaml.safe_load(open(p, 'r').read()) else: model_config = {} del p # Load custom model-specific settings user_config = load_user_config() model_config = OrderedDict(model_config) user_config = OrderedDict(user_config) ================================================ FILE: modules/tensorrt_llm.py ================================================ from pathlib import Path from tensorrt_llm._tensorrt_engine import LLM from tensorrt_llm.llmapi import SamplingParams from modules import shared from modules.logging_colors import logger class TensorRTLLMModel: def __init__(self): pass @classmethod def from_pretrained(cls, path_to_model): path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model) llm = LLM( model=str(path_to_model), skip_tokenizer_init=False, ) result = cls() result.llm = llm result.tokenizer = llm.tokenizer return result def generate_with_streaming(self, prompt, state): sampling_params = SamplingParams( max_tokens=state['max_new_tokens'] if not state['auto_max_new_tokens'] else state['truncation_length'] - len(shared.tokenizer.encode(prompt)), end_id=shared.tokenizer.eos_token_id, temperature=state['temperature'], top_k=state['top_k'], top_p=state['top_p'], min_p=state['min_p'], repetition_penalty=state['repetition_penalty'], presence_penalty=state['presence_penalty'], frequency_penalty=state['frequency_penalty'], no_repeat_ngram_size=state['no_repeat_ngram_size'] if state['no_repeat_ngram_size'] > 0 else None, seed=state['seed'], ignore_eos=state['ban_eos_token'], add_special_tokens=state['add_bos_token'], skip_special_tokens=state['skip_special_tokens'], ) stop_event = state.get('stop_event') result = self.llm.generate_async(prompt, sampling_params=sampling_params, streaming=True) cumulative_reply = '' for output in result: if shared.stop_everything or (stop_event and stop_event.is_set()): result.abort() break text_diff = output.outputs[0].text_diff if text_diff: cumulative_reply += text_diff yield cumulative_reply def generate(self, prompt, state): output = '' for output in self.generate_with_streaming(prompt, state): pass return output def unload(self): if hasattr(self, 'llm') and self.llm is not None: self.llm.shutdown() self.llm = None ================================================ FILE: modules/text_generation.py ================================================ import ast import copy import html import pprint import random import time import traceback import numpy as np import modules.shared as shared from modules import models from modules.callbacks import Iteratorize from modules.extensions import apply_extensions from modules.html_generator import generate_basic_html from modules.logging_colors import logger from modules.utils import check_model_loaded def generate_reply(*args, **kwargs): if shared.args.idle_timeout > 0 and shared.model is None and shared.model_name not in [None, 'None']: from modules.models import load_model shared.model, shared.tokenizer = load_model(shared.model_name) state = args[1] if len(args) > 1 else kwargs.get('state', {}) use_parallel = ( state.get('stop_event') is not None and shared.model.__class__.__name__ in ['Exllamav3Model', 'LlamaServer', 'TensorRTLLMModel'] and (shared.model.__class__.__name__ != 'LlamaServer' or shared.args.parallel > 1) ) if not use_parallel: shared.generation_lock.acquire() try: for result in _generate_reply(*args, **kwargs): yield result finally: models.last_generation_time = time.time() if not use_parallel: shared.generation_lock.release() def _generate_reply(question, state, stopping_strings=None, is_chat=False, escape_html=False, for_ui=False): # Find the appropriate generation function generate_func = apply_extensions('custom_generate_reply') if generate_func is None: model_is_loaded, error_message = check_model_loaded() if not model_is_loaded: yield '' return if shared.model.__class__.__name__ in ['LlamaServer', 'Exllamav3Model', 'TensorRTLLMModel']: generate_func = generate_reply_custom else: generate_func = generate_reply_HF if generate_func != generate_reply_HF and shared.args.verbose: logger.info("PROMPT=") print_prompt(question) # Prepare the input original_question = question if not is_chat: state = apply_extensions('state', state) question = apply_extensions('input', question, state) # Find the stopping strings all_stop_strings = [] for st in (stopping_strings, state['custom_stopping_strings']): if type(st) is str: st = ast.literal_eval(f"[{st}]") if type(st) is list and len(st) > 0: all_stop_strings += st shared.stop_everything = False reply = '' is_stream = state['stream'] if len(all_stop_strings) > 0 and not state['stream']: original_logits_processor = state.get('logits_processor') stop_event_ref = state.pop('stop_event', None) state = copy.deepcopy(state) if stop_event_ref is not None: state['stop_event'] = stop_event_ref if original_logits_processor is not None: state['logits_processor'] = original_logits_processor state['stream'] = True # Generate last_update = -1 latency_threshold = 1 / 1000 for reply in generate_func(question, original_question, state, stopping_strings, is_chat=is_chat): cur_time = time.monotonic() reply, stop_found = apply_stopping_strings(reply, all_stop_strings) if escape_html: reply = html.escape(reply) if is_stream: # Limit number of tokens/second to make text readable in real time if state['max_tokens_second'] > 0: diff = 1 / state['max_tokens_second'] - (cur_time - last_update) if diff > 0: time.sleep(diff) last_update = time.monotonic() yield reply # Limit updates to avoid lag in the Gradio UI # API updates are not limited else: # If 'generate_func' takes less than 0.001 seconds to yield the next token # (equivalent to more than 1000 tok/s), assume that the UI is lagging behind and skip yielding if (cur_time - last_update) > latency_threshold: yield reply last_update = time.monotonic() stop_event = state.get('stop_event') if stop_found or shared.stop_everything or (stop_event and stop_event.is_set()): break if not is_chat: reply = apply_extensions('output', reply, state) yield reply def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None): if shared.tokenizer is None: raise ValueError('No tokenizer is loaded') # llama.cpp case if shared.model.__class__.__name__ == 'LlamaServer': input_ids = shared.tokenizer.encode(str(prompt), add_bos_token=add_bos_token) input_ids = np.array(input_ids).reshape(1, len(input_ids)) if truncation_length is not None: input_ids = input_ids[:, -truncation_length:] return input_ids # All other model types else: import torch from modules.torch_utils import get_device if shared.model.__class__.__name__ in ['Exllamav3Model', 'TensorRTLLMModel']: input_ids = shared.tokenizer.encode(str(prompt)) if shared.model.__class__.__name__ not in ['Exllamav3Model']: input_ids = np.array(input_ids).reshape(1, len(input_ids)) else: input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', add_special_tokens=add_special_tokens) if hasattr(shared.tokenizer, 'bos_token_id') and shared.tokenizer.bos_token_id is not None: if add_bos_token: # Add BOS token if missing if (len(input_ids[0]) > 0 and input_ids[0][0] != shared.tokenizer.bos_token_id) or len(input_ids[0]) == 0: bos_tensor = torch.tensor([[shared.tokenizer.bos_token_id]]) input_ids = torch.cat((bos_tensor, input_ids), 1) # Always prevent double BOS tokens (regardless of add_bos_token setting) while len(input_ids[0]) > 1 and input_ids[0][0] == shared.tokenizer.bos_token_id and input_ids[0][1] == shared.tokenizer.bos_token_id: input_ids = input_ids[:, 1:] if truncation_length is not None: input_ids = input_ids[:, -truncation_length:] if shared.model.__class__.__name__ in ['Exllamav3Model', 'TensorRTLLMModel'] or shared.args.cpu: return input_ids else: device = get_device() if device: return input_ids.to(device) return input_ids def decode(output_ids, skip_special_tokens=True): if shared.tokenizer is None: raise ValueError('No tokenizer is loaded') return shared.tokenizer.decode(output_ids, skip_special_tokens=skip_special_tokens) def get_encoded_length(prompt): length_after_extensions = apply_extensions('tokenized_length', prompt) if length_after_extensions is not None: return length_after_extensions return len(encode(prompt)[0]) def get_token_ids(prompt): tokens = encode(prompt)[0] decoded_tokens = [shared.tokenizer.decode([int(i)]) for i in tokens] output = '' for row in list(zip(tokens, decoded_tokens)): output += f"{str(int(row[0])).ljust(5)} - {repr(row[1])}\n" return output def get_max_prompt_length(state): return state['truncation_length'] - state['max_new_tokens'] def generate_reply_wrapper(question, state, stopping_strings=None): """ Returns formatted outputs for the UI """ reply = question if not shared.is_seq2seq else '' yield formatted_outputs(reply, shared.model_name) for reply in generate_reply(question, state, stopping_strings, is_chat=False, escape_html=True, for_ui=True): if not shared.is_seq2seq: reply = question + reply yield formatted_outputs(reply, shared.model_name) def formatted_outputs(reply, model_name): return html.unescape(reply), generate_basic_html(reply) def set_manual_seed(seed): seed = int(seed) if seed == -1: seed = random.randint(1, 2**31) if shared.args.loader != 'llama.cpp': import torch from transformers import is_torch_npu_available, is_torch_xpu_available torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) elif is_torch_xpu_available(): torch.xpu.manual_seed_all(seed) elif is_torch_npu_available(): torch.npu.manual_seed_all(seed) return seed def stop_everything_event(): shared.stop_everything = True def apply_stopping_strings(reply, all_stop_strings): stop_found = False for string in all_stop_strings: idx = reply.find(string) if idx != -1: reply = reply[:idx] stop_found = True break if not stop_found: # If something like "\nYo" is generated just before "\nYou:" # is completed, trim it for string in all_stop_strings: for j in range(len(string) - 1, 0, -1): if reply[-j:] == string[:j]: reply = reply[:-j] break else: continue break return reply, stop_found def get_reply_from_output_ids(output_ids, state=None, starting_from=0): import torch if torch.cuda.is_available(): torch.cuda.synchronize() reply = decode(output_ids[starting_from:], state['skip_special_tokens'] if state else True) # Handle tokenizers that do not add the leading space for the first token if (hasattr(shared.tokenizer, 'convert_ids_to_tokens') and len(output_ids) > starting_from) and not reply.startswith(' '): first_token = shared.tokenizer.convert_ids_to_tokens(int(output_ids[starting_from])) if isinstance(first_token, (bytes,)): # try to decode the bytes to a string # if it fails, which means it's not a string in this turn, just ignore it try: first_token = first_token.decode('utf8') except UnicodeDecodeError: first_token = '' if first_token.startswith('▁'): reply = ' ' + reply return reply def generate_reply_HF(question, original_question, state, stopping_strings=None, is_chat=False): import torch import transformers from transformers import LogitsProcessorList from modules.grammar.grammar_utils import initialize_grammar from modules.grammar.logits_process import ( GrammarConstrainedLogitsProcessor ) from modules.torch_utils import clear_torch_cache, get_device from modules.transformers_loader import ( Stream, _StopEverythingStoppingCriteria ) if shared.args.loader == 'Transformers': clear_torch_cache() seed = set_manual_seed(state['seed']) generate_params = {} for k in [ 'temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', 'smoothing_curve', 'min_p', 'top_p', 'top_k', 'typical_p', 'xtc_threshold', 'xtc_probability', 'tfs', 'top_a', 'top_n_sigma', 'adaptive_target', 'adaptive_decay', 'dry_multiplier', 'dry_allowed_length', 'dry_base', 'repetition_penalty', 'frequency_penalty', 'presence_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'repetition_penalty_range', 'penalty_alpha', 'guidance_scale', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'max_new_tokens', 'do_sample', 'dynamic_temperature', 'temperature_last', 'dry_sequence_breakers', ]: if k in state: generate_params[k] = state[k] for k in ['epsilon_cutoff', 'eta_cutoff']: if state[k] > 0: generate_params[k] = state[k] * 1e-4 if state['prompt_lookup_num_tokens'] > 0: generate_params['prompt_lookup_num_tokens'] = state['prompt_lookup_num_tokens'] if state['ban_eos_token']: generate_params['suppress_tokens'] = [shared.tokenizer.eos_token_id] if state['static_cache']: generate_params['cache_implementation'] = 'static' if isinstance(state['sampler_priority'], list) and len(state['sampler_priority']) > 0: generate_params['sampler_priority'] = state['sampler_priority'] elif isinstance(state['sampler_priority'], str) and state['sampler_priority'].strip() != '': generate_params['sampler_priority'] = [x.strip() for x in state['sampler_priority'].replace('\n', ',').split(',') if x.strip()] if state['custom_token_bans']: to_ban = [int(x.strip()) for x in state['custom_token_bans'].split(',') if x.strip()] if len(to_ban) > 0: if generate_params.get('suppress_tokens', None): generate_params['suppress_tokens'] += to_ban else: generate_params['suppress_tokens'] = to_ban if state['negative_prompt'] != '': generate_params['negative_prompt_ids'] = encode(state['negative_prompt']) generate_params.update({'use_cache': not shared.args.no_cache}) # Encode the input input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state)) output = input_ids[0] if state['auto_max_new_tokens']: generate_params['max_new_tokens'] = state['truncation_length'] - input_ids.shape[-1] # Add the encoded tokens to generate_params question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None) original_input_ids = input_ids generate_params.update({'inputs': input_ids}) if inputs_embeds is not None: generate_params.update({'inputs_embeds': inputs_embeds}) # Stopping criteria / eos token eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else [] generate_params['eos_token_id'] = eos_token_ids generate_params['stopping_criteria'] = transformers.StoppingCriteriaList() generate_params['stopping_criteria'].append(_StopEverythingStoppingCriteria()) # Logits processor processor = state.get('logits_processor', LogitsProcessorList([])) if not isinstance(processor, LogitsProcessorList): processor = LogitsProcessorList([processor]) # Grammar if state['grammar_string'].strip() != '': grammar = initialize_grammar(state['grammar_string']) grammar_processor = GrammarConstrainedLogitsProcessor(grammar) processor.append(grammar_processor) apply_extensions('logits_processor', processor, input_ids) generate_params['logits_processor'] = processor if shared.args.verbose: logger.info("GENERATE_PARAMS=") filtered_params = {key: value for key, value in generate_params.items() if not isinstance(value, torch.Tensor)} pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(filtered_params) print() logger.info("PROMPT=") print_prompt(decode(input_ids[0], skip_special_tokens=False)) t0 = time.time() try: if not is_chat and not shared.is_seq2seq: yield '' # Generate the entire reply at once. if not state['stream']: with torch.no_grad(): output = shared.model.generate(**generate_params)[0] device = get_device() if device: output = output.to(device) starting_from = 0 if shared.is_seq2seq else len(input_ids[0]) yield get_reply_from_output_ids(output, state, starting_from=starting_from) # Stream the reply 1 token at a time. # This is based on the trick of using 'stopping_criteria' to create an iterator. else: def generate_with_callback(callback=None, *args, **kwargs): kwargs['stopping_criteria'].append(Stream(callback_func=callback)) with torch.no_grad(): shared.model.generate(**kwargs) def generate_with_streaming(**kwargs): return Iteratorize(generate_with_callback, [], kwargs, callback=None) with generate_with_streaming(**generate_params) as generator: cumulative_reply = '' starting_from = 0 if shared.is_seq2seq else len(input_ids[0]) for output in generator: if output[-1] in eos_token_ids: break new_content = get_reply_from_output_ids(output, state, starting_from=starting_from) # check the partial unicode character if chr(0xfffd) in new_content: continue cumulative_reply += new_content starting_from = len(output) yield cumulative_reply except Exception: traceback.print_exc() finally: t1 = time.time() original_tokens = len(original_input_ids[0]) new_tokens = len(output) - (original_tokens if not shared.is_seq2seq else 0) logger.info(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})') return def generate_reply_custom(question, original_question, state, stopping_strings=None, is_chat=False): """ For models that do not use the transformers library for sampling """ stop_event_ref = state.pop('stop_event', None) state = copy.deepcopy(state) if stop_event_ref is not None: state['stop_event'] = stop_event_ref state['seed'] = set_manual_seed(state['seed']) t0 = time.time() reply = '' try: if not is_chat: yield '' if not state['stream']: reply = shared.model.generate(question, state) yield reply else: for reply in shared.model.generate_with_streaming(question, state): yield reply except Exception: traceback.print_exc() finally: t1 = time.time() if hasattr(shared.model, 'last_prompt_token_count'): original_tokens = shared.model.last_prompt_token_count new_tokens = len(encode(reply)[0]) if reply else 0 else: original_tokens = len(encode(original_question)[0]) new_tokens = len(encode(original_question + reply)[0]) - original_tokens logger.info(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {state["seed"]})') return def print_prompt(prompt, max_chars=-1): DARK_YELLOW = "\033[38;5;3m" RESET = "\033[0m" if max_chars > 0 and len(prompt) > max_chars: half_chars = max_chars // 2 hidden_len = len(prompt[half_chars:-half_chars]) hidden_msg = f"{DARK_YELLOW}[...{hidden_len} characters hidden...]{RESET}" print(prompt[:half_chars] + hidden_msg + prompt[-half_chars:]) else: print(prompt) print() ================================================ FILE: modules/tool_parsing.py ================================================ import json import random import re def get_tool_call_id() -> str: letter_bytes = "abcdefghijklmnopqrstuvwxyz0123456789" b = [random.choice(letter_bytes) for _ in range(8)] return "call_" + "".join(b).lower() # All known opening markers for tool calls across model formats. TOOL_CALL_OPENING_MARKERS = [ '', '', '', '<|tool_call_begin|>', '<|tool_calls_section_begin|>', '<|tool▁call▁begin|>', '<|tool▁calls▁begin|>', '[TOOL_CALLS]', 'to=functions.', '<|channel|>commentary', ] def streaming_tool_buffer_check(text, markers=None, tool_names=None, check_bare_names=False): ''' Check whether streaming output should be withheld because it may contain tool-call markup. Args: text: Full accumulated internal text. markers: Template-specific markers for partial-prefix matching. If None, falls back to TOOL_CALL_OPENING_MARKERS. tool_names: List of tool function names. check_bare_names: Whether to do partial-prefix matching on tool names (for models with unknown template format). ''' # Full marker found in text → buffer permanently. # Always checks ALL known markers regardless of template (cheap safety net). for marker in TOOL_CALL_OPENING_MARKERS: if marker in text: return True # Bare function-name full match: "get_weather{...}" or "get_weather {...}" if tool_names: for name in tool_names: if name + '{' in text or name + ' {' in text: return True # Partial-prefix matching: only for template-specific markers. for marker in (markers if markers is not None else TOOL_CALL_OPENING_MARKERS): for prefix_len in range(min(len(marker) - 1, len(text)), 0, -1): if text.endswith(marker[:prefix_len]): return True # Bare-name partial matching: only when template format is unknown. if check_bare_names and tool_names: for name in tool_names: if text.endswith(name): return True for prefix_len in range(min(len(name) - 1, len(text)), 0, -1): if text.endswith(name[:prefix_len]): return True return False def check_and_sanitize_tool_call_candidate(candidate_dict: dict, tool_names: list[str]): # check if property 'function' exists and is a dictionary, otherwise adapt dict if 'function' not in candidate_dict and 'name' in candidate_dict and isinstance(candidate_dict['name'], str): candidate_dict = {"type": "function", "function": candidate_dict} if 'function' in candidate_dict and isinstance(candidate_dict['function'], str): candidate_dict['name'] = candidate_dict['function'] del candidate_dict['function'] candidate_dict = {"type": "function", "function": candidate_dict} if 'function' in candidate_dict and isinstance(candidate_dict['function'], dict): # check if 'name' exists within 'function' and is part of known tools if 'name' in candidate_dict['function'] and candidate_dict['function']['name'] in tool_names: candidate_dict["type"] = "function" # ensure required property 'type' exists and has the right value # map property 'parameters' used by some older models to 'arguments' if "arguments" not in candidate_dict["function"] and "parameters" in candidate_dict["function"]: candidate_dict["function"]["arguments"] = candidate_dict["function"]["parameters"] del candidate_dict["function"]["parameters"] return candidate_dict return None def _extract_balanced_json(text: str, start: int) -> str | None: """Extract a balanced JSON object from text starting at the given position. Walks through the string tracking brace depth and string boundaries to correctly handle arbitrary nesting levels. """ if start >= len(text) or text[start] != '{': return None depth = 0 in_string = False escape_next = False for i in range(start, len(text)): c = text[i] if escape_next: escape_next = False continue if c == '\\' and in_string: escape_next = True continue if c == '"': in_string = not in_string continue if in_string: continue if c == '{': depth += 1 elif c == '}': depth -= 1 if depth == 0: return text[start:i + 1] return None def _parse_channel_tool_calls(answer: str, tool_names: list[str]): """Parse channel-based tool calls used by GPT-OSS and similar models. Format: <|start|>assistant to=functions.func_name<|channel|>commentary json<|message|>{"arg": "value"} or: <|channel|>commentary to=functions.func_name <|constrain|>json<|message|>{"arg": "value"} """ matches = [] start_pos = None # Pattern 1: to=functions.NAME before <|channel|> (GPT-OSS primary format) # Pattern 2: to=functions.NAME after <|channel|> (alternative format) patterns = [ r'to=functions\.([^<\s]+)\s*<\|channel\|>[^<]*<\|message\|>', r'<\|channel\|>\w+ to=functions\.([^<\s]+).*?<\|message\|>', ] for pattern in patterns: for m in re.finditer(pattern, answer): func_name = m.group(1).strip() if func_name not in tool_names: continue json_str = _extract_balanced_json(answer, m.end()) if json_str is None: continue try: arguments = json.loads(json_str) if start_pos is None: prefix = answer.rfind('<|start|>assistant', 0, m.start()) start_pos = prefix if prefix != -1 else m.start() matches.append({ "type": "function", "function": { "name": func_name, "arguments": arguments } }) except json.JSONDecodeError: pass if matches: break return matches, start_pos def _parse_mistral_token_tool_calls(answer: str, tool_names: list[str]): """Parse Mistral/Devstral-style tool calls with [TOOL_CALLS] and [ARGS] special tokens. Format: [TOOL_CALLS]func_name[ARGS]{"arg": "value"} """ matches = [] start_pos = None for m in re.finditer( r'\[TOOL_CALLS\]\s*(\S+?)\s*\[ARGS\]\s*', answer ): func_name = m.group(1).strip() if func_name not in tool_names: continue json_str = _extract_balanced_json(answer, m.end()) if json_str is None: continue try: arguments = json.loads(json_str) if start_pos is None: start_pos = m.start() matches.append({ "type": "function", "function": { "name": func_name, "arguments": arguments } }) except json.JSONDecodeError: pass return matches, start_pos def _parse_bare_name_tool_calls(answer: str, tool_names: list[str]): """Parse bare function-name style tool calls used by Mistral and similar models. Format: functionName{"arg": "value"} Multiple calls are concatenated directly or separated by whitespace. """ matches = [] start_pos = None # Match tool name followed by opening brace, then extract balanced JSON escaped_names = [re.escape(name) for name in tool_names] pattern = r'(?:' + '|'.join(escaped_names) + r')\s*\{' for match in re.finditer(pattern, answer): text = match.group(0) name = None for n in tool_names: if text.startswith(n): name = n break if not name: continue brace_start = match.end() - 1 json_str = _extract_balanced_json(answer, brace_start) if json_str is None: continue try: arguments = json.loads(json_str) if start_pos is None: start_pos = match.start() matches.append({ "type": "function", "function": { "name": name, "arguments": arguments } }) except json.JSONDecodeError: pass return matches, start_pos def _parse_xml_param_tool_calls(answer: str, tool_names: list[str]): """Parse XML-parameter style tool calls used by Qwen3.5 and similar models. Format: value """ matches = [] start_pos = None for tc_match in re.finditer(r'\s*(.*?)\s*', answer, re.DOTALL): tc_content = tc_match.group(1) func_match = re.search(r']+)>', tc_content) if not func_match: continue func_name = func_match.group(1).strip() if func_name not in tool_names: continue arguments = {} for param_match in re.finditer(r']+)>\s*(.*?)\s*', tc_content, re.DOTALL): param_name = param_match.group(1).strip() param_value = param_match.group(2).strip() try: param_value = json.loads(param_value) except (json.JSONDecodeError, ValueError): pass # keep as string arguments[param_name] = param_value if start_pos is None: start_pos = tc_match.start() matches.append({ "type": "function", "function": { "name": func_name, "arguments": arguments } }) return matches, start_pos def _parse_kimi_tool_calls(answer: str, tool_names: list[str]): """Parse Kimi-K2-style tool calls using pipe-delimited tokens. Format: <|tool_calls_section_begin|> <|tool_call_begin|>functions.func_name:index<|tool_call_argument_begin|>{"arg": "value"}<|tool_call_end|> <|tool_calls_section_end|> """ matches = [] start_pos = None for m in re.finditer( r'<\|tool_call_begin\|>\s*(?:functions\.)?(\S+?)(?::\d+)?\s*<\|tool_call_argument_begin\|>\s*', answer ): func_name = m.group(1).strip() if func_name not in tool_names: continue json_str = _extract_balanced_json(answer, m.end()) if json_str is None: continue try: arguments = json.loads(json_str) if start_pos is None: # Check for section begin marker before the call marker section = answer.rfind('<|tool_calls_section_begin|>', 0, m.start()) start_pos = section if section != -1 else m.start() matches.append({ "type": "function", "function": { "name": func_name, "arguments": arguments } }) except json.JSONDecodeError: pass return matches, start_pos def _parse_minimax_tool_calls(answer: str, tool_names: list[str]): """Parse MiniMax-style tool calls using invoke/parameter XML tags. Format: value """ matches = [] start_pos = None for tc_match in re.finditer(r'\s*(.*?)\s*', answer, re.DOTALL): tc_content = tc_match.group(1) # Split on to handle multiple parallel calls in one block for invoke_match in re.finditer(r'(.*?)', tc_content, re.DOTALL): func_name = invoke_match.group(1).strip() if func_name not in tool_names: continue invoke_body = invoke_match.group(2) arguments = {} for param_match in re.finditer(r'\s*(.*?)\s*', invoke_body, re.DOTALL): param_name = param_match.group(1).strip() param_value = param_match.group(2).strip() try: param_value = json.loads(param_value) except (json.JSONDecodeError, ValueError): pass # keep as string arguments[param_name] = param_value if start_pos is None: start_pos = tc_match.start() matches.append({ "type": "function", "function": { "name": func_name, "arguments": arguments } }) return matches, start_pos def _parse_deep_seek_tool_calls(answer: str, tool_names: list[str]): """Parse DeepSeek-style tool calls using fullwidth Unicode token delimiters. Format: <|tool▁calls▁begin|><|tool▁call▁begin|>func_name<|tool▁sep|>{"arg": "value"}<|tool▁call▁end|><|tool▁calls▁end|> """ matches = [] start_pos = None for m in re.finditer( r'<|tool▁call▁begin|>\s*(\S+?)\s*<|tool▁sep|>\s*', answer ): func_name = m.group(1).strip() if func_name not in tool_names: continue json_str = _extract_balanced_json(answer, m.end()) if json_str is None: continue try: arguments = json.loads(json_str) if start_pos is None: # Check for section begin marker before the call marker section = answer.rfind('<|tool▁calls▁begin|>', 0, m.start()) start_pos = section if section != -1 else m.start() matches.append({ "type": "function", "function": { "name": func_name, "arguments": arguments } }) except json.JSONDecodeError: pass return matches, start_pos def _parse_glm_tool_calls(answer: str, tool_names: list[str]): """Parse GLM-style tool calls using arg_key/arg_value XML pairs. Format: function_name key1 value1 """ matches = [] start_pos = None for tc_match in re.finditer(r'\s*(.*?)\s*', answer, re.DOTALL): tc_content = tc_match.group(1) # First non-tag text is the function name name_match = re.match(r'([^<\s]+)', tc_content.strip()) if not name_match: continue func_name = name_match.group(1).strip() if func_name not in tool_names: continue # Extract arg_key/arg_value pairs keys = [k.group(1).strip() for k in re.finditer(r'\s*(.*?)\s*', tc_content, re.DOTALL)] vals = [v.group(1).strip() for v in re.finditer(r'\s*(.*?)\s*', tc_content, re.DOTALL)] if len(keys) != len(vals): continue arguments = {} for k, v in zip(keys, vals): try: v = json.loads(v) except (json.JSONDecodeError, ValueError): pass # keep as string arguments[k] = v if start_pos is None: start_pos = tc_match.start() matches.append({ "type": "function", "function": { "name": func_name, "arguments": arguments } }) return matches, start_pos def _parse_pythonic_tool_calls(answer: str, tool_names: list[str]): """Parse pythonic-style tool calls used by Llama 4 and similar models. Format: [func_name(param1="value1", param2="value2"), func_name2(...)] """ matches = [] start_pos = None # Match a bracketed list of function calls bracket_match = re.search(r'\[([^\[\]]+)\]', answer) if not bracket_match: return matches, start_pos inner = bracket_match.group(1) # Build pattern for known tool names escaped_names = [re.escape(name) for name in tool_names] name_pattern = '|'.join(escaped_names) for call_match in re.finditer( r'(' + name_pattern + r')\(([^)]*)\)', inner ): func_name = call_match.group(1) params_str = call_match.group(2).strip() arguments = {} if params_str: # Parse key="value" pairs, handling commas inside quoted values for param_match in re.finditer( r'(\w+)\s*=\s*("(?:[^"\\]|\\.)*"|\'(?:[^\'\\]|\\.)*\'|[^,\)]+)', params_str ): param_name = param_match.group(1) param_value = param_match.group(2).strip() # Strip surrounding quotes if (param_value.startswith('"') and param_value.endswith('"')) or \ (param_value.startswith("'") and param_value.endswith("'")): param_value = param_value[1:-1] # Try to parse as JSON for numeric/bool/null values try: param_value = json.loads(param_value) except (json.JSONDecodeError, ValueError): pass arguments[param_name] = param_value if start_pos is None: start_pos = bracket_match.start() matches.append({ "type": "function", "function": { "name": func_name, "arguments": arguments } }) return matches, start_pos # Format registry: maps template substrings to the parser and streaming # markers for that format. When a format's hints are NOT found in the # template, its parser and markers are excluded. TOOL_CALL_FORMATS = [ { 'template_hints': ['tool▁call▁begin', 'tool▁calls▁begin'], 'parser': _parse_deep_seek_tool_calls, 'markers': ['<|tool▁call▁begin|>', '<|tool▁calls▁begin|>'], }, { 'template_hints': ['<|tool_call_begin|>', 'tool_calls_section'], 'parser': _parse_kimi_tool_calls, 'markers': ['<|tool_call_begin|>', '<|tool_calls_section_begin|>'], }, { 'template_hints': ['to=functions.', '<|channel|>'], 'parser': _parse_channel_tool_calls, 'markers': ['to=functions.', '<|channel|>commentary'], }, { 'template_hints': ['minimax:tool_call'], 'parser': _parse_minimax_tool_calls, 'markers': [''], }, { 'template_hints': [''], 'parser': _parse_glm_tool_calls, 'markers': [''], }, { 'template_hints': [''], 'parser': _parse_xml_param_tool_calls, 'markers': [''], }, { 'template_hints': ['[TOOL_CALLS]'], 'parser': _parse_mistral_token_tool_calls, 'markers': ['[TOOL_CALLS]'], }, { 'template_hints': [''], 'parser': None, 'markers': [''], }, ] # Default ordered list of all specialized parsers. ALL_PARSERS = [ _parse_deep_seek_tool_calls, _parse_kimi_tool_calls, _parse_channel_tool_calls, _parse_minimax_tool_calls, _parse_glm_tool_calls, _parse_xml_param_tool_calls, _parse_mistral_token_tool_calls, _parse_bare_name_tool_calls, _parse_pythonic_tool_calls, ] def detect_tool_call_format(template_str): """Inspect a chat/instruction template to determine which tool call formats are relevant. Uses an exclude-based approach: starts with all parsers/markers, then removes the ones whose hints are not found in the template. Returns (parsers, streaming_markers, check_bare_names). """ if not template_str: return None, TOOL_CALL_OPENING_MARKERS, True matched_any = False exclude_parsers = [] exclude_markers = [] matched_markers = [] for fmt in TOOL_CALL_FORMATS: if any(hint in template_str for hint in fmt['template_hints']): matched_any = True matched_markers.extend(fmt['markers']) else: if fmt['parser'] is not None: exclude_parsers.append(fmt['parser']) exclude_markers.extend(fmt['markers']) if not matched_any: return None, TOOL_CALL_OPENING_MARKERS, True parsers = [p for p in ALL_PARSERS if p not in exclude_parsers] markers = [m for m in TOOL_CALL_OPENING_MARKERS if m not in exclude_markers or m in matched_markers] return parsers, markers, False def parse_tool_call(answer: str, tool_names: list[str], return_prefix: bool = False, parsers: list = None): matches = [] start_pos = None def _return(matches, start_pos): if return_prefix: prefix = answer[:start_pos] if matches and start_pos is not None else '' return matches, prefix return matches # Try specialized parsers. for parser in (parsers if parsers is not None else ALL_PARSERS): matches, start_pos = parser(answer, tool_names) if matches: return _return(matches, start_pos) # Generic fallback: regex pattern to find the JSON content wrapped in , , , and other tags observed from various models patterns = [r"(```[^\n]*)\n(.*?)```", r"<([^>]+)>(.*?)"] for pattern in patterns: for match in re.finditer(pattern, answer, re.DOTALL): if match.group(2) is None: continue # remove backtick wraps if present candidate = re.sub(r"^```(json|xml|python[^\n]*)\n", "", match.group(2).strip()) candidate = re.sub(r"```$", "", candidate.strip()) # unwrap inner tags candidate = re.sub(pattern, r"\2", candidate.strip(), flags=re.DOTALL) # llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually if re.search(r"\}\s*\n\s*\{", candidate) is not None: candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate) if not candidate.strip().startswith("["): candidate = "[" + candidate + "]" candidates = [] try: # parse the candidate JSON into a dictionary candidates = json.loads(candidate) if not isinstance(candidates, list): candidates = [candidates] except json.JSONDecodeError: # Ignore invalid JSON silently continue for candidate_dict in candidates: checked_candidate = check_and_sanitize_tool_call_candidate(candidate_dict, tool_names) if checked_candidate is not None: if start_pos is None: start_pos = match.start() matches.append(checked_candidate) # last resort if nothing has been mapped: LLM might have produced plain json tool call without xml-like tags if len(matches) == 0: try: candidate = answer # llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually if re.search(r"\}\s*\n\s*\{", candidate) is not None: candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate) if not candidate.strip().startswith("["): candidate = "[" + candidate + "]" # parse the candidate JSON into a dictionary candidates = json.loads(candidate) if not isinstance(candidates, list): candidates = [candidates] for candidate_dict in candidates: checked_candidate = check_and_sanitize_tool_call_candidate(candidate_dict, tool_names) if checked_candidate is not None: matches.append(checked_candidate) except json.JSONDecodeError: # Ignore invalid JSON silently pass return _return(matches, start_pos) ================================================ FILE: modules/tool_use.py ================================================ import importlib.util import json from modules import shared from modules.logging_colors import logger from modules.utils import natural_keys, sanitize_filename def get_available_tools(): """Return sorted list of tool script names from user_data/tools/*.py.""" tools_dir = shared.user_data_dir / 'tools' tools_dir.mkdir(parents=True, exist_ok=True) return sorted((p.stem for p in tools_dir.glob('*.py')), key=natural_keys) def load_tools(selected_names): """ Import selected tool scripts and return their definitions and executors. Returns (tool_defs, executors) where: - tool_defs: list of OpenAI-format tool dicts - executors: dict mapping function_name -> execute callable """ tool_defs = [] executors = {} for name in selected_names: name = sanitize_filename(name) if not name: continue path = shared.user_data_dir / 'tools' / f'{name}.py' if not path.exists(): continue try: spec = importlib.util.spec_from_file_location(f"tool_{name}", str(path)) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) except Exception: logger.exception(f'Failed to load tool script "{name}"') continue tool_def = getattr(module, 'tool', None) execute_fn = getattr(module, 'execute', None) if tool_def is None or execute_fn is None: logger.warning(f'Tool "{name}" is missing a "tool" dict or "execute" function.') continue func_name = tool_def.get('function', {}).get('name', name) if func_name in executors: logger.warning(f'Tool "{name}" declares function name "{func_name}" which conflicts with an already loaded tool. Skipping.') continue tool_defs.append(tool_def) executors[func_name] = execute_fn return tool_defs, executors def execute_tool(func_name, arguments, executors): """Execute a tool by function name. Returns result as a JSON string.""" fn = executors.get(func_name) if fn is None: return json.dumps({"error": f"Unknown tool: {func_name}"}) try: if isinstance(arguments, str): arguments = json.loads(arguments) result = fn(arguments) return json.dumps(result) if not isinstance(result, str) else result except Exception as e: logger.exception(f'Tool "{func_name}" execution failed') return json.dumps({"error": str(e)}) ================================================ FILE: modules/torch_utils.py ================================================ import gc import torch from accelerate.utils import is_npu_available, is_xpu_available from transformers import is_torch_npu_available, is_torch_xpu_available from modules import shared def get_device(): if hasattr(shared.model, 'device'): return shared.model.device elif torch.cuda.is_available(): return torch.device('cuda') elif torch.backends.mps.is_available(): return torch.device('mps') elif is_torch_xpu_available(): return torch.device('xpu:0') elif is_torch_npu_available(): return torch.device('npu:0') else: return None def clear_torch_cache(): gc.collect() if not shared.args.cpu: if torch.cuda.is_available(): torch.cuda.empty_cache() elif is_xpu_available(): torch.xpu.empty_cache() elif is_npu_available(): torch.npu.empty_cache() elif torch.backends.mps.is_available(): if hasattr(torch.backends.mps, 'empty_cache'): torch.backends.mps.empty_cache() ================================================ FILE: modules/training.py ================================================ import os os.environ["WANDB_MODE"] = "offline" # os.environ["WANDB_DISABLED"] = "true" import json import math import random import shutil import sys import threading import time import traceback from datetime import datetime from pathlib import Path import yaml import gradio as gr from modules import shared, ui, utils from modules.evaluate import ( calculate_perplexity, generate_markdown_table, save_past_evaluations ) from modules.logging_colors import logger from modules.models import reload_model PARAMETERS = ["lora_name", "always_override", "all_linear", "q_proj_en", "v_proj_en", "k_proj_en", "o_proj_en", "gate_proj_en", "down_proj_en", "up_proj_en", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "text_dataset", "higher_rank_limit", "warmup_steps", "optimizer", "stride_length", "stop_at_loss", "add_eos_token", "excess_length", "report_to"] WANT_INTERRUPT = False train_log = {} train_template = {} def create_ui(): mu = shared.args.multi_user with gr.Tab("Training", elem_id="training-tab"): with gr.Tab('Train LoRA', elem_id='lora-train-tab'): tmp = gr.State('') with gr.Row(): with gr.Column(): gr.Markdown("[Tutorial](https://github.com/oobabooga/text-generation-webui/wiki/05-%E2%80%90-Training-Tab)") with gr.Row(): copy_from = gr.Dropdown(label='Copy parameters from', value='None', choices=utils.get_available_loras(), elem_classes=['slim-dropdown'], interactive=not mu) ui.create_refresh_button(copy_from, lambda: None, lambda: {'choices': utils.get_available_loras()}, 'refresh-button', interactive=not mu) with gr.Row(): with gr.Column(scale=5): lora_name = gr.Textbox(label='Name', info='The name of your new LoRA file') with gr.Column(): always_override = gr.Checkbox(label='Override Existing Files', value=False, info='If the name is the same, checking will replace the existing file, and unchecking will load and continue from it (the rank must be the same).', elem_classes=['no-background']) with gr.Accordion(label='Target Modules', open=False, elem_classes='tgw-accordion'): gr.Markdown("Selects which modules to target in training. Targeting more modules is closer to a full fine-tune at the cost of increased VRAM and adapter size.") all_linear = gr.Checkbox(label='Target all linear layers', value=True, info='Targets every nn.Linear layer except lm_head. Works for any model architecture. When checked, the individual module checkboxes below are ignored.', elem_classes=['no-background']) with gr.Row(): with gr.Column(): q_proj_en = gr.Checkbox(label='Enable q_proj', value=True) with gr.Column(): v_proj_en = gr.Checkbox(label='Enable v_proj', value=True) with gr.Column(): k_proj_en = gr.Checkbox(label='Enable k_proj', value=False) with gr.Column(): o_proj_en = gr.Checkbox(label='Enable o_proj', value=False) with gr.Column(): gate_proj_en = gr.Checkbox(label='Enable gate_proj', value=False) with gr.Column(): down_proj_en = gr.Checkbox(label='Enable down_proj', value=False) with gr.Column(): up_proj_en = gr.Checkbox(label='Enable up_proj', value=False) with gr.Row(): with gr.Column(): lora_rank = gr.Slider(label='LoRA Rank', value=8, minimum=0, maximum=1024, step=4, info='Also called dimension count. Higher values = larger file, more content control. Smaller values = smaller file, less control. Use 4 or 8 for style, 128 or 256 to teach, 1024+ for fine-detail on big data. More VRAM is needed for higher ranks.') lora_alpha = gr.Slider(label='LoRA Alpha', value=16, minimum=0, maximum=2048, step=4, info='This divided by the rank becomes the scaling of the LoRA. Higher means stronger. A good standard value is twice your Rank.') batch_size = gr.Slider(label='Batch Size', value=32, minimum=0, maximum=1024, step=4, info='Global batch size. The two batch sizes together determine gradient accumulation (gradientAccum = batch / microBatch). Higher gradient accum values lead to better quality training.') micro_batch_size = gr.Slider(label='Micro Batch Size', value=4, minimum=1, maximum=128, step=1, info='Per-device batch size (NOTE: multiple devices not yet implemented). Increasing this will increase VRAM usage.') cutoff_len = gr.Slider(label='Cutoff Length', minimum=0, maximum=4096, value=512, step=32, info='Maximum sequence length in tokens. For instruction datasets, conversations longer than this are dropped. For text datasets, documents are split into chunks of this size. Higher values require more VRAM.') with gr.Column(): save_steps = gr.Number(label='Save every n steps', value=0, info='If above 0, a full training checkpoint (adapter weights, optimizer, scheduler) will be saved every time this many steps pass. Training can be resumed from these checkpoints.') epochs = gr.Number(label='Epochs', value=3, info='Number of times every entry in the dataset should be fed into training. So 1 means feed each item in once, 5 means feed it in five times, etc.') learning_rate = gr.Textbox(label='Learning Rate', value='3e-4', info='In scientific notation. 3e-4 is a good starting base point. 1e-2 is extremely high, 1e-6 is extremely low.') with gr.Row(): lr_scheduler_type = gr.Dropdown(label='LR Scheduler', value='cosine', choices=['linear', 'constant', 'constant_with_warmup', 'cosine', 'cosine_with_restarts', 'polynomial', 'inverse_sqrt'], info='Learning rate scheduler - defines how the learning rate changes over time. "Constant" means never change, "linear" means to go in a straight line from the learning rate down to 0, cosine follows a curve, etc.', elem_classes=['slim-dropdown']) with gr.Accordion(label='Advanced Options', open=False, elem_classes='tgw-accordion'): with gr.Row(): with gr.Column(): lora_dropout = gr.Slider(label='LoRA Dropout', minimum=0.0, maximum=1.0, step=0.025, value=0.0, info='Percentage probability for dropout of LoRA layers. This can help reduce overfitting. Most users should leave at default.') stop_at_loss = gr.Slider(label='Stop at loss', minimum=0.0, maximum=3.0, step=0.1, value=0.00, info='The process will automatically stop once the desired loss value is reached. (reasonable numbers are 1.5-1.8)') with gr.Row(): optimizer = gr.Dropdown(label='Optimizer', value='adamw_torch', choices=['adamw_hf', 'adamw_torch', 'adamw_torch_fused', 'adamw_torch_xla', 'adamw_apex_fused', 'adafactor', 'adamw_bnb_8bit', 'adamw_anyprecision', 'sgd', 'adagrad'], info='Optimizer algorithm. adamw_torch is the standard choice. adamw_bnb_8bit uses less VRAM. adafactor is memory-efficient for large models.', elem_classes=['slim-dropdown']) with gr.Column(): warmup_steps = gr.Number(label='Warmup Steps', value=100, info='For this many steps at the start, the learning rate is gradually ramped up from 0 to the target value. This prevents unstable updates early in training.') add_eos_token = gr.Checkbox(label='Add EOS token', value=True, info="Adds EOS token for each document in text datasets.") excess_length = gr.Dropdown(label='Excess length', value='drop', choices=['drop', 'truncate'], info='What to do with conversations that exceed the cutoff length. "Drop" removes them entirely (recommended). "Truncate" cuts from the right, which may produce incomplete responses.', elem_classes=['slim-dropdown']) higher_rank_limit = gr.Checkbox(label='Enable higher ranks', value=False, info='If checked, changes Rank/Alpha slider above to go much higher. This will not work without a datacenter-class GPU.') report_to = gr.Radio(label="Save detailed logs with", value="None", choices=["None", "wandb", "tensorboard"], interactive=True) with gr.Column(): with gr.Tab(label='Chat Dataset'): with gr.Row(): dataset = gr.Dropdown(choices=utils.get_chat_datasets(str(shared.user_data_dir / 'training/datasets')), value='None', label='Dataset File', info='A JSON file with chat conversations (messages or ShareGPT format). Each row is one conversation.', elem_classes=['slim-dropdown'], interactive=not mu) ui.create_refresh_button(dataset, lambda: None, lambda: {'choices': utils.get_chat_datasets(str(shared.user_data_dir / 'training/datasets'))}, 'refresh-button', interactive=not mu) with gr.Row(): format = gr.Dropdown(choices=get_instruction_templates(), value='None', label='Instruction Template', info='Select an instruction template for formatting the dataset, or "Chat Template" to use the model\'s built-in chat template.', elem_classes=['slim-dropdown'], interactive=not mu) ui.create_refresh_button(format, lambda: None, lambda: {'choices': get_instruction_templates()}, 'refresh-button', interactive=not mu) with gr.Tab(label="Text Dataset"): with gr.Row(): text_dataset = gr.Dropdown(choices=utils.get_text_datasets(str(shared.user_data_dir / 'training/datasets')), value='None', label='Dataset File', info='A JSON file with a "text" key per row, for pretraining-style training. Each row is one document.', elem_classes=['slim-dropdown'], interactive=not mu) ui.create_refresh_button(text_dataset, lambda: None, lambda: {'choices': utils.get_text_datasets(str(shared.user_data_dir / 'training/datasets'))}, 'refresh-button', interactive=not mu) stride_length = gr.Slider(label='Stride Length', minimum=0, maximum=2048, value=256, step=32, info='Overlap between chunks in tokens. 0 = no overlap. Values like 256 or 512 help preserve context across chunk boundaries.') with gr.Row(): eval_dataset = gr.Dropdown(choices=utils.get_datasets(str(shared.user_data_dir / 'training/datasets'), 'json'), value='None', label='Evaluation Dataset', info='The (optional) dataset file used to evaluate the model after training.', elem_classes=['slim-dropdown'], interactive=not mu) ui.create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': utils.get_datasets(str(shared.user_data_dir / 'training/datasets'), 'json')}, 'refresh-button', interactive=not mu) eval_steps = gr.Number(label='Evaluate every n steps', value=100, info='If an evaluation dataset is given, test it every time this many steps pass.') with gr.Row(): start_button = gr.Button("Start LoRA Training", variant='primary', interactive=not mu) stop_button = gr.Button("Interrupt", interactive=not mu) output = gr.Markdown(value="Ready") with gr.Tab('Perplexity evaluation', elem_id='evaluate-tab'): with gr.Row(): with gr.Column(): models = gr.Dropdown(utils.get_available_models(), label='Models', multiselect=True, interactive=not mu) evaluate_text_file = gr.Dropdown(choices=['wikitext', 'ptb', 'ptb_new'] + utils.get_datasets(str(shared.user_data_dir / 'training/datasets'), 'txt')[1:], value='wikitext', label='Input dataset', info=f'The raw text file on which the model will be evaluated. The first options are automatically downloaded: wikitext, ptb, and ptb_new. The next options are your local text files under {shared.user_data_dir}/training/datasets.', interactive=not mu) with gr.Row(): with gr.Column(): stride_length = gr.Slider(label='Stride', minimum=0, maximum=32768, value=512, step=256, info='Used to make the evaluation faster at the cost of accuracy. 1 = slowest but most accurate. 512 is a common value.') with gr.Column(): max_length = gr.Number(label='max_length', precision=0, step=256, value=0, info='The context for each evaluation. If set to 0, the maximum context length for the model will be used.') with gr.Row(): start_current_evaluation = gr.Button("Evaluate loaded model", interactive=not mu) start_evaluation = gr.Button("Evaluate selected models", interactive=not mu) stop_evaluation = gr.Button("Interrupt", interactive=not mu) with gr.Column(): evaluation_log = gr.Markdown(value='') evaluation_table = gr.Dataframe(value=generate_markdown_table(), interactive=True) with gr.Row(): save_comments = gr.Button('Save comments', elem_classes="small-button", interactive=not mu) refresh_table = gr.Button('Refresh the table', elem_classes="small-button", interactive=not mu) # Training events all_params = [lora_name, always_override, all_linear, q_proj_en, v_proj_en, k_proj_en, o_proj_en, gate_proj_en, down_proj_en, up_proj_en, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, text_dataset, higher_rank_limit, warmup_steps, optimizer, stride_length, stop_at_loss, add_eos_token, excess_length, report_to] copy_from.change(do_copy_params, [copy_from] + all_params, all_params) start_button.click(do_train, all_params, output) stop_button.click(do_interrupt, None, None, queue=False) higher_rank_limit.change(change_rank_limit, [higher_rank_limit], [lora_rank, lora_alpha]) # Evaluation events. For some reason, the interrupt event # doesn't work with the .then() syntax, so I write them one # by one in this ugly but functional way. ev = start_evaluation.click(calculate_perplexity, [models, evaluate_text_file, stride_length, max_length], evaluation_log, show_progress=False) ev.then(generate_markdown_table, None, evaluation_table, show_progress=False) ev_cur = start_current_evaluation.click( lambda: ['current model'], None, tmp).then( calculate_perplexity, [tmp, evaluate_text_file, stride_length, max_length], evaluation_log, show_progress=False) ev_cur.then(generate_markdown_table, None, evaluation_table, show_progress=False) stop_evaluation.click(None, None, None, cancels=[ev, ev_cur], queue=False) refresh_table.click(generate_markdown_table, None, evaluation_table, show_progress=True) save_comments.click( save_past_evaluations, evaluation_table, None).then( lambda: "Comments saved.", None, evaluation_log, show_progress=False) def do_interrupt(): global WANT_INTERRUPT WANT_INTERRUPT = True def do_copy_params(lora_name: str, *args): f_name = f"{shared.args.lora_dir}/{clean_path(None, lora_name)}/training_parameters.json" if Path(f_name).is_file(): with open(f_name, 'r', encoding='utf-8') as format_file: params: dict[str, str] = json.load(format_file) else: params = {} result = list() for i in range(0, len(PARAMETERS)): key = PARAMETERS[i] if key in params: result.append(params[key]) else: result.append(args[i]) return result def change_rank_limit(use_higher_ranks: bool): mult = 2 if use_higher_ranks else 1 return {"maximum": 1024 * mult, "__type__": "update"}, {"maximum": 2048 * mult, "__type__": "update"} def clean_path(base_path: str, path: str): """Strips unusual symbols and forcibly builds a path as relative to the intended directory.""" path = path.replace('\\', '/').replace('..', '_') if base_path is None: return path return f'{Path(base_path).absolute()}/{path}' def get_instruction_templates(): path = shared.user_data_dir / 'instruction-templates' names = set() for ext in ['yaml', 'yml', 'jinja', 'jinja2']: for f in path.glob(f'*.{ext}'): names.add(f.stem) return ['None', 'Chat Template'] + sorted(names, key=utils.natural_keys) def load_template(name): """Load a Jinja2 template string from {user_data_dir}/instruction-templates/.""" path = shared.user_data_dir / 'instruction-templates' for ext in ['jinja', 'jinja2', 'yaml', 'yml']: filepath = path / f'{name}.{ext}' if filepath.exists(): if ext in ['jinja', 'jinja2']: return filepath.read_text(encoding='utf-8') else: data = yaml.safe_load(filepath.read_text(encoding='utf-8')) return data.get('instruction_template', '') return '' def backup_adapter(input_folder): # Get the creation date of the adapter file (safetensors or bin) try: adapter_file = Path(f"{input_folder}/adapter_model.safetensors") if not adapter_file.is_file(): adapter_file = Path(f"{input_folder}/adapter_model.bin") if adapter_file.is_file(): logger.info("Backing up existing LoRA adapter") creation_date = datetime.fromtimestamp(adapter_file.stat().st_ctime) creation_date_str = creation_date.strftime("Backup-%Y-%m-%d") # Create the new subfolder subfolder_path = Path(f"{input_folder}/{creation_date_str}") subfolder_path.mkdir(parents=True, exist_ok=True) # Check if the file already exists in the subfolder backup_adapter_file = subfolder_path / adapter_file.name if backup_adapter_file.is_file(): print(" - Backup already exists. Skipping backup process.") return # Copy existing files to the new subfolder existing_files = Path(input_folder).iterdir() for file in existing_files: if file.is_file(): shutil.copy2(file, subfolder_path) except Exception as e: print("An error occurred in backup_adapter:", str(e)) def calc_trainable_parameters(model): trainable_params = 0 all_param = 0 for _, param in model.named_parameters(): num_params = param.numel() # if using DS Zero 3 and the weights are initialized empty if num_params == 0 and hasattr(param, "ds_numel"): num_params = param.ds_numel all_param += num_params if param.requires_grad: trainable_params += num_params return trainable_params, all_param def do_train(lora_name: str, always_override: bool, all_linear: bool, q_proj_en: bool, v_proj_en: bool, k_proj_en: bool, o_proj_en: bool, gate_proj_en: bool, down_proj_en: bool, up_proj_en: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, text_dataset: str, higher_rank_limit: bool, warmup_steps: int, optimizer: str, stride_length: int, stop_at_loss: float, add_eos_token: bool, excess_length: str, report_to: str): import torch import transformers from datasets import Dataset, load_dataset from peft import ( LoraConfig, get_peft_model, prepare_model_for_kbit_training, set_peft_model_state_dict ) global WANT_INTERRUPT WANT_INTERRUPT = False # == Input validation / processing == yield "Preparing the input..." if shared.args.loader == 'llama.cpp': yield "Error: LoRA training requires a model loaded with the Transformers loader. GGUF models are not supported for training." return lora_file_path = clean_path(None, lora_name) if lora_file_path.strip() == '': yield "Missing or invalid LoRA file name input." return lora_file_path = f"{Path(shared.args.lora_dir)}/{lora_file_path}" actual_lr = float(learning_rate) model_type = type(shared.model).__name__ if model_type == "PeftModelForCausalLM": if len(shared.lora_names) > 0: yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*" logger.warning("Training LoRA over top of another LoRA. May have unexpected effects.") else: yield "Model ID not matched due to LoRA loading. Consider reloading base model. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*" logger.warning("Model ID not matched due to LoRA loading. Consider reloading base model.") time.sleep(5) if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0: yield "Cannot input zeroes." return gradient_accumulation_steps = max(1, batch_size // micro_batch_size) original_chat_template = getattr(shared.tokenizer, 'chat_template', None) if shared.tokenizer.pad_token_id is None: shared.tokenizer.pad_token_id = shared.tokenizer.eos_token_id shared.tokenizer.padding_side = "right" def list_target_modules(): if all_linear: return "all-linear" target_mods = [f"{name}_proj" for name, enabled in { "q": q_proj_en, "k": k_proj_en, "v": v_proj_en, "o": o_proj_en, "gate": gate_proj_en, "down": down_proj_en, "up": up_proj_en, }.items() if enabled] return target_mods def normalize_messages(data_point): """Convert a dataset row to OpenAI messages format for apply_chat_template().""" if "messages" in data_point: return data_point["messages"] if "conversations" in data_point: role_map = {"human": "user", "gpt": "assistant"} return [ {"role": role_map.get(turn.get("from", ""), turn.get("from", "")), "content": turn["value"]} for turn in data_point["conversations"] ] raise RuntimeError( f'Dataset row must contain "messages" or "conversations" key. ' f'Found: {list(data_point.keys())}' ) def tokenize_conversation(data_point): """Tokenize using apply_chat_template() with assistant-only label masking.""" messages = normalize_messages(data_point) full_ids = list(shared.tokenizer.apply_chat_template(messages, tokenize=True, return_dict=False)) # Build labels: -100 for everything, then unmask assistant turns. # This assumes apply_chat_template(messages[:i]) is a token-for-token # prefix of apply_chat_template(messages[:i+1]), which holds for all # standard chat templates (Llama, ChatML, Mistral, etc.). labels = [-100] * len(full_ids) for i, msg in enumerate(messages): if msg["role"] == "assistant": # Tokens up to where this assistant turn starts header_ids = shared.tokenizer.apply_chat_template( messages[:i], tokenize=True, return_dict=False, add_generation_prompt=True ) # Tokens through end of this assistant turn through_ids = shared.tokenizer.apply_chat_template( messages[:i + 1], tokenize=True, return_dict=False ) # Unmask assistant tokens start = len(header_ids) end = min(len(through_ids), len(full_ids)) labels[start:end] = full_ids[start:end] if len(full_ids) > cutoff_len: if excess_length == 'truncate': full_ids = full_ids[:cutoff_len] labels = labels[:cutoff_len] else: return {"input_ids": [], "labels": [], "attention_mask": []} return { "input_ids": full_ids, "labels": labels, "attention_mask": [1] * len(full_ids), } train_template.clear() # == Prep the dataset, format, etc == has_text_dataset = text_dataset not in ['None', ''] has_chat_dataset = dataset not in ['None', ''] if has_text_dataset and has_chat_dataset: yield "Error: select either a Chat Dataset or a Text Dataset, not both." return def tokenize_text_data(data): """Tokenize text dataset rows, concatenate, and split into chunks.""" all_tokens = [] for row in data: tokens = shared.tokenizer.encode(row['text']) if add_eos_token: tokens.append(shared.tokenizer.eos_token_id) all_tokens.extend(tokens) stride = int(stride_length) step = cutoff_len - stride if stride > 0 else cutoff_len if step <= 0: return None, "Error: stride length must be smaller than cutoff length." if len(all_tokens) < cutoff_len: return None, "Error: dataset is too short to fill even one chunk of the given cutoff length." chunks = [] for start in range(0, len(all_tokens), step): chunk = all_tokens[start:start + cutoff_len] if len(chunk) == 0: break if len(chunk) < cutoff_len: pad_len = cutoff_len - len(chunk) chunks.append({ "input_ids": chunk + [shared.tokenizer.pad_token_id] * pad_len, "labels": list(chunk) + [-100] * pad_len, "attention_mask": [1] * len(chunk) + [0] * pad_len, }) else: chunks.append({ "input_ids": chunk, "labels": list(chunk), "attention_mask": [1] * cutoff_len, }) return Dataset.from_list(chunks), None if has_text_dataset: train_template["template_type"] = "text_dataset" logger.info("Loading text dataset") data = load_dataset("json", data_files=clean_path(str(shared.user_data_dir / 'training/datasets'), f'{text_dataset}.json')) if "text" not in data['train'].column_names: yield "Error: text dataset must have a \"text\" key per row." return train_data, err = tokenize_text_data(data['train']) if err: yield err return if eval_dataset == 'None': eval_data = None else: eval_raw = load_dataset("json", data_files=clean_path(str(shared.user_data_dir / 'training/datasets'), f'{eval_dataset}.json')) if "text" not in eval_raw['train'].column_names: yield "Error: evaluation dataset must have a \"text\" key per row." return eval_data, err = tokenize_text_data(eval_raw['train']) if err: yield err return elif has_chat_dataset: if format in ['None', '']: yield "Missing format choice input, cannot continue." return if format == 'Chat Template': if not getattr(shared.tokenizer, 'chat_template', None): yield "Error: this model's tokenizer does not have a chat template. Select an instruction template instead, or load an instruct/chat model." return else: # Load custom instruction template and set on tokenizer template_str = load_template(format) if not template_str: yield f"Error: could not load instruction template '{format}'." return shared.tokenizer.chat_template = template_str # Unified path — both cases use tokenize_conversation() train_template["template_type"] = "chat_template" logger.info("Loading JSON dataset with chat template format") data = load_dataset("json", data_files=clean_path(str(shared.user_data_dir / 'training/datasets'), f'{dataset}.json')) # Validate the first row try: normalize_messages(data['train'][0]) except (RuntimeError, KeyError, IndexError) as e: yield f"Error: {e}" return total = len(data['train']) train_data = data['train'].map( tokenize_conversation, remove_columns=data['train'].column_names, new_fingerprint='%030x' % random.randrange(16**30) ) train_data = train_data.filter(lambda x: len(x['input_ids']) > 0) dropped = total - len(train_data) if dropped > 0: logger.warning(f"Dropped {dropped}/{total} conversations exceeding cutoff length of {cutoff_len} tokens.") if len(train_data) == 0: yield f"Error: all {total} conversations exceed the cutoff length of {cutoff_len} tokens. Increase the cutoff length or shorten your data." return if eval_dataset == 'None': eval_data = None else: eval_data = load_dataset("json", data_files=clean_path(str(shared.user_data_dir / 'training/datasets'), f'{eval_dataset}.json')) eval_data = eval_data['train'].map( tokenize_conversation, remove_columns=eval_data['train'].column_names, new_fingerprint='%030x' % random.randrange(16**30) ) eval_data = eval_data.filter(lambda x: len(x['input_ids']) > 0) else: yield "No dataset selected. Choose a Chat Dataset or a Text Dataset." return # == We MUST reload model if it went through any previous training, even failed one == if shared.model_dirty_from_training: selected_model = shared.model_name if selected_model: print("\033[1;31;1m(Model has been modified by previous training, it needs to be reloaded...)\033[0;37;0m") try: yield f"Reloading {selected_model}..." reload_model() if shared.model is not None: print("Model reloaded OK, continue with training.") else: yield f"Failed to load {selected_model}." return except Exception: exc = traceback.format_exc() logger.error('Failed to reload the model.') print(exc) yield exc.replace('\n', '\n\n') return # == Start prepping the model itself == if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'): logger.info("Getting model ready") if 'quantization_config' in shared.model.config.to_dict(): prepare_model_for_kbit_training(shared.model) # base model is now frozen and should not be reused for any other LoRA training than this one shared.model_dirty_from_training = True logger.info("Preparing for training") target_modules = list_target_modules() if not target_modules: yield "No target modules selected. Enable at least one module or check 'Target all linear layers'." return config = LoraConfig( r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules, lora_dropout=lora_dropout, bias="none", task_type="CAUSAL_LM" ) # == Backup the existing adapter == if not always_override: backup_adapter(lora_file_path) # == get model trainable params model_trainable_params, model_all_params = calc_trainable_parameters(shared.model) # == Determine if we can resume from a checkpoint == resume_checkpoint = None try: logger.info("Creating LoRA model") lora_model = get_peft_model(shared.model, config) if not always_override and Path(lora_file_path).exists(): # Look for HF Trainer checkpoint dirs (full resumption) checkpoints = sorted(Path(lora_file_path).glob("checkpoint-*"), key=os.path.getmtime) if checkpoints: resume_checkpoint = str(checkpoints[-1]) logger.info(f"Will resume from checkpoint: {resume_checkpoint}") else: # Legacy fallback: load bare adapter weights only safetensors_path = Path(f"{lora_file_path}/adapter_model.safetensors") bin_path = Path(f"{lora_file_path}/adapter_model.bin") if safetensors_path.is_file(): logger.info("Loading existing LoRA data (safetensors)") from safetensors.torch import load_file state_dict_peft = load_file(str(safetensors_path)) set_peft_model_state_dict(lora_model, state_dict_peft) elif bin_path.is_file(): logger.info("Loading existing LoRA data (bin)") state_dict_peft = torch.load(str(bin_path), weights_only=True) set_peft_model_state_dict(lora_model, state_dict_peft) except Exception: yield traceback.format_exc().replace('\n', '\n\n') return class Tracked(): def __init__(self): self.current_steps = 0 self.max_steps = 0 self.did_save = False tracked = Tracked() actual_save_steps = math.ceil(save_steps / gradient_accumulation_steps) class Callbacks(transformers.TrainerCallback): def on_step_begin(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs): tracked.current_steps = state.global_step * gradient_accumulation_steps tracked.max_steps = state.max_steps * gradient_accumulation_steps if WANT_INTERRUPT: control.should_epoch_stop = True control.should_training_stop = True def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs): tracked.current_steps += 1 if WANT_INTERRUPT: control.should_epoch_stop = True control.should_training_stop = True def on_log(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, logs, **kwargs): train_log.update(logs) train_log.update({"current_steps": tracked.current_steps}) if WANT_INTERRUPT: print("\033[1;31;1mInterrupted by user\033[0;37;0m") print(f"\033[1;30;40mStep: {tracked.current_steps} \033[0;37;0m", end='') if 'loss' in logs: loss = float(logs['loss']) if stop_at_loss > 0 and loss <= stop_at_loss: control.should_epoch_stop = True control.should_training_stop = True print(f"\033[1;31;1mStop Loss {stop_at_loss} reached.\033[0;37;0m") def on_save(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs): checkpoint_dir = Path(args.output_dir) / f"checkpoint-{state.global_step}" if checkpoint_dir.exists(): with open(checkpoint_dir / "training_log.json", 'w', encoding='utf-8') as file: json.dump(train_log, file, indent=2) with open(checkpoint_dir / "training_prompt.json", 'w', encoding='utf-8') as file: json.dump(train_template, file, indent=2) # Fix training for mixed precision models for param in shared.model.parameters(): if param.requires_grad: param.data = param.data.float() lora_model.config.use_cache = False def collate_fn(batch): max_len = max(len(item['input_ids']) for item in batch) input_ids, labels, attention_mask = [], [], [] for item in batch: pad_len = max_len - len(item['input_ids']) input_ids.append(item['input_ids'] + [shared.tokenizer.pad_token_id] * pad_len) labels.append(item['labels'] + [-100] * pad_len) attention_mask.append(item['attention_mask'] + [0] * pad_len) return { 'input_ids': torch.tensor(input_ids), 'labels': torch.tensor(labels), 'attention_mask': torch.tensor(attention_mask), } trainer = transformers.Trainer( model=lora_model, train_dataset=train_data, eval_dataset=eval_data, args=transformers.TrainingArguments( report_to=report_to if report_to != "None" else "none", per_device_train_batch_size=micro_batch_size, gradient_accumulation_steps=gradient_accumulation_steps, warmup_steps=math.ceil(warmup_steps / gradient_accumulation_steps), num_train_epochs=epochs, learning_rate=actual_lr, fp16=False if shared.args.cpu or shared.args.bf16 else True, bf16=shared.args.bf16, optim=optimizer, logging_steps=1, eval_strategy="steps" if eval_data is not None else "no", eval_steps=math.ceil(eval_steps / gradient_accumulation_steps) if eval_data is not None else None, save_strategy="steps" if save_steps > 0 or eval_data is not None else "no", save_steps=actual_save_steps if save_steps > 0 else None, output_dir=lora_file_path, lr_scheduler_type=lr_scheduler_type, load_best_model_at_end=eval_data is not None, # TODO: Enable multi-device support ddp_find_unused_parameters=None, use_cpu=shared.args.cpu, remove_unused_columns=False, ), data_collator=collate_fn, callbacks=[Callbacks()] ) # == Save parameters for reuse == with open(f"{lora_file_path}/training_parameters.json", 'w', encoding='utf-8') as file: local_vars = locals() json.dump({x: local_vars[x] for x in PARAMETERS}, file, indent=2) # == Save training prompt == with open(f"{lora_file_path}/training_prompt.json", 'w', encoding='utf-8') as file: json.dump(train_template, file, indent=2) # == Main run and monitor loop == logger.info("Starting training") yield "Starting..." lora_trainable_param, lora_all_param = calc_trainable_parameters(lora_model) if target_modules == "all-linear": projections_string = "all-linear" else: projections_string = ", ".join([projection.replace("_proj", "") for projection in target_modules]) print(f"Training '{model_type}' model using ({projections_string}) projections") if lora_all_param > 0: print(f"Trainable params: {lora_trainable_param:,d} ({100 * lora_trainable_param / lora_all_param:.4f} %), All params: {lora_all_param:,d} (Model: {model_all_params:,d})") train_log.update({"base_model_name": shared.model_name}) train_log.update({"base_model_class": shared.model.__class__.__name__}) train_log.update({"base_loaded_in_4bit": getattr(lora_model, "is_loaded_in_4bit", False)}) train_log.update({"base_loaded_in_8bit": getattr(lora_model, "is_loaded_in_8bit", False)}) train_log.update({"projections": projections_string}) if stop_at_loss > 0: print(f"Monitoring loss \033[1;31;1m(Auto-Stop at: {stop_at_loss})\033[0;37;0m") if WANT_INTERRUPT: yield "Interrupted before start." return def log_train_dataset(trainer): decoded_entries = [] # Try to decode the entries and write the log file try: # Iterate over the first 10 elements in the dataset (or fewer if there are less than 10) for i in range(min(10, len(trainer.train_dataset))): decoded_text = shared.tokenizer.decode(trainer.train_dataset[i]['input_ids']) decoded_entries.append({"value": decoded_text}) # Write the log file (shared.user_data_dir / 'logs').mkdir(exist_ok=True) with open(shared.user_data_dir / 'logs' / 'train_dataset_sample.json', 'w') as json_file: json.dump(decoded_entries, json_file, indent=4) logger.info(f"Log file 'train_dataset_sample.json' created in the '{shared.user_data_dir}/logs' directory.") except Exception as e: logger.error(f"Failed to create log file due to error: {e}") thread_error = None def threaded_run(): nonlocal thread_error try: log_train_dataset(trainer) trainer.train(resume_from_checkpoint=resume_checkpoint) # Note: save in the thread in case the gradio thread breaks (eg browser closed) lora_model.save_pretrained(lora_file_path) tracked.did_save = True logger.info("LoRA training run is completed and saved.") # Save log with open(f"{lora_file_path}/training_log.json", 'w', encoding='utf-8') as file: json.dump(train_log, file, indent=2) except Exception as e: thread_error = e logger.error(f"Training error: {e}") thread = threading.Thread(target=threaded_run) thread.start() last_step = 0 start_time = time.perf_counter() while thread.is_alive(): time.sleep(0.5) if WANT_INTERRUPT: yield "Interrupting, please wait... *(Run will stop after the current training step completes.)*" elif tracked.current_steps != last_step: last_step = tracked.current_steps time_elapsed = time.perf_counter() - start_time if time_elapsed <= 0: timer_info = "" total_time_estimate = 999 else: its = tracked.current_steps / time_elapsed if its > 1: timer_info = f"`{its:.2f}` it/s" else: timer_info = f"`{1.0/its:.2f}` s/it" total_time_estimate = (1.0 / its) * (tracked.max_steps) yield f"Running... **{tracked.current_steps}** / **{tracked.max_steps}** ... {timer_info}, {format_time(time_elapsed)} / {format_time(total_time_estimate)} ... {format_time(total_time_estimate - time_elapsed)} remaining" # Check for errors from the training thread if thread_error is not None: yield f"Training failed: {thread_error}" return # Saving in the train thread might fail if an error occurs, so save here if so. if not tracked.did_save: logger.info("Training complete, saving") lora_model.save_pretrained(lora_file_path) # Restore the original chat_template if we changed it for training if shared.tokenizer is not None and hasattr(shared.tokenizer, 'chat_template'): shared.tokenizer.chat_template = original_chat_template if WANT_INTERRUPT: logger.info("Training interrupted.") yield f"Interrupted. Incomplete LoRA saved to `{lora_file_path}`." else: logger.info("Training complete!") yield f"Done! LoRA saved to `{lora_file_path}`.\n\nBefore testing your new LoRA, make sure to first reload the model, as it is currently dirty from training." def format_time(seconds: float): if seconds < 120: return f"`{seconds:.0f}` seconds" minutes = seconds / 60 if minutes < 120: return f"`{minutes:.0f}` minutes" hours = minutes / 60 return f"`{hours:.0f}` hours" ================================================ FILE: modules/transformers_loader.py ================================================ import pprint from pathlib import Path import torch import torch.nn.functional as F import transformers from accelerate import infer_auto_device_map, init_empty_weights from accelerate.utils import is_xpu_available from transformers import ( AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, BitsAndBytesConfig, LogitsProcessor ) import modules.shared as shared from modules.logging_colors import logger from modules.text_generation import get_reply_from_output_ids from modules.torch_utils import get_device transformers.logging.set_verbosity_error() class _StopEverythingStoppingCriteria(transformers.StoppingCriteria): def __init__(self): transformers.StoppingCriteria.__init__(self) def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor) -> bool: return shared.stop_everything class Stream(transformers.StoppingCriteria): def __init__(self, callback_func=None): self.callback_func = callback_func def __call__(self, input_ids, scores) -> bool: if self.callback_func is not None: self.callback_func(input_ids[0]) return False class LogitsBiasProcessor(LogitsProcessor): def __init__(self, logit_bias={}): self.logit_bias = logit_bias if self.logit_bias: self.keys = list([int(key) for key in self.logit_bias.keys()]) values = [self.logit_bias[str(key)] for key in self.keys] self.values = torch.tensor(values, dtype=torch.float, device=shared.model.device) def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor: if self.logit_bias: logits[0, self.keys] += self.values return logits def __repr__(self): return f"<{self.__class__.__name__}(logit_bias={self.logit_bias})>" class LogprobProcessor(LogitsProcessor): def __init__(self, logprobs=None): self.logprobs = logprobs self.token_alternatives = {} self.token_alternatives_history = [] def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor: if self.logprobs is not None: # 0-5 log_e_probabilities = F.log_softmax(logits, dim=1) top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs) top_tokens = [get_reply_from_output_ids([tok]) for tok in top_indices[0]] top_probs = [float(x) for x in top_values[0]] self.token_alternatives = dict(zip(top_tokens, top_probs)) self.token_alternatives_history.append(self.token_alternatives) return logits def __repr__(self): return f"<{self.__class__.__name__}(logprobs={self.logprobs}, token_alternatives={self.token_alternatives})>" def load_tokenizer(model_name, tokenizer_dir=None): if tokenizer_dir: path_to_model = Path(tokenizer_dir) else: path_to_model = Path(f"{shared.args.model_dir}/{model_name}/") tokenizer = None if path_to_model.exists(): if shared.args.no_use_fast: logger.info('Loading the tokenizer with use_fast=False.') tokenizer = AutoTokenizer.from_pretrained( path_to_model, trust_remote_code=shared.original_args.trust_remote_code, use_fast=not shared.args.no_use_fast ) return tokenizer def load_model_HF(model_name): torch._dynamo.config.disable = True path_to_model = Path(f'{shared.args.model_dir}/{model_name}') params = { 'low_cpu_mem_usage': True, 'attn_implementation': shared.args.attn_implementation, 'torch_dtype': torch.bfloat16 if shared.args.bf16 else torch.float16, } if shared.original_args.trust_remote_code: params['trust_remote_code'] = True if shared.args.force_safetensors: params['force_safetensors'] = True config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=shared.original_args.trust_remote_code) if 'chatglm' in model_name.lower(): LoaderClass = AutoModel else: if config.to_dict().get('is_encoder_decoder', False): LoaderClass = AutoModelForSeq2SeqLM shared.is_seq2seq = True else: LoaderClass = AutoModelForCausalLM # Determine if we should use default loading should_use_default_loading = not any([ shared.args.cpu, shared.args.load_in_8bit, shared.args.load_in_4bit, shared.args.disk, shared.args.cpu_memory is not None, ]) # Load the model without any special settings if should_use_default_loading: params['device_map'] = 'auto' logger.info("TRANSFORMERS_PARAMS=") pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(params) print() model = LoaderClass.from_pretrained(path_to_model, **params) if not (hasattr(model, 'is_loaded_in_4bit') and model.is_loaded_in_4bit): device = get_device() if device: model = model.to(device) # Load with quantization and/or offloading else: if not any((shared.args.cpu, torch.cuda.is_available(), is_xpu_available(), torch.backends.mps.is_available())): logger.warning('torch.cuda.is_available() and is_xpu_available() returned False. This means that no GPU has been detected. Falling back to CPU mode.') shared.args.cpu = True if shared.args.cpu: params['torch_dtype'] = torch.float32 else: params['device_map'] = 'auto' if x := get_max_memory_dict(): params['max_memory'] = x if shared.args.load_in_4bit: # See https://github.com/huggingface/transformers/pull/23479/files # and https://huggingface.co/blog/4bit-transformers-bitsandbytes quantization_config_params = { 'load_in_4bit': True, 'bnb_4bit_compute_dtype': eval(f"torch.{shared.args.compute_dtype}") if shared.args.compute_dtype in ["bfloat16", "float16", "float32"] else None, 'bnb_4bit_quant_type': shared.args.quant_type, 'bnb_4bit_use_double_quant': shared.args.use_double_quant, 'llm_int8_enable_fp32_cpu_offload': True } params['quantization_config'] = BitsAndBytesConfig(**quantization_config_params) elif shared.args.load_in_8bit: if shared.args.gpu_split: params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True) else: params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True) if params.get('max_memory') is not None: with init_empty_weights(): model = LoaderClass.from_config(config, trust_remote_code=params.get('trust_remote_code')) model.tie_weights() params['device_map'] = infer_auto_device_map( model, dtype=torch.int8, max_memory=params.get('max_memory'), no_split_module_classes=model._no_split_modules ) if shared.args.disk: params['offload_folder'] = str(Path(shared.args.disk_cache_dir)) logger.info("TRANSFORMERS_PARAMS=") pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(params) print() model = LoaderClass.from_pretrained(path_to_model, **params) return model def get_max_memory_dict(): max_memory = {} if shared.args.cpu_memory > 0: max_memory['cpu'] = f'{shared.args.cpu_memory}GiB' if shared.args.gpu_split: for i, memory in enumerate(shared.args.gpu_split.split(',')): max_memory[i] = f'{memory}GiB' return max_memory if len(max_memory) > 0 else None ================================================ FILE: modules/ui.py ================================================ import copy import threading from pathlib import Path import gradio as gr import yaml import extensions import modules.extensions as extensions_module from modules import shared from modules.chat import load_history from modules.utils import gradio # Global state for auto-saving UI settings with debouncing _auto_save_timer = None _auto_save_lock = threading.Lock() _last_interface_state = None _last_preset = None _last_extensions = None _last_show_controls = None _last_theme_state = None with open(Path(__file__).resolve().parent / '../css/NotoSans/stylesheet.css', 'r', encoding='utf-8') as f: css = f.read() with open(Path(__file__).resolve().parent / '../css/main.css', 'r', encoding='utf-8') as f: css += f.read() with open(Path(__file__).resolve().parent / '../css/katex/katex.min.css', 'r', encoding='utf-8') as f: css += f.read() with open(Path(__file__).resolve().parent / '../css/highlightjs/highlightjs-copy.min.css', 'r', encoding='utf-8') as f: css += f.read() with open(Path(__file__).resolve().parent / '../js/main.js', 'r', encoding='utf-8') as f: js = f.read() with open(Path(__file__).resolve().parent / '../js/global_scope_js.js', 'r', encoding='utf-8') as f: global_scope_js = f.read() with open(Path(__file__).resolve().parent / '../js/save_files.js', 'r', encoding='utf-8') as f: save_files_js = f.read() with open(Path(__file__).resolve().parent / '../js/switch_tabs.js', 'r', encoding='utf-8') as f: switch_tabs_js = f.read() with open(Path(__file__).resolve().parent / '../js/show_controls.js', 'r', encoding='utf-8') as f: show_controls_js = f.read() with open(Path(__file__).resolve().parent / '../js/update_big_picture.js', 'r', encoding='utf-8') as f: update_big_picture_js = f.read() with open(Path(__file__).resolve().parent / '../js/dark_theme.js', 'r', encoding='utf-8') as f: dark_theme_js = f.read() refresh_symbol = '🔄' delete_symbol = '🗑️' save_symbol = '💾' theme = gr.themes.Default( font=['Noto Sans', 'Helvetica', 'ui-sans-serif', 'system-ui', 'sans-serif'], font_mono=['IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'], ).set( border_color_primary='#c5c5d2', button_large_padding='6px 12px', body_text_color_subdued='#484848', background_fill_secondary='#eaeaea', background_fill_primary='var(--neutral-50)', body_background_fill="white", block_background_fill="#f4f4f4", body_text_color="#333", button_secondary_background_fill="#f4f4f4", button_secondary_border_color="var(--border-color-primary)" ) if not shared.args.old_colors: theme = theme.set( # General Colors border_color_primary='#c5c5d2', body_text_color_subdued='#484848', background_fill_secondary='#eaeaea', background_fill_secondary_dark='var(--selected-item-color-dark, #282930)', background_fill_primary='var(--neutral-50)', background_fill_primary_dark='var(--darker-gray, #1C1C1D)', body_background_fill="white", block_background_fill="transparent", body_text_color='rgb(64, 64, 64)', button_secondary_background_fill="white", button_secondary_border_color="var(--border-color-primary)", input_shadow="none", button_shadow_hover="none", # Dark Mode Colors input_background_fill_dark='var(--darker-gray, #1C1C1D)', checkbox_background_color_dark='var(--darker-gray, #1C1C1D)', block_background_fill_dark='transparent', block_border_color_dark='transparent', input_border_color_dark='var(--border-color-dark, #525252)', input_border_color_focus_dark='var(--border-color-dark, #525252)', checkbox_border_color_dark='var(--border-color-dark, #525252)', border_color_primary_dark='var(--border-color-dark, #525252)', button_secondary_border_color_dark='var(--border-color-dark, #525252)', body_background_fill_dark='var(--dark-gray, #212125)', button_primary_background_fill_dark='transparent', button_secondary_background_fill_dark='transparent', checkbox_label_background_fill_dark='transparent', button_cancel_background_fill_dark='transparent', button_secondary_background_fill_hover_dark='var(--selected-item-color-dark, #282930)', checkbox_label_background_fill_hover_dark='var(--selected-item-color-dark, #282930)', table_even_background_fill_dark='var(--darker-gray, #1C1C1D)', table_odd_background_fill_dark='var(--selected-item-color-dark, #282930)', code_background_fill_dark='var(--darker-gray, #1C1C1D)', # Shadows and Radius checkbox_label_shadow='none', block_shadow='none', block_shadow_dark='none', input_shadow_focus='none', input_shadow_focus_dark='none', button_large_radius='0.375rem', button_large_padding='6px 12px', input_radius='0.375rem', block_radius='0', ) if (shared.user_data_dir / "notification.mp3").exists(): audio_notification_js = "document.querySelector('#audio_notification audio')?.play();" else: audio_notification_js = "" def list_model_elements(): from modules.loaders import list_model_elements return list_model_elements() def list_interface_input_elements(): elements = [ 'temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', 'smoothing_curve', 'min_p', 'top_p', 'top_k', 'typical_p', 'xtc_threshold', 'xtc_probability', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'top_n_sigma', 'adaptive_target', 'adaptive_decay', 'dry_multiplier', 'dry_allowed_length', 'dry_base', 'repetition_penalty', 'frequency_penalty', 'presence_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'repetition_penalty_range', 'penalty_alpha', 'guidance_scale', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'max_new_tokens', 'prompt_lookup_num_tokens', 'max_tokens_second', 'do_sample', 'dynamic_temperature', 'temperature_last', 'auto_max_new_tokens', 'ban_eos_token', 'add_bos_token', 'enable_thinking', 'reasoning_effort', 'skip_special_tokens', 'stream', 'static_cache', 'truncation_length', 'seed', 'sampler_priority', 'custom_stopping_strings', 'custom_token_bans', 'negative_prompt', 'dry_sequence_breakers', 'grammar_string', 'navigate_message_index', 'navigate_direction', 'navigate_message_role', 'edit_message_index', 'edit_message_text', 'edit_message_role', 'branch_index', 'enable_web_search', 'web_search_pages', ] # Chat elements elements += [ 'history', 'search_chat', 'unique_id', 'textbox', 'start_with', 'selected_tools', 'mode', 'chat_style', 'chat-instruct_command', 'character_menu', 'user_menu', 'name2', 'context', 'greeting', 'name1', 'user_bio', 'custom_system_message', 'instruction_template_str', 'chat_template_str', ] # Notebook/default elements elements += [ 'textbox-default', 'textbox-notebook', 'prompt_menu-default', 'prompt_menu-notebook', 'output_textbox', ] # Model elements elements += list_model_elements() # Other elements elements += [ 'show_two_notebook_columns', 'paste_to_attachment', 'include_past_attachments', ] if not shared.args.portable: # Image generation elements elements += [ 'image_prompt', 'image_neg_prompt', 'image_width', 'image_height', 'image_aspect_ratio', 'image_steps', 'image_cfg_scale', 'image_seed', 'image_batch_size', 'image_batch_count', 'image_llm_variations', 'image_llm_variations_prompt', 'image_model_menu', 'image_dtype', 'image_attn_backend', 'image_compile', 'image_cpu_offload', 'image_quant', ] return elements def gather_interface_values(*args): interface_elements = list_interface_input_elements() output = {} for element, value in zip(interface_elements, args): output[element] = value if not shared.args.multi_user: shared.persistent_interface_state = output # Remove the chat input, as it gets cleared after this function call shared.persistent_interface_state.pop('textbox') # Prevent history loss if backend is restarted but UI is not refreshed if (output['history'] is None or (len(output['history'].get('visible', [])) == 0 and len(output['history'].get('internal', [])) == 0)) and output['unique_id'] is not None: output['history'] = load_history(output['unique_id'], output['character_menu'], output['mode']) return output def apply_interface_values(state, use_persistent=False): if use_persistent: state = shared.persistent_interface_state if 'textbox-default' in state and 'prompt_menu-default' in state: state.pop('prompt_menu-default') if 'textbox-notebook' in state and 'prompt_menu-notebook' in state: state.pop('prompt_menu-notebook') elements = list_interface_input_elements() if len(state) == 0: return [gr.update() for k in elements] # Dummy, do nothing else: return [state[k] if k in state else gr.update() for k in elements] def save_settings(state, preset, extensions_list, show_controls, theme_state, manual_save=False): output = copy.deepcopy(shared.settings) exclude = [] for k in state: if k in shared.settings and k not in exclude: output[k] = state[k] if preset: output['preset'] = preset output['prompt-notebook'] = state['prompt_menu-default'] if state['show_two_notebook_columns'] else state['prompt_menu-notebook'] if state.get('character_menu'): output['character'] = state['character_menu'] if state.get('user_menu'): output['user'] = state['user_menu'] output['seed'] = int(output['seed']) output['custom_stopping_strings'] = output.get('custom_stopping_strings') or '' output['custom_token_bans'] = output.get('custom_token_bans') or '' output['show_controls'] = show_controls output['dark_theme'] = True if theme_state == 'dark' else False output.pop('instruction_template_str') output.pop('truncation_length') # Handle extensions and extension parameters if manual_save: # Save current extensions and their parameter values output['default_extensions'] = extensions_list for extension_name in extensions_list: extension = getattr(extensions, extension_name, None) if extension: extension = extension.script if hasattr(extension, 'params'): params = getattr(extension, 'params') for param in params: _id = f"{extension_name}-{param}" # Only save if different from default value if param not in shared.default_settings or params[param] != shared.default_settings[param]: output[_id] = params[param] else: # Preserve existing extensions and extension parameters during autosave settings_path = shared.user_data_dir / 'settings.yaml' if settings_path.exists(): try: with open(settings_path, 'r', encoding='utf-8') as f: existing_settings = yaml.safe_load(f.read()) or {} # Preserve default_extensions if 'default_extensions' in existing_settings: output['default_extensions'] = existing_settings['default_extensions'] # Preserve extension parameter values for key, value in existing_settings.items(): if any(key.startswith(f"{ext_name}-") for ext_name in extensions_module.available_extensions): output[key] = value except Exception: pass # If we can't read the file, just don't modify extensions # Do not save unchanged settings for key in list(output.keys()): if key in shared.default_settings and output[key] == shared.default_settings[key]: output.pop(key) return yaml.dump(output, sort_keys=False, width=float("inf"), allow_unicode=True) def store_current_state_and_debounce(interface_state, preset, extensions, show_controls, theme_state): """Store current state and trigger debounced save""" global _auto_save_timer, _last_interface_state, _last_preset, _last_extensions, _last_show_controls, _last_theme_state if shared.args.multi_user: return # Store the current state in global variables _last_interface_state = interface_state _last_preset = preset _last_extensions = extensions _last_show_controls = show_controls _last_theme_state = theme_state # Reset the debounce timer with _auto_save_lock: if _auto_save_timer is not None: _auto_save_timer.cancel() _auto_save_timer = threading.Timer(1.0, _perform_debounced_save) _auto_save_timer.start() def _perform_debounced_save(): """Actually perform the save using the stored state""" global _auto_save_timer try: if _last_interface_state is not None: contents = save_settings(_last_interface_state, _last_preset, _last_extensions, _last_show_controls, _last_theme_state, manual_save=False) settings_path = shared.user_data_dir / 'settings.yaml' settings_path.parent.mkdir(exist_ok=True) with open(settings_path, 'w', encoding='utf-8') as f: f.write(contents) except Exception as e: print(f"Auto-save failed: {e}") finally: with _auto_save_lock: _auto_save_timer = None def setup_auto_save(): """Attach auto-save to key UI elements""" if shared.args.multi_user: return change_elements = [ # Chat tab (ui_chat.py) 'start_with', 'enable_web_search', 'web_search_pages', 'mode', 'chat_style', 'chat-instruct_command', 'character_menu', 'user_menu', 'name1', 'name2', 'context', 'greeting', 'user_bio', 'custom_system_message', 'chat_template_str', 'selected_tools', # Parameters tab (ui_parameters.py) - Generation parameters 'preset_menu', 'temperature', 'dynatemp_low', 'dynatemp_high', 'dynatemp_exponent', 'smoothing_factor', 'smoothing_curve', 'min_p', 'top_p', 'top_k', 'typical_p', 'xtc_threshold', 'xtc_probability', 'epsilon_cutoff', 'eta_cutoff', 'tfs', 'top_a', 'top_n_sigma', 'adaptive_target', 'adaptive_decay', 'dry_multiplier', 'dry_allowed_length', 'dry_base', 'repetition_penalty', 'frequency_penalty', 'presence_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'repetition_penalty_range', 'penalty_alpha', 'guidance_scale', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'max_new_tokens', 'prompt_lookup_num_tokens', 'max_tokens_second', 'do_sample', 'dynamic_temperature', 'temperature_last', 'auto_max_new_tokens', 'ban_eos_token', 'add_bos_token', 'enable_thinking', 'reasoning_effort', 'skip_special_tokens', 'stream', 'static_cache', 'seed', 'sampler_priority', 'custom_stopping_strings', 'custom_token_bans', 'negative_prompt', 'dry_sequence_breakers', 'grammar_string', # Default tab (ui_default.py) 'prompt_menu-default', # Notebook tab (ui_notebook.py) 'prompt_menu-notebook', # Session tab (ui_session.py) 'show_controls', 'theme_state', 'show_two_notebook_columns', 'paste_to_attachment', 'include_past_attachments', ] if not shared.args.portable: # Image generation tab (ui_image_generation.py) change_elements += [ 'image_prompt', 'image_neg_prompt', 'image_width', 'image_height', 'image_aspect_ratio', 'image_steps', 'image_cfg_scale', 'image_seed', 'image_batch_size', 'image_batch_count', 'image_llm_variations', 'image_llm_variations_prompt', 'image_model_menu', 'image_dtype', 'image_attn_backend', 'image_compile', 'image_cpu_offload', 'image_quant', ] for element_name in change_elements: if element_name in shared.gradio: shared.gradio[element_name].change( gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( store_current_state_and_debounce, gradio('interface_state', 'preset_menu', 'extensions_menu', 'show_controls', 'theme_state'), None, show_progress=False) def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_class, interactive=True): """ Copied from https://github.com/AUTOMATIC1111/stable-diffusion-webui """ def refresh(): refresh_method() args = refreshed_args() if callable(refreshed_args) else refreshed_args return gr.update(**(args or {})) refresh_button = gr.Button(refresh_symbol, elem_classes=elem_class, interactive=interactive) refresh_button.click( fn=lambda: {k: tuple(v) if type(k) is list else v for k, v in refresh().items()}, inputs=[], outputs=[refresh_component] ) return refresh_button ================================================ FILE: modules/ui_chat.py ================================================ import json from functools import partial from pathlib import Path import gradio as gr from PIL import Image from modules import chat, shared, ui, utils from modules.html_generator import chat_html_wrapper from modules.text_generation import stop_everything_event from modules.utils import gradio inputs = ('Chat input', 'interface_state') reload_arr = ('history', 'name1', 'name2', 'mode', 'chat_style', 'character_menu') def create_ui(): mu = shared.args.multi_user shared.gradio['Chat input'] = gr.State() shared.gradio['history'] = gr.State({'internal': [], 'visible': [], 'metadata': {}}) shared.gradio['display'] = gr.Headless(value={}) with gr.Tab('Chat', elem_id='chat-tab'): with gr.Row(elem_id='past-chats-row', elem_classes=['pretty_scrollbar']): with gr.Column(): with gr.Row(elem_id='past-chats-buttons'): shared.gradio['branch_chat'] = gr.Button('Branch', elem_classes=['refresh-button', 'refresh-button-medium'], elem_id='Branch', interactive=not mu) shared.gradio['rename_chat'] = gr.Button('Rename', elem_classes=['refresh-button', 'refresh-button-medium'], interactive=not mu) shared.gradio['delete_chat'] = gr.Button('🗑️', visible=False, elem_classes='refresh-button', interactive=not mu, elem_id='delete_chat') shared.gradio['Start new chat'] = gr.Button('New chat', elem_classes=['refresh-button', 'refresh-button-medium', 'focus-on-chat-input'], elem_id='new-chat-btn') shared.gradio['Start incognito chat'] = gr.Button('Incognito chat', visible=False, elem_id='incognito-chat-btn') shared.gradio['branch_index'] = gr.Number(value=-1, precision=0, visible=False, elem_id="Branch-index", interactive=True) shared.gradio['search_chat'] = gr.Textbox(placeholder='Search chats...', max_lines=1, elem_id='search_chat') with gr.Row(elem_id='delete-chat-row', visible=False) as shared.gradio['delete-chat-row']: shared.gradio['delete_chat-cancel'] = gr.Button('Cancel', elem_classes=['refresh-button', 'focus-on-chat-input'], elem_id='delete_chat-cancel') shared.gradio['delete_chat-confirm'] = gr.Button('Confirm', variant='stop', elem_classes=['refresh-button', 'focus-on-chat-input'], elem_id='delete_chat-confirm') with gr.Row(elem_id='rename-row', visible=False) as shared.gradio['rename-row']: shared.gradio['rename_to'] = gr.Textbox(label='Rename to:', placeholder='New name', elem_classes=['no-background']) with gr.Row(): shared.gradio['rename_to-cancel'] = gr.Button('Cancel', elem_classes=['refresh-button', 'focus-on-chat-input']) shared.gradio['rename_to-confirm'] = gr.Button('Confirm', elem_classes=['refresh-button', 'focus-on-chat-input'], variant='primary') with gr.Row(): shared.gradio['unique_id'] = gr.Radio(label="", elem_classes=['slim-dropdown', 'pretty_scrollbar'], interactive=not mu, elem_id='past-chats') with gr.Row(): with gr.Column(elem_id='chat-col'): shared.gradio['html_display'] = gr.HTML(value=chat_html_wrapper({'internal': [], 'visible': [], 'metadata': {}}, '', '', 'chat', 'cai-chat', '')['html'], visible=True) with gr.Row(elem_id="chat-input-row"): with gr.Column(scale=1, elem_id='gr-hover-container'): gr.HTML(value='
      ', elem_id='gr-hover') with gr.Column(scale=10, elem_id='chat-input-container'): shared.gradio['textbox'] = gr.MultimodalTextbox(label='', placeholder='Send a message', file_types=['text', '.pdf', 'image'], file_count="multiple", elem_id='chat-input', elem_classes=['add_scrollbar']) shared.gradio['typing-dots'] = gr.HTML(value='
      ', label='typing', elem_id='typing-container') with gr.Column(scale=1, elem_id='generate-stop-container'): with gr.Row(): shared.gradio['Stop'] = gr.Button('Stop', elem_id='stop', visible=False) shared.gradio['Generate'] = gr.Button('Send', elem_id='Generate', variant='primary') # Hover menu buttons with gr.Column(elem_id='chat-buttons'): shared.gradio['Regenerate'] = gr.Button('Regenerate (Ctrl + Enter)', elem_id='Regenerate') shared.gradio['Continue'] = gr.Button('Continue (Alt + Enter)', elem_id='Continue') shared.gradio['Remove last'] = gr.Button('Remove last reply (Ctrl + Shift + Backspace)', elem_id='Remove-last') shared.gradio['Impersonate'] = gr.Button('Impersonate (Ctrl + Shift + M)', elem_id='Impersonate') shared.gradio['Send dummy message'] = gr.Button('Send dummy message') shared.gradio['Send dummy reply'] = gr.Button('Send dummy reply') shared.gradio['send-chat-to-notebook'] = gr.Button('Send to Notebook') shared.gradio['show_controls'] = gr.Checkbox(value=shared.settings['show_controls'], label='Show controls (Ctrl+S)', elem_id='show-controls') with gr.Row(elem_id='chat-controls', elem_classes=['pretty_scrollbar']): with gr.Column(): with gr.Row(): shared.gradio['start_with'] = gr.Textbox(label='Start reply with', placeholder='Sure thing!', value=shared.settings['start_with'], elem_classes=['add_scrollbar']) gr.HTML("") shared.gradio['reasoning_effort'] = gr.Dropdown(value=shared.settings['reasoning_effort'], choices=['low', 'medium', 'high'], label='Reasoning effort', info='Used by GPT-OSS.') shared.gradio['enable_thinking'] = gr.Checkbox(value=shared.settings['enable_thinking'], label='Enable thinking', info='Used by Seed-OSS and pre-2507 Qwen3.') gr.HTML("") shared.gradio['enable_web_search'] = gr.Checkbox(value=shared.settings.get('enable_web_search', False), label='Activate web search', elem_id='web-search') with gr.Row(visible=shared.settings.get('enable_web_search', False)) as shared.gradio['web_search_row']: shared.gradio['web_search_pages'] = gr.Number(value=shared.settings.get('web_search_pages', 3), precision=0, label='Number of pages to download', minimum=1, maximum=10) gr.HTML("") from modules.tool_use import get_available_tools shared.gradio['selected_tools'] = gr.CheckboxGroup(choices=get_available_tools(), value=shared.settings.get('selected_tools', []), label='Tools', info='Functions the model can call during generation.', elem_id='tools-group') shared.gradio['tools_refresh'] = gr.Button('Refresh list', elem_id='tools-refresh-btn', visible=False) shared.gradio['tools_refresh'].click(fn=lambda: gr.update(choices=get_available_tools()), inputs=[], outputs=[shared.gradio['selected_tools']]) def sync_web_tools(selected): if 'web_search' in selected and 'fetch_webpage' not in selected and 'fetch_webpage' in get_available_tools(): selected.append('fetch_webpage') return gr.update(value=selected) shared.gradio['selected_tools'].change(fn=sync_web_tools, inputs=[shared.gradio['selected_tools']], outputs=[shared.gradio['selected_tools']], show_progress=False) gr.HTML("") with gr.Row(): shared.gradio['mode'] = gr.Radio(choices=['instruct', 'chat-instruct', 'chat'], value=None, label='Mode', info='In instruct and chat-instruct modes, the template under Parameters > Instruction template is used.', elem_id='chat-mode') with gr.Row(): shared.gradio['chat_style'] = gr.Dropdown(choices=utils.get_available_chat_styles(), label='Chat style', value=shared.settings['chat_style'], visible=shared.settings['mode'] != 'instruct') with gr.Row(): shared.gradio['chat-instruct_command'] = gr.Textbox(value=shared.settings['chat-instruct_command'], lines=12, label='Command for chat-instruct mode', info='<|character|> and <|prompt|> get replaced with the bot name and the regular chat prompt respectively.', visible=shared.settings['mode'] == 'chat-instruct', elem_classes=['add_scrollbar']) gr.HTML("") with gr.Row(): shared.gradio['count_tokens'] = gr.Button('Count tokens', size='sm') shared.gradio['token_display'] = gr.HTML(value='', elem_classes='token-display') # Hidden elements for version navigation and editing with gr.Row(visible=False): shared.gradio['navigate_message_index'] = gr.Number(value=-1, precision=0, elem_id="Navigate-message-index") shared.gradio['navigate_direction'] = gr.Textbox(value="", elem_id="Navigate-direction") shared.gradio['navigate_message_role'] = gr.Textbox(value="", elem_id="Navigate-message-role") shared.gradio['navigate_version'] = gr.Button(elem_id="Navigate-version") shared.gradio['edit_message_index'] = gr.Number(value=-1, precision=0, elem_id="Edit-message-index") shared.gradio['edit_message_text'] = gr.Textbox(value="", elem_id="Edit-message-text") shared.gradio['edit_message_role'] = gr.Textbox(value="", elem_id="Edit-message-role") shared.gradio['edit_message'] = gr.Button(elem_id="Edit-message") def create_character_settings_ui(): mu = shared.args.multi_user with gr.Tab('Character', elem_id="character-tab"): with gr.Row(): with gr.Column(scale=8): with gr.Tab("Character"): with gr.Row(): shared.gradio['character_menu'] = gr.Dropdown(value=shared.settings['character'], choices=utils.get_available_characters(), label='Character', elem_id='character-menu', info='Used in chat and chat-instruct modes.', elem_classes='slim-dropdown') ui.create_refresh_button(shared.gradio['character_menu'], lambda: None, lambda: {'choices': utils.get_available_characters()}, 'refresh-button', interactive=not mu) shared.gradio['save_character'] = gr.Button('💾', elem_classes='refresh-button', elem_id="save-character", interactive=not mu) shared.gradio['delete_character'] = gr.Button('🗑️', elem_classes='refresh-button', interactive=not mu) shared.gradio['restore_character'] = gr.Button('Restore character', elem_classes='refresh-button', interactive=True, elem_id='restore-character') shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Character\'s name') shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=10, label='Context', elem_classes=['add_scrollbar'], elem_id="character-context") shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=5, label='Greeting', elem_classes=['add_scrollbar'], elem_id="character-greeting") with gr.Tab("User"): with gr.Row(): shared.gradio['user_menu'] = gr.Dropdown(value=shared.settings['user'], choices=utils.get_available_users(), label='User', elem_id='user-menu', info='Select a user profile.', elem_classes='slim-dropdown') ui.create_refresh_button(shared.gradio['user_menu'], lambda: None, lambda: {'choices': utils.get_available_users()}, 'refresh-button', interactive=not mu) shared.gradio['save_user'] = gr.Button('💾', elem_classes='refresh-button', elem_id="save-user", interactive=not mu) shared.gradio['delete_user'] = gr.Button('🗑️', elem_classes='refresh-button', interactive=not mu) shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Name') shared.gradio['user_bio'] = gr.Textbox(value=shared.settings['user_bio'], lines=10, label='Description', info='Here you can optionally write a description of yourself.', placeholder='{{user}}\'s personality: ...', elem_classes=['add_scrollbar'], elem_id="user-description") with gr.Tab('Chat history'): with gr.Row(): with gr.Column(): shared.gradio['save_chat_history'] = gr.Button(value='Save history') with gr.Column(): shared.gradio['load_chat_history'] = gr.File(type='binary', file_types=['.json', '.txt'], label='Upload History JSON') with gr.Tab('Upload character'): with gr.Tab('YAML or JSON'): with gr.Row(): shared.gradio['upload_json'] = gr.File(type='binary', file_types=['.json', '.yaml'], label='JSON or YAML File', interactive=not mu) shared.gradio['upload_img_bot'] = gr.Image(type='filepath', label='Profile Picture (optional)', interactive=not mu) shared.gradio['Submit character'] = gr.Button(value='Submit', interactive=False) with gr.Tab('TavernAI PNG'): with gr.Row(): with gr.Column(): shared.gradio['upload_img_tavern'] = gr.Image(type='filepath', label='TavernAI PNG File', elem_id='upload_img_tavern', interactive=not mu) shared.gradio['tavern_json'] = gr.State() with gr.Column(): shared.gradio['tavern_name'] = gr.Textbox(value='', lines=1, label='Name', interactive=False) shared.gradio['tavern_desc'] = gr.Textbox(value='', lines=10, label='Description', interactive=False, elem_classes=['add_scrollbar']) shared.gradio['Submit tavern character'] = gr.Button(value='Submit', interactive=False) with gr.Column(scale=1): shared.gradio['character_picture'] = gr.Image(label='Character picture', type='filepath', interactive=not mu) shared.gradio['your_picture'] = gr.Image(label='Your picture', type='filepath', value=Image.open(shared.user_data_dir / 'cache' / 'pfp_me.png') if (shared.user_data_dir / 'cache' / 'pfp_me.png').exists() else None, interactive=not mu) def create_chat_settings_ui(): mu = shared.args.multi_user with gr.Tab('Instruction template'): with gr.Row(): with gr.Column(): with gr.Row(): shared.gradio['instruction_template'] = gr.Dropdown(choices=utils.get_available_instruction_templates(), label='Saved instruction templates', info="After selecting the template, click on \"Load\" to load and apply it.", value='None', elem_classes='slim-dropdown') ui.create_refresh_button(shared.gradio['instruction_template'], lambda: None, lambda: {'choices': utils.get_available_instruction_templates()}, 'refresh-button', interactive=not mu) shared.gradio['load_template'] = gr.Button("Load", elem_classes='refresh-button') shared.gradio['save_template'] = gr.Button('💾', elem_classes='refresh-button', interactive=not mu) shared.gradio['delete_template'] = gr.Button('🗑️ ', elem_classes='refresh-button', interactive=not mu) with gr.Column(): pass with gr.Row(): with gr.Column(): shared.gradio['instruction_template_str'] = gr.Textbox(value=shared.settings['instruction_template_str'], label='Instruction template', lines=24, info='This gets autodetected; you usually don\'t need to change it. Used in instruct and chat-instruct modes.', elem_classes=['add_scrollbar', 'monospace'], elem_id='instruction-template-str') with gr.Row(): shared.gradio['send_instruction_to_notebook'] = gr.Button('Send to notebook', elem_classes=['small-button']) with gr.Column(): shared.gradio['chat_template_str'] = gr.Textbox(value=shared.settings['chat_template_str'], label='Chat template', lines=22, elem_classes=['add_scrollbar', 'monospace'], info='Defines how the chat prompt in chat/chat-instruct modes is generated.', elem_id='chat-template-str') def create_event_handlers(): # Obsolete variables, kept for compatibility with old extensions shared.input_params = gradio(inputs) shared.reload_inputs = gradio(reload_arr) # Morph HTML updates instead of updating everything shared.gradio['display'].change(None, gradio('display'), None, js="(data) => handleMorphdomUpdate(data)") shared.gradio['Generate'].click( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( lambda x: (x, {"text": "", "files": []}), gradio('textbox'), gradio('Chat input', 'textbox'), show_progress=False).then( lambda: None, None, None, js='() => document.getElementById("chat").parentNode.parentNode.parentNode.classList.add("_generating")').then( chat.generate_chat_reply_wrapper, gradio(inputs), gradio('display', 'history'), show_progress=False).then( None, None, None, js='() => document.getElementById("chat").parentNode.parentNode.parentNode.classList.remove("_generating")').then( None, None, None, js=f'() => {{{ui.audio_notification_js}}}') shared.gradio['textbox'].submit( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( lambda x: (x, {"text": "", "files": []}), gradio('textbox'), gradio('Chat input', 'textbox'), show_progress=False).then( lambda: None, None, None, js='() => document.getElementById("chat").parentNode.parentNode.parentNode.classList.add("_generating")').then( chat.generate_chat_reply_wrapper, gradio(inputs), gradio('display', 'history'), show_progress=False).then( None, None, None, js='() => document.getElementById("chat").parentNode.parentNode.parentNode.classList.remove("_generating")').then( None, None, None, js=f'() => {{{ui.audio_notification_js}}}') shared.gradio['Regenerate'].click( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( lambda: None, None, None, js='() => document.getElementById("chat").parentNode.parentNode.parentNode.classList.add("_generating")').then( partial(chat.generate_chat_reply_wrapper, regenerate=True), gradio(inputs), gradio('display', 'history'), show_progress=False).then( None, None, None, js='() => document.getElementById("chat").parentNode.parentNode.parentNode.classList.remove("_generating")').then( None, None, None, js=f'() => {{{ui.audio_notification_js}}}') shared.gradio['Continue'].click( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( lambda: None, None, None, js='() => document.getElementById("chat").parentNode.parentNode.parentNode.classList.add("_generating")').then( partial(chat.generate_chat_reply_wrapper, _continue=True), gradio(inputs), gradio('display', 'history'), show_progress=False).then( None, None, None, js='() => document.getElementById("chat").parentNode.parentNode.parentNode.classList.remove("_generating")').then( None, None, None, js=f'() => {{{ui.audio_notification_js}}}') shared.gradio['Impersonate'].click( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( lambda x: x, gradio('textbox'), gradio('Chat input'), show_progress=False).then( lambda: None, None, None, js='() => document.getElementById("chat").parentNode.parentNode.parentNode.classList.add("_generating")').then( chat.impersonate_wrapper, gradio(inputs), gradio('textbox', 'display'), show_progress=False).then( None, None, None, js='() => document.getElementById("chat").parentNode.parentNode.parentNode.classList.remove("_generating")').then( None, None, None, js=f'() => {{{ui.audio_notification_js}}}') shared.gradio['Send dummy message'].click( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( chat.handle_send_dummy_message_click, gradio('textbox', 'interface_state'), gradio('history', 'display', 'textbox'), show_progress=False) shared.gradio['Send dummy reply'].click( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( chat.handle_send_dummy_reply_click, gradio('textbox', 'interface_state'), gradio('history', 'display', 'textbox'), show_progress=False) shared.gradio['Remove last'].click( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( chat.handle_remove_last_click, gradio('interface_state'), gradio('history', 'display', 'textbox'), show_progress=False) shared.gradio['Stop'].click( stop_everything_event, None, None, queue=False).then( chat.redraw_html, gradio(reload_arr), gradio('display'), show_progress=False) if not shared.args.multi_user: shared.gradio['unique_id'].select( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( chat.handle_unique_id_select, gradio('interface_state'), gradio('history', 'display'), show_progress=False) shared.gradio['Start new chat'].click( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( chat.handle_start_new_chat_click, gradio('interface_state'), gradio('history', 'display', 'unique_id'), show_progress=False) shared.gradio['Start incognito chat'].click( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( chat.handle_start_incognito_chat_click, gradio('interface_state'), gradio('history', 'display', 'unique_id'), show_progress=False) shared.gradio['delete_chat-confirm'].click( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( chat.handle_delete_chat_confirm_click, gradio('interface_state'), gradio('history', 'display', 'unique_id'), show_progress=False) shared.gradio['branch_chat'].click( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( chat.handle_branch_chat_click, gradio('interface_state'), gradio('history', 'display', 'unique_id', 'branch_index'), show_progress=False) shared.gradio['rename_chat'].click(chat.handle_rename_chat_click, None, gradio('rename_to', 'rename-row'), show_progress=False) shared.gradio['rename_to-cancel'].click(lambda: gr.update(visible=False), None, gradio('rename-row'), show_progress=False) shared.gradio['rename_to-confirm'].click( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( chat.handle_rename_chat_confirm, gradio('rename_to', 'interface_state'), gradio('unique_id', 'rename-row')) shared.gradio['rename_to'].submit( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( chat.handle_rename_chat_confirm, gradio('rename_to', 'interface_state'), gradio('unique_id', 'rename-row'), show_progress=False) shared.gradio['search_chat'].change( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( chat.handle_search_chat_change, gradio('interface_state'), gradio('unique_id'), show_progress=False) shared.gradio['load_chat_history'].upload( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( chat.handle_upload_chat_history, gradio('load_chat_history', 'interface_state'), gradio('history', 'display', 'unique_id'), show_progress=False).then( None, None, None, js=f'() => {{{ui.switch_tabs_js}; switch_to_chat()}}') shared.gradio['character_menu'].change( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( chat.handle_character_menu_change, gradio('interface_state'), gradio('history', 'display', 'name1', 'name2', 'character_picture', 'greeting', 'context', 'unique_id'), show_progress=False).then( None, None, None, js=f'() => {{{ui.update_big_picture_js}; updateBigPicture()}}') shared.gradio['character_picture'].change(chat.handle_character_picture_change, gradio('character_picture'), None, show_progress=False) shared.gradio['mode'].change( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( chat.handle_mode_change, gradio('interface_state'), gradio('history', 'display', 'chat_style', 'chat-instruct_command', 'unique_id'), show_progress=False).then( None, gradio('mode'), None, js="(mode) => {const characterContainer = document.getElementById('character-menu').parentNode.parentNode; const isInChatTab = document.querySelector('#chat-controls').contains(characterContainer); if (isInChatTab) { characterContainer.style.display = mode === 'instruct' ? 'none' : ''; } if (mode === 'instruct') document.querySelectorAll('.bigProfilePicture').forEach(el => el.remove());}") shared.gradio['chat_style'].change(chat.redraw_html, gradio(reload_arr), gradio('display'), show_progress=False) shared.gradio['navigate_version'].click( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( chat.handle_navigate_version_click, gradio('interface_state'), gradio('history', 'display'), show_progress=False) shared.gradio['edit_message'].click( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( chat.handle_edit_message_click, gradio('interface_state'), gradio('history', 'display'), show_progress=False) # Save/delete a character shared.gradio['save_character'].click(chat.handle_save_character_click, gradio('name2'), gradio('save_character_filename', 'character_saver'), show_progress=False) shared.gradio['delete_character'].click(lambda: gr.update(visible=True), None, gradio('character_deleter'), show_progress=False) shared.gradio['load_template'].click(chat.handle_load_template_click, gradio('instruction_template'), gradio('instruction_template_str', 'instruction_template'), show_progress=False) shared.gradio['save_template'].click( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( chat.handle_save_template_click, gradio('instruction_template_str'), gradio('save_filename', 'save_root', 'save_contents', 'save_root_state', 'file_saver'), show_progress=False) shared.gradio['restore_character'].click( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( chat.restore_character_for_ui, gradio('interface_state'), gradio('interface_state', 'name2', 'context', 'greeting', 'character_picture'), show_progress=False) shared.gradio['delete_template'].click(chat.handle_delete_template_click, gradio('instruction_template'), gradio('delete_filename', 'delete_root', 'delete_root_state', 'file_deleter'), show_progress=False) shared.gradio['save_chat_history'].click( lambda x: json.dumps(x, indent=4), gradio('history'), gradio('temporary_text')).then( None, gradio('temporary_text', 'character_menu', 'mode'), None, js=f'(hist, char, mode) => {{{ui.save_files_js}; saveHistory(hist, char, mode)}}') shared.gradio['Submit character'].click( chat.upload_character, gradio('upload_json', 'upload_img_bot'), gradio('character_menu'), show_progress=False).then( None, None, None, js=f'() => {{{ui.switch_tabs_js}; switch_to_character()}}') shared.gradio['Submit tavern character'].click( chat.upload_tavern_character, gradio('upload_img_tavern', 'tavern_json'), gradio('character_menu'), show_progress=False).then( None, None, None, js=f'() => {{{ui.switch_tabs_js}; switch_to_character()}}') shared.gradio['upload_json'].upload(lambda: gr.update(interactive=True), None, gradio('Submit character')) shared.gradio['upload_json'].clear(lambda: gr.update(interactive=False), None, gradio('Submit character')) shared.gradio['upload_img_tavern'].upload(chat.check_tavern_character, gradio('upload_img_tavern'), gradio('tavern_name', 'tavern_desc', 'tavern_json', 'Submit tavern character'), show_progress=False) shared.gradio['upload_img_tavern'].clear(lambda: (None, None, None, gr.update(interactive=False)), None, gradio('tavern_name', 'tavern_desc', 'tavern_json', 'Submit tavern character'), show_progress=False) shared.gradio['your_picture'].change( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( chat.handle_your_picture_change, gradio('your_picture', 'interface_state'), gradio('display'), show_progress=False) shared.gradio['send_instruction_to_notebook'].click( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( chat.handle_send_instruction_click, gradio('interface_state'), gradio('textbox-notebook', 'textbox-default', 'output_textbox'), show_progress=False).then( None, None, None, js=f'() => {{{ui.switch_tabs_js}; switch_to_notebook()}}') shared.gradio['send-chat-to-notebook'].click( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( chat.handle_send_chat_click, gradio('interface_state'), gradio('textbox-notebook', 'textbox-default', 'output_textbox'), show_progress=False).then( None, None, None, js=f'() => {{{ui.switch_tabs_js}; switch_to_notebook()}}') shared.gradio['show_controls'].change(None, gradio('show_controls'), None, js=f'(x) => {{{ui.show_controls_js}; toggle_controls(x)}}') shared.gradio['count_tokens'].click( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( chat.count_prompt_tokens, gradio('textbox', 'interface_state'), gradio('token_display'), show_progress=False) shared.gradio['enable_web_search'].change( lambda x: gr.update(visible=x), gradio('enable_web_search'), gradio('web_search_row') ) # User menu event handlers shared.gradio['user_menu'].change( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( chat.handle_user_menu_change, gradio('interface_state'), gradio('name1', 'user_bio', 'your_picture'), show_progress=False) shared.gradio['save_user'].click(chat.handle_save_user_click, gradio('name1'), gradio('save_user_filename', 'user_saver'), show_progress=False) shared.gradio['delete_user'].click(lambda: gr.update(visible=True), None, gradio('user_deleter'), show_progress=False) ================================================ FILE: modules/ui_default.py ================================================ from pathlib import Path import gradio as gr from modules import logits, shared, ui, utils from modules.prompts import count_tokens, load_prompt from modules.text_generation import ( generate_reply_wrapper, get_token_ids, stop_everything_event ) from modules.ui_notebook import store_notebook_state_and_debounce from modules.utils import gradio inputs = ('textbox-default', 'interface_state') outputs = ('output_textbox', 'html-default') def create_ui(): mu = shared.args.multi_user with gr.Row(visible=shared.settings['show_two_notebook_columns']) as shared.gradio['default-tab']: with gr.Row(): with gr.Column(): with gr.Row(): shared.gradio['textbox-default'] = gr.Textbox(value="", lines=27, label='Input', elem_classes=['textbox_default', 'add_scrollbar']) shared.gradio['token-counter-default'] = gr.HTML(value="0", elem_id="default-token-counter") with gr.Row(): shared.gradio['Continue-default'] = gr.Button('Continue') shared.gradio['Stop-default'] = gr.Button('Stop', elem_id='stop', visible=False) shared.gradio['Generate-default'] = gr.Button('Generate', variant='primary') with gr.Row(): shared.gradio['prompt_menu-default'] = gr.Dropdown(choices=utils.get_available_prompts(), value=shared.settings['prompt-notebook'], label='Prompt', elem_classes='slim-dropdown') ui.create_refresh_button(shared.gradio['prompt_menu-default'], lambda: None, lambda: {'choices': utils.get_available_prompts()}, 'refresh-button', interactive=not mu) shared.gradio['new_prompt-default'] = gr.Button('New', elem_classes='refresh-button', interactive=not mu) shared.gradio['rename_prompt-default'] = gr.Button('Rename', elem_classes='refresh-button', interactive=not mu) shared.gradio['delete_prompt-default'] = gr.Button('🗑️', elem_classes='refresh-button', interactive=not mu) # Rename elements (initially hidden) shared.gradio['rename_prompt_to-default'] = gr.Textbox(label="New name", elem_classes=['no-background'], visible=False) shared.gradio['rename_prompt-cancel-default'] = gr.Button('Cancel', elem_classes=['refresh-button'], visible=False) shared.gradio['rename_prompt-confirm-default'] = gr.Button('Confirm', elem_classes=['refresh-button'], variant='primary', visible=False) # Delete confirmation elements (initially hidden) shared.gradio['delete_prompt-cancel-default'] = gr.Button('Cancel', elem_classes=['refresh-button'], visible=False) shared.gradio['delete_prompt-confirm-default'] = gr.Button('Confirm', variant='stop', elem_classes=['refresh-button'], visible=False) with gr.Column(): with gr.Tab('Raw'): shared.gradio['output_textbox'] = gr.Textbox(lines=27, label='Output', elem_id='textbox-default', elem_classes=['textbox_default_output', 'add_scrollbar']) with gr.Tab('Markdown'): shared.gradio['markdown_render-default'] = gr.Button('Render') shared.gradio['markdown-default'] = gr.Markdown() with gr.Tab('HTML'): shared.gradio['html-default'] = gr.HTML() with gr.Tab('Logits'): with gr.Row(): with gr.Column(scale=10): shared.gradio['get_logits-default'] = gr.Button('Get next token probabilities') with gr.Column(scale=1): shared.gradio['use_samplers-default'] = gr.Checkbox(label='Use samplers', value=True, elem_classes=['no-background']) with gr.Row(): shared.gradio['logits-default'] = gr.Textbox(lines=23, label='Output', elem_classes=['textbox_logits', 'add_scrollbar']) shared.gradio['logits-default-previous'] = gr.Textbox(lines=23, label='Previous output', elem_classes=['textbox_logits', 'add_scrollbar']) with gr.Tab('Tokens'): shared.gradio['get_tokens-default'] = gr.Button('Get token IDs for the input') shared.gradio['tokens-default'] = gr.Textbox(lines=23, label='Tokens', elem_classes=['textbox_logits', 'add_scrollbar', 'monospace']) def create_event_handlers(): shared.gradio['Generate-default'].click( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('Stop-default', 'Generate-default')).then( generate_reply_wrapper, gradio('textbox-default', 'interface_state'), gradio(outputs), show_progress=False).then( lambda state, left, right: state.update({'textbox-default': left, 'output_textbox': right}), gradio('interface_state', 'textbox-default', 'output_textbox'), None).then( lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('Stop-default', 'Generate-default')).then( None, None, None, js=f'() => {{{ui.audio_notification_js}}}') shared.gradio['textbox-default'].submit( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('Stop-default', 'Generate-default')).then( generate_reply_wrapper, gradio('textbox-default', 'interface_state'), gradio(outputs), show_progress=False).then( lambda state, left, right: state.update({'textbox-default': left, 'output_textbox': right}), gradio('interface_state', 'textbox-default', 'output_textbox'), None).then( lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('Stop-default', 'Generate-default')).then( None, None, None, js=f'() => {{{ui.audio_notification_js}}}') shared.gradio['Continue-default'].click( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('Stop-default', 'Generate-default')).then( generate_reply_wrapper, gradio('output_textbox', 'interface_state'), gradio(outputs), show_progress=False).then( lambda state, left, right: state.update({'textbox-default': left, 'output_textbox': right}), gradio('interface_state', 'textbox-default', 'output_textbox'), None).then( lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('Stop-default', 'Generate-default')).then( None, None, None, js=f'() => {{{ui.audio_notification_js}}}') shared.gradio['Stop-default'].click(stop_everything_event, None, None, queue=False) shared.gradio['markdown_render-default'].click(lambda x: x, gradio('output_textbox'), gradio('markdown-default'), queue=False) shared.gradio['prompt_menu-default'].change(lambda x: (load_prompt(x), ""), gradio('prompt_menu-default'), gradio('textbox-default', 'output_textbox'), show_progress=False) shared.gradio['new_prompt-default'].click(handle_new_prompt, None, gradio('prompt_menu-default'), show_progress=False) # Input change handler to save input (reusing notebook's debounced saving) shared.gradio['textbox-default'].change( store_notebook_state_and_debounce, gradio('textbox-default', 'prompt_menu-default'), None, show_progress=False ) shared.gradio['delete_prompt-default'].click( lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)], None, gradio('delete_prompt-default', 'delete_prompt-cancel-default', 'delete_prompt-confirm-default'), show_progress=False) shared.gradio['delete_prompt-cancel-default'].click( lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)], None, gradio('delete_prompt-default', 'delete_prompt-cancel-default', 'delete_prompt-confirm-default'), show_progress=False) shared.gradio['delete_prompt-confirm-default'].click( handle_delete_prompt_confirm_default, gradio('prompt_menu-default'), gradio('prompt_menu-default', 'delete_prompt-default', 'delete_prompt-cancel-default', 'delete_prompt-confirm-default'), show_progress=False) shared.gradio['rename_prompt-default'].click( handle_rename_prompt_click_default, gradio('prompt_menu-default'), gradio('rename_prompt_to-default', 'rename_prompt-default', 'rename_prompt-cancel-default', 'rename_prompt-confirm-default'), show_progress=False) shared.gradio['rename_prompt-cancel-default'].click( lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)], None, gradio('rename_prompt_to-default', 'rename_prompt-default', 'rename_prompt-cancel-default', 'rename_prompt-confirm-default'), show_progress=False) shared.gradio['rename_prompt-confirm-default'].click( handle_rename_prompt_confirm_default, gradio('rename_prompt_to-default', 'prompt_menu-default'), gradio('prompt_menu-default', 'rename_prompt_to-default', 'rename_prompt-default', 'rename_prompt-cancel-default', 'rename_prompt-confirm-default'), show_progress=False) shared.gradio['textbox-default'].change(lambda x: f"{count_tokens(x)}", gradio('textbox-default'), gradio('token-counter-default'), show_progress=False) shared.gradio['get_logits-default'].click( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( logits.get_next_logits, gradio('textbox-default', 'interface_state', 'use_samplers-default', 'logits-default'), gradio('logits-default', 'logits-default-previous'), show_progress=False) shared.gradio['get_tokens-default'].click(get_token_ids, gradio('textbox-default'), gradio('tokens-default'), show_progress=False) def handle_new_prompt(): new_name = utils.current_time() # Create the new prompt file prompt_path = shared.user_data_dir / "logs" / "notebook" / f"{new_name}.txt" prompt_path.parent.mkdir(parents=True, exist_ok=True) prompt_path.write_text("In this story,", encoding='utf-8') return gr.update(choices=utils.get_available_prompts(), value=new_name) def handle_delete_prompt_confirm_default(prompt_name): available_prompts = utils.get_available_prompts() current_index = available_prompts.index(prompt_name) if prompt_name in available_prompts else 0 (shared.user_data_dir / "logs" / "notebook" / f"{prompt_name}.txt").unlink(missing_ok=True) available_prompts = utils.get_available_prompts() if available_prompts: new_value = available_prompts[min(current_index, len(available_prompts) - 1)] else: new_value = utils.current_time() (shared.user_data_dir / "logs" / "notebook").mkdir(parents=True, exist_ok=True) (shared.user_data_dir / "logs" / "notebook" / f"{new_value}.txt").write_text("In this story,") available_prompts = [new_value] return [ gr.update(choices=available_prompts, value=new_value), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) ] def handle_rename_prompt_click_default(current_name): return [ gr.update(value=current_name, visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True) ] def handle_rename_prompt_confirm_default(new_name, current_name): old_path = shared.user_data_dir / "logs" / "notebook" / f"{current_name}.txt" new_path = shared.user_data_dir / "logs" / "notebook" / f"{new_name}.txt" if old_path.exists() and not new_path.exists(): old_path.rename(new_path) available_prompts = utils.get_available_prompts() return [ gr.update(choices=available_prompts, value=new_name), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) ] ================================================ FILE: modules/ui_file_saving.py ================================================ import traceback import gradio as gr from modules import chat, presets, shared, ui, utils from modules.utils import gradio, sanitize_filename def create_ui(): mu = shared.args.multi_user # Server-side per-session root paths for the generic file saver/deleter. # Set by the handler that opens the dialog, read by the confirm handler. # Using gr.State so they are session-scoped and safe for multi-user. shared.gradio['save_root_state'] = gr.State(None) shared.gradio['delete_root_state'] = gr.State(None) # Text file saver with gr.Group(visible=False, elem_classes='file-saver') as shared.gradio['file_saver']: shared.gradio['save_filename'] = gr.Textbox(lines=1, label='File name') shared.gradio['save_root'] = gr.Textbox(lines=1, label='File folder', info='For reference. Unchangeable.', interactive=False) shared.gradio['save_contents'] = gr.Textbox(lines=10, label='File contents') with gr.Row(): shared.gradio['save_cancel'] = gr.Button('Cancel', elem_classes="small-button") shared.gradio['save_confirm'] = gr.Button('Save', elem_classes="small-button", variant='primary', interactive=not mu) # Text file deleter with gr.Group(visible=False, elem_classes='file-saver') as shared.gradio['file_deleter']: shared.gradio['delete_filename'] = gr.Textbox(lines=1, label='File name') shared.gradio['delete_root'] = gr.Textbox(lines=1, label='File folder', info='For reference. Unchangeable.', interactive=False) with gr.Row(): shared.gradio['delete_cancel'] = gr.Button('Cancel', elem_classes="small-button") shared.gradio['delete_confirm'] = gr.Button('Delete', elem_classes="small-button", variant='stop', interactive=not mu) # Character saver/deleter with gr.Group(visible=False, elem_classes='file-saver') as shared.gradio['character_saver']: shared.gradio['save_character_filename'] = gr.Textbox(lines=1, label='File name', info=f'The character will be saved to your {shared.user_data_dir}/characters folder with this base filename.') with gr.Row(): shared.gradio['save_character_cancel'] = gr.Button('Cancel', elem_classes="small-button") shared.gradio['save_character_confirm'] = gr.Button('Save', elem_classes="small-button", variant='primary', interactive=not mu) with gr.Group(visible=False, elem_classes='file-saver') as shared.gradio['character_deleter']: gr.Markdown('Confirm the character deletion?') with gr.Row(): shared.gradio['delete_character_cancel'] = gr.Button('Cancel', elem_classes="small-button") shared.gradio['delete_character_confirm'] = gr.Button('Delete', elem_classes="small-button", variant='stop', interactive=not mu) # User saver/deleter with gr.Group(visible=False, elem_classes='file-saver') as shared.gradio['user_saver']: shared.gradio['save_user_filename'] = gr.Textbox(lines=1, label='File name', info=f'The user profile will be saved to your {shared.user_data_dir}/users folder with this base filename.') with gr.Row(): shared.gradio['save_user_cancel'] = gr.Button('Cancel', elem_classes="small-button") shared.gradio['save_user_confirm'] = gr.Button('Save', elem_classes="small-button", variant='primary', interactive=not mu) with gr.Group(visible=False, elem_classes='file-saver') as shared.gradio['user_deleter']: gr.Markdown('Confirm the user deletion?') with gr.Row(): shared.gradio['delete_user_cancel'] = gr.Button('Cancel', elem_classes="small-button") shared.gradio['delete_user_confirm'] = gr.Button('Delete', elem_classes="small-button", variant='stop', interactive=not mu) # Preset saver with gr.Group(visible=False, elem_classes='file-saver') as shared.gradio['preset_saver']: shared.gradio['save_preset_filename'] = gr.Textbox(lines=1, label='File name', info=f'The preset will be saved to your {shared.user_data_dir}/presets folder with this base filename.') shared.gradio['save_preset_contents'] = gr.Textbox(lines=10, label='File contents') with gr.Row(): shared.gradio['save_preset_cancel'] = gr.Button('Cancel', elem_classes="small-button") shared.gradio['save_preset_confirm'] = gr.Button('Save', elem_classes="small-button", variant='primary', interactive=not mu) def create_event_handlers(): shared.gradio['save_preset'].click( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( handle_save_preset_click, gradio('interface_state'), gradio('save_preset_contents', 'save_preset_filename', 'preset_saver'), show_progress=False) shared.gradio['delete_preset'].click(handle_delete_preset_click, gradio('preset_menu'), gradio('delete_filename', 'delete_root', 'delete_root_state', 'file_deleter'), show_progress=False) shared.gradio['save_grammar'].click(handle_save_grammar_click, gradio('grammar_string'), gradio('save_contents', 'save_filename', 'save_root', 'save_root_state', 'file_saver'), show_progress=False) shared.gradio['delete_grammar'].click(handle_delete_grammar_click, gradio('grammar_file'), gradio('delete_filename', 'delete_root', 'delete_root_state', 'file_deleter'), show_progress=False) shared.gradio['save_preset_confirm'].click(handle_save_preset_confirm_click, gradio('save_preset_filename', 'save_preset_contents'), gradio('preset_menu', 'preset_saver'), show_progress=False) shared.gradio['save_confirm'].click(handle_save_confirm_click, gradio('save_root_state', 'save_filename', 'save_contents'), gradio('save_root_state', 'file_saver'), show_progress=False) shared.gradio['delete_confirm'].click(handle_delete_confirm_click, gradio('delete_root_state', 'delete_filename'), gradio('delete_root_state', 'file_deleter'), show_progress=False) shared.gradio['save_character_confirm'].click(handle_save_character_confirm_click, gradio('name2', 'greeting', 'context', 'character_picture', 'save_character_filename'), gradio('character_menu', 'character_saver'), show_progress=False) shared.gradio['delete_character_confirm'].click(handle_delete_character_confirm_click, gradio('character_menu'), gradio('character_menu', 'character_deleter'), show_progress=False) shared.gradio['save_preset_cancel'].click(lambda: gr.update(visible=False), None, gradio('preset_saver'), show_progress=False) shared.gradio['save_cancel'].click(lambda: gr.update(visible=False), None, gradio('file_saver')) shared.gradio['delete_cancel'].click(lambda: gr.update(visible=False), None, gradio('file_deleter')) shared.gradio['save_character_cancel'].click(lambda: gr.update(visible=False), None, gradio('character_saver'), show_progress=False) shared.gradio['delete_character_cancel'].click(lambda: gr.update(visible=False), None, gradio('character_deleter'), show_progress=False) # User save/delete event handlers shared.gradio['save_user_confirm'].click(handle_save_user_confirm_click, gradio('name1', 'user_bio', 'your_picture', 'save_user_filename'), gradio('user_menu', 'user_saver'), show_progress=False) shared.gradio['delete_user_confirm'].click(handle_delete_user_confirm_click, gradio('user_menu'), gradio('user_menu', 'user_deleter'), show_progress=False) shared.gradio['save_user_cancel'].click(lambda: gr.update(visible=False), None, gradio('user_saver'), show_progress=False) shared.gradio['delete_user_cancel'].click(lambda: gr.update(visible=False), None, gradio('user_deleter'), show_progress=False) def handle_save_preset_confirm_click(filename, contents): try: filename = sanitize_filename(filename) utils.save_file(str(shared.user_data_dir / "presets" / f"{filename}.yaml"), contents) available_presets = utils.get_available_presets() output = gr.update(choices=available_presets, value=filename) except Exception: output = gr.update() traceback.print_exc() return [ output, gr.update(visible=False) ] def handle_save_confirm_click(root_state, filename, contents): try: if root_state is None: return None, gr.update(visible=False) filename = sanitize_filename(filename) utils.save_file(root_state + filename, contents) except Exception: traceback.print_exc() return None, gr.update(visible=False) def handle_delete_confirm_click(root_state, filename): try: if root_state is None: return None, gr.update(visible=False) filename = sanitize_filename(filename) utils.delete_file(root_state + filename) except Exception: traceback.print_exc() return None, gr.update(visible=False) def handle_save_character_confirm_click(name2, greeting, context, character_picture, filename): try: chat.save_character(name2, greeting, context, character_picture, filename) available_characters = utils.get_available_characters() output = gr.update(choices=available_characters, value=filename) except Exception: output = gr.update() traceback.print_exc() return [ output, gr.update(visible=False) ] def handle_delete_character_confirm_click(character): try: index = str(utils.get_available_characters().index(character)) chat.delete_character(character) output = chat.update_character_menu_after_deletion(index) except Exception: output = gr.update() traceback.print_exc() return [ output, gr.update(visible=False) ] def handle_save_preset_click(state): contents = presets.generate_preset_yaml(state) return [ contents, "My Preset", gr.update(visible=True) ] def handle_delete_preset_click(preset): root = str(shared.user_data_dir / "presets") + "/" return [ f"{preset}.yaml", root, root, gr.update(visible=True) ] def handle_save_grammar_click(grammar_string): root = str(shared.user_data_dir / "grammars") + "/" return [ grammar_string, "My Fancy Grammar.gbnf", root, root, gr.update(visible=True) ] def handle_delete_grammar_click(grammar_file): root = str(shared.user_data_dir / "grammars") + "/" return [ grammar_file, root, root, gr.update(visible=True) ] def handle_save_user_confirm_click(name1, user_bio, your_picture, filename): try: chat.save_user(name1, user_bio, your_picture, filename) available_users = utils.get_available_users() output = gr.update(choices=available_users, value=filename) except Exception: output = gr.update() traceback.print_exc() return [ output, gr.update(visible=False) ] def handle_delete_user_confirm_click(user): try: index = str(utils.get_available_users().index(user)) chat.delete_user(user) output = chat.update_user_menu_after_deletion(index) except Exception: output = gr.update() traceback.print_exc() return [ output, gr.update(visible=False) ] ================================================ FILE: modules/ui_image_generation.py ================================================ import json import os import random import time import traceback from datetime import datetime from pathlib import Path import gradio as gr from PIL.PngImagePlugin import PngInfo from modules import shared, ui, utils from modules.image_models import ( get_pipeline_type, load_image_model, unload_image_model ) from modules.image_utils import open_image_safely from modules.logging_colors import logger from modules.text_generation import stop_everything_event from modules.utils import check_model_loaded, gradio ASPECT_RATIOS = { "1:1 Square": (1, 1), "16:9 Cinema": (16, 9), "9:16 Mobile": (9, 16), "4:3 Photo": (4, 3), "Custom": None, } STEP = 16 IMAGES_PER_PAGE = 32 # Settings keys to save in PNG metadata (Generate tab only) METADATA_SETTINGS_KEYS = [ 'image_prompt', 'image_neg_prompt', 'image_width', 'image_height', 'image_aspect_ratio', 'image_steps', 'image_seed', 'image_cfg_scale', ] # Cache for all image paths _image_cache = [] _cache_timestamp = 0 def round_to_step(value, step=STEP): return round(value / step) * step def clamp(value, min_val, max_val): return max(min_val, min(max_val, value)) def apply_aspect_ratio(aspect_ratio, current_width, current_height): if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS: return current_width, current_height w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio] if w_ratio == h_ratio: base = min(current_width, current_height) new_width = base new_height = base elif w_ratio < h_ratio: new_width = current_width new_height = round_to_step(current_width * h_ratio / w_ratio) else: new_height = current_height new_width = round_to_step(current_height * w_ratio / h_ratio) new_width = clamp(new_width, 256, 2048) new_height = clamp(new_height, 256, 2048) return int(new_width), int(new_height) def update_height_from_width(width, aspect_ratio): if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS: return gr.update() w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio] new_height = round_to_step(width * h_ratio / w_ratio) new_height = clamp(new_height, 256, 2048) return int(new_height) def update_width_from_height(height, aspect_ratio): if aspect_ratio == "Custom" or aspect_ratio not in ASPECT_RATIOS: return gr.update() w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio] new_width = round_to_step(height * w_ratio / h_ratio) new_width = clamp(new_width, 256, 2048) return int(new_width) def swap_dimensions_and_update_ratio(width, height, aspect_ratio): new_width, new_height = height, width new_ratio = "Custom" for name, ratios in ASPECT_RATIOS.items(): if ratios is None: continue w_r, h_r = ratios expected_height = new_width * h_r / w_r if abs(expected_height - new_height) < STEP: new_ratio = name break return new_width, new_height, new_ratio def build_generation_metadata(state, actual_seed): """Build metadata dict from generation settings.""" metadata = {} for key in METADATA_SETTINGS_KEYS: if key in state: metadata[key] = state[key] # Store the actual seed used (not -1) metadata['image_seed'] = actual_seed metadata['generated_at'] = datetime.now().isoformat() metadata['model'] = shared.image_model_name return metadata def save_generated_images(images, state, actual_seed): """Save images with generation metadata embedded in PNG. Returns list of saved file paths.""" if shared.args.multi_user: return [] date_str = datetime.now().strftime("%Y-%m-%d") folder_path = str(shared.user_data_dir / "image_outputs" / date_str) os.makedirs(folder_path, exist_ok=True) metadata = build_generation_metadata(state, actual_seed) metadata_json = json.dumps(metadata, ensure_ascii=False) saved_paths = [] for idx, img in enumerate(images): timestamp = datetime.now().strftime("%H-%M-%S") filename = f"TGW_{timestamp}_{actual_seed:010d}_{idx:03d}.png" filepath = os.path.join(folder_path, filename) # Create PNG metadata png_info = PngInfo() png_info.add_text("image_gen_settings", metadata_json) # Save with metadata img.save(filepath, pnginfo=png_info) saved_paths.append(filepath) return saved_paths def read_image_metadata(image_path): """Read generation metadata from PNG file.""" try: img = open_image_safely(image_path) if img is None: return None try: if hasattr(img, 'text') and 'image_gen_settings' in img.text: return json.loads(img.text['image_gen_settings']) finally: img.close() except Exception as e: logger.debug(f"Could not read metadata from {image_path}: {e}") return None def format_metadata_for_display(metadata): """Format metadata as readable text.""" if not metadata: return "No generation settings found in this image." lines = [] # Display in a nice order display_order = [ ('image_prompt', 'Prompt'), ('image_neg_prompt', 'Negative Prompt'), ('image_width', 'Width'), ('image_height', 'Height'), ('image_aspect_ratio', 'Aspect Ratio'), ('image_steps', 'Steps'), ('image_cfg_scale', 'CFG Scale'), ('image_seed', 'Seed'), ('model', 'Model'), ('generated_at', 'Generated At'), ] for key, label in display_order: if key in metadata: value = metadata[key] if key in ['image_prompt', 'image_neg_prompt'] and value: # Truncate long prompts for display if len(str(value)) > 200: value = str(value)[:200] + "..." lines.append(f"**{label}:** {value}") return "\n\n".join(lines) def get_all_history_images(force_refresh=False): """Get all history images sorted by modification time (newest first). Uses caching.""" global _image_cache, _cache_timestamp output_dir = str(shared.user_data_dir / "image_outputs") if not os.path.exists(output_dir): return [] # Check if we need to refresh cache current_time = time.time() if not force_refresh and _image_cache and (current_time - _cache_timestamp) < 2: return _image_cache image_files = [] for root, _, files in os.walk(output_dir): for file in files: if file.endswith((".png", ".jpg", ".jpeg")): full_path = os.path.join(root, file) image_files.append((full_path, os.path.getmtime(full_path))) image_files.sort(key=lambda x: x[1], reverse=True) _image_cache = [x[0] for x in image_files] _cache_timestamp = current_time return _image_cache def get_paginated_images(page=0, force_refresh=False): """Get images for a specific page.""" all_images = get_all_history_images(force_refresh) total_images = len(all_images) total_pages = max(1, (total_images + IMAGES_PER_PAGE - 1) // IMAGES_PER_PAGE) # Clamp page to valid range page = max(0, min(page, total_pages - 1)) start_idx = page * IMAGES_PER_PAGE end_idx = min(start_idx + IMAGES_PER_PAGE, total_images) page_images = all_images[start_idx:end_idx] return page_images, page, total_pages, total_images def get_initial_page_info(): """Get page info string for initial load.""" _, page, total_pages, total_images = get_paginated_images(0) return f"Page {page + 1} of {total_pages} ({total_images} total images)" def refresh_gallery(current_page=0): """Refresh gallery with current page.""" images, page, total_pages, total_images = get_paginated_images(current_page, force_refresh=True) page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)" return images, page, page_info def go_to_page(page_num, current_page): """Go to a specific page (1-indexed input).""" try: page = int(page_num) - 1 # Convert to 0-indexed except (ValueError, TypeError): page = current_page images, page, total_pages, total_images = get_paginated_images(page) page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)" return images, page, page_info def next_page(current_page): """Go to next page.""" images, page, total_pages, total_images = get_paginated_images(current_page + 1) page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)" return images, page, page_info def prev_page(current_page): """Go to previous page.""" images, page, total_pages, total_images = get_paginated_images(current_page - 1) page_info = f"Page {page + 1} of {total_pages} ({total_images} total images)" return images, page, page_info def on_gallery_select(evt: gr.SelectData, current_page): """Handle image selection from gallery.""" if evt.index is None: return "", "Select an image to view its settings" if not _image_cache: get_all_history_images() all_images = _image_cache total_images = len(all_images) # Calculate the actual index in the full list start_idx = current_page * IMAGES_PER_PAGE actual_idx = start_idx + evt.index if actual_idx >= total_images: return "", "Image not found" image_path = all_images[actual_idx] metadata = read_image_metadata(image_path) metadata_display = format_metadata_for_display(metadata) return image_path, metadata_display def send_to_generate(selected_image_path): """Load settings from selected image and return updates for all Generate tab inputs.""" if not selected_image_path or not os.path.exists(selected_image_path): return [gr.update()] * 8 + ["No image selected"] metadata = read_image_metadata(selected_image_path) if not metadata: return [gr.update()] * 8 + ["No settings found in this image"] # Return updates for each input element in order updates = [ gr.update(value=metadata.get('image_prompt', '')), gr.update(value=metadata.get('image_neg_prompt', '')), gr.update(value=metadata.get('image_width', 1024)), gr.update(value=metadata.get('image_height', 1024)), gr.update(value=metadata.get('image_aspect_ratio', '1:1 Square')), gr.update(value=metadata.get('image_steps', 9)), gr.update(value=metadata.get('image_seed', -1)), gr.update(value=metadata.get('image_cfg_scale', 0.0)), ] status = f"✓ Settings loaded from image (seed: {metadata.get('image_seed', 'unknown')})" return updates + [status] def read_dropped_image_metadata(image_path): """Read metadata from a dropped/uploaded image.""" if not image_path: return "Drop an image to view its generation settings." metadata = read_image_metadata(image_path) return format_metadata_for_display(metadata) def create_ui(): if shared.settings['image_model_menu'] != 'None': shared.image_model_name = shared.settings['image_model_menu'] with gr.Tab("Image AI", elem_id="image-ai-tab"): with gr.Tabs(): # TAB 1: GENERATE with gr.TabItem("Generate"): with gr.Row(): with gr.Column(scale=4, min_width=350): shared.gradio['image_prompt'] = gr.Textbox( label="Prompt", placeholder="Describe your imagination...", lines=3, autofocus=True, value=shared.settings['image_prompt'] ) shared.gradio['image_neg_prompt'] = gr.Textbox( label="Negative Prompt", placeholder="Low quality...", lines=3, value=shared.settings['image_neg_prompt'] ) shared.gradio['image_llm_variations'] = gr.Checkbox( value=shared.settings['image_llm_variations'], label='LLM Prompt Variations', elem_id="llm-prompt-variations", ) shared.gradio['image_llm_variations_prompt'] = gr.Textbox( value=shared.settings['image_llm_variations_prompt'], label='Variation Prompt', lines=3, placeholder='Instructions for generating prompt variations...', visible=shared.settings['image_llm_variations'], info='Use the loaded LLM to generate creative prompt variations for each sequential batch.' ) shared.gradio['image_generate_btn'] = gr.Button("Generate", variant="primary", size="lg") shared.gradio['image_stop_btn'] = gr.Button("Stop", size="lg", visible=False) shared.gradio['image_progress'] = gr.HTML( value=progress_bar_html(), elem_id="image-progress" ) gr.Markdown("### Dimensions") with gr.Row(): with gr.Column(): shared.gradio['image_width'] = gr.Slider(256, 2048, value=shared.settings['image_width'], step=STEP, label="Width") with gr.Column(): shared.gradio['image_height'] = gr.Slider(256, 2048, value=shared.settings['image_height'], step=STEP, label="Height") shared.gradio['image_swap_btn'] = gr.Button("⇄ Swap", elem_classes='refresh-button', scale=0, min_width=80, elem_id="swap-height-width") with gr.Row(): shared.gradio['image_aspect_ratio'] = gr.Radio( choices=["1:1 Square", "16:9 Cinema", "9:16 Mobile", "4:3 Photo", "Custom"], value=shared.settings['image_aspect_ratio'], label="Aspect Ratio", interactive=True ) gr.Markdown("### Config") with gr.Row(): with gr.Column(): shared.gradio['image_steps'] = gr.Slider(1, 100, value=shared.settings['image_steps'], step=1, label="Steps") shared.gradio['image_cfg_scale'] = gr.Slider( 0.0, 10.0, value=shared.settings['image_cfg_scale'], step=0.1, label="CFG Scale", info="Z-Image Turbo: 0.0 | Qwen: 4.0" ) shared.gradio['image_seed'] = gr.Number(label="Seed", value=shared.settings['image_seed'], precision=0, info="-1 = Random") with gr.Column(): shared.gradio['image_batch_size'] = gr.Slider(1, 32, value=shared.settings['image_batch_size'], step=1, label="Batch Size (VRAM Heavy)", info="Generates N images at once.") shared.gradio['image_batch_count'] = gr.Slider(1, 128, value=shared.settings['image_batch_count'], step=1, label="Sequential Count (Loop)", info="Repeats the generation N times.") with gr.Column(scale=6, min_width=500): with gr.Column(elem_classes=["viewport-container"]): shared.gradio['image_output_gallery'] = gr.Gallery(label="Output", show_label=False, columns=2, rows=2, height="80vh", object_fit="contain", preview=True, elem_id="image-output-gallery") # TAB 2: GALLERY (with pagination) with gr.TabItem("Gallery"): with gr.Row(): with gr.Column(scale=3): # Pagination controls with gr.Row(): shared.gradio['image_refresh_history'] = gr.Button("🔄 Refresh", elem_classes="refresh-button") shared.gradio['image_prev_page'] = gr.Button("◀ Prev Page", elem_classes="refresh-button") shared.gradio['image_page_info'] = gr.Markdown(value=get_initial_page_info, elem_id="image-page-info") shared.gradio['image_next_page'] = gr.Button("Next Page ▶", elem_classes="refresh-button") shared.gradio['image_page_input'] = gr.Number(value=1, label="Page", precision=0, minimum=1, scale=0, min_width=80) shared.gradio['image_go_to_page'] = gr.Button("Go", elem_classes="refresh-button", scale=0, min_width=50) # State for current page and selected image path shared.gradio['image_current_page'] = gr.State(value=0) shared.gradio['image_selected_path'] = gr.State(value="") # Paginated gallery using gr.Gallery shared.gradio['image_history_gallery'] = gr.Gallery( value=lambda: get_paginated_images(0)[0], label="Image History", show_label=False, columns=6, object_fit="cover", height="auto", allow_preview=True, elem_id="image-history-gallery" ) with gr.Column(scale=1): gr.Markdown("### Generation Settings") shared.gradio['image_settings_display'] = gr.Markdown("Select an image to view its settings") shared.gradio['image_send_to_generate'] = gr.Button("Send to Generate", variant="primary") shared.gradio['image_gallery_status'] = gr.Markdown("") gr.Markdown("### Import Image") shared.gradio['image_drop_upload'] = gr.Image( label="Drop image here to view settings", type="filepath", height=150 ) # TAB 3: MODEL with gr.TabItem("Model"): with gr.Row(): with gr.Column(): with gr.Row(): shared.gradio['image_model_menu'] = gr.Dropdown( choices=utils.get_available_image_models(), value=shared.settings['image_model_menu'], label='Model', elem_classes='slim-dropdown' ) shared.gradio['image_refresh_models'] = gr.Button("🔄", elem_classes='refresh-button', scale=0, min_width=40) shared.gradio['image_load_model'] = gr.Button("Load", variant='primary', elem_classes='refresh-button') shared.gradio['image_unload_model'] = gr.Button("Unload", elem_classes='refresh-button') gr.Markdown("## Settings") with gr.Row(): with gr.Column(): shared.gradio['image_quant'] = gr.Dropdown( label='Quantization', choices=['none', 'bnb-8bit', 'bnb-4bit', 'torchao-int8wo', 'torchao-fp4', 'torchao-float8wo'], value=shared.settings['image_quant'], info='BnB: bitsandbytes quantization. torchao: int8wo, fp4, float8wo.' ) shared.gradio['image_dtype'] = gr.Dropdown( choices=['bfloat16', 'float16'], value=shared.settings['image_dtype'], label='Data Type', info='bfloat16 recommended for modern GPUs' ) shared.gradio['image_attn_backend'] = gr.Dropdown( choices=['sdpa', 'flash_attention_2'], value=shared.settings['image_attn_backend'], label='Attention Backend', info='SDPA is default. Flash Attention requires compatible GPU.' ) with gr.Column(): shared.gradio['image_compile'] = gr.Checkbox( value=shared.settings['image_compile'], label='Compile Model', info='Faster inference after first run. First run will be slow.' ) shared.gradio['image_cpu_offload'] = gr.Checkbox( value=shared.settings['image_cpu_offload'], label='CPU Offload', info='Enable for low VRAM GPUs. Slower but uses less memory.' ) with gr.Column(): shared.gradio['image_download_path'] = gr.Textbox( label="Download model", placeholder="Tongyi-MAI/Z-Image-Turbo", info="Enter HuggingFace path. Use : for branch, e.g. user/model:main" ) shared.gradio['image_download_btn'] = gr.Button("Download", variant='primary') shared.gradio['image_model_status'] = gr.Markdown(value="") def create_event_handlers(): # Dimension controls shared.gradio['image_aspect_ratio'].change( apply_aspect_ratio, gradio('image_aspect_ratio', 'image_width', 'image_height'), gradio('image_width', 'image_height'), show_progress=False ) shared.gradio['image_width'].release( update_height_from_width, gradio('image_width', 'image_aspect_ratio'), gradio('image_height'), show_progress=False ) shared.gradio['image_height'].release( update_width_from_height, gradio('image_height', 'image_aspect_ratio'), gradio('image_width'), show_progress=False ) shared.gradio['image_swap_btn'].click( swap_dimensions_and_update_ratio, gradio('image_width', 'image_height', 'image_aspect_ratio'), gradio('image_width', 'image_height', 'image_aspect_ratio'), show_progress=False ) # Generation shared.gradio['image_generate_btn'].click( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('image_stop_btn', 'image_generate_btn')).then( generate, gradio('interface_state'), gradio('image_output_gallery', 'image_progress'), show_progress=False).then( lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('image_stop_btn', 'image_generate_btn')) shared.gradio['image_prompt'].submit( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('image_stop_btn', 'image_generate_btn')).then( generate, gradio('interface_state'), gradio('image_output_gallery', 'image_progress'), show_progress=False).then( lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('image_stop_btn', 'image_generate_btn')) shared.gradio['image_neg_prompt'].submit( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('image_stop_btn', 'image_generate_btn')).then( generate, gradio('interface_state'), gradio('image_output_gallery', 'image_progress'), show_progress=False).then( lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('image_stop_btn', 'image_generate_btn')) # Stop button shared.gradio['image_stop_btn'].click( stop_everything_event, None, None, show_progress=False ) # Model management shared.gradio['image_refresh_models'].click( lambda: gr.update(choices=utils.get_available_image_models()), None, gradio('image_model_menu'), show_progress=False ) shared.gradio['image_load_model'].click( load_image_model_wrapper, gradio('image_model_menu', 'image_dtype', 'image_attn_backend', 'image_cpu_offload', 'image_compile', 'image_quant'), gradio('image_model_status'), show_progress=True ) shared.gradio['image_unload_model'].click( unload_image_model_wrapper, None, gradio('image_model_status'), show_progress=False ) shared.gradio['image_download_btn'].click( download_image_model_wrapper, gradio('image_download_path'), gradio('image_model_status', 'image_model_menu'), show_progress=True ) # Gallery pagination handlers shared.gradio['image_refresh_history'].click( refresh_gallery, gradio('image_current_page'), gradio('image_history_gallery', 'image_current_page', 'image_page_info'), show_progress=False ) shared.gradio['image_next_page'].click( next_page, gradio('image_current_page'), gradio('image_history_gallery', 'image_current_page', 'image_page_info'), show_progress=False ) shared.gradio['image_prev_page'].click( prev_page, gradio('image_current_page'), gradio('image_history_gallery', 'image_current_page', 'image_page_info'), show_progress=False ) shared.gradio['image_go_to_page'].click( go_to_page, gradio('image_page_input', 'image_current_page'), gradio('image_history_gallery', 'image_current_page', 'image_page_info'), show_progress=False ) # Image selection from gallery shared.gradio['image_history_gallery'].select( on_gallery_select, gradio('image_current_page'), gradio('image_selected_path', 'image_settings_display'), show_progress=False ) # Send to Generate shared.gradio['image_send_to_generate'].click( send_to_generate, gradio('image_selected_path'), gradio( 'image_prompt', 'image_neg_prompt', 'image_width', 'image_height', 'image_aspect_ratio', 'image_steps', 'image_seed', 'image_cfg_scale', 'image_gallery_status' ), js=f'() => {{{ui.switch_tabs_js}; switch_to_image_ai_generate()}}', show_progress=False ) shared.gradio['image_drop_upload'].change( read_dropped_image_metadata, gradio('image_drop_upload'), gradio('image_settings_display'), show_progress=False ) # LLM Variations visibility toggle shared.gradio['image_llm_variations'].change( lambda x: gr.update(visible=x), gradio('image_llm_variations'), gradio('image_llm_variations_prompt'), show_progress=False ) def generate_prompt_variation(state): """Generate a creative variation of the image prompt using the LLM.""" from modules.chat import generate_chat_prompt from modules.text_generation import generate_reply prompt = state['image_prompt'] # Check if LLM is loaded model_loaded, _ = check_model_loaded() if not model_loaded: logger.warning("No LLM loaded for prompt variation. Using original prompt.") return prompt # Get the custom variation prompt or use default variation_instruction = state.get('image_llm_variations_prompt', '') if not variation_instruction: variation_instruction = 'Write a variation of the image generation prompt above. Consider the intent of the user with that prompt and write something that will likely please them, with added details. Output only the new prompt. Do not add any explanations, prefixes, or additional text.' augmented_message = f"{prompt}\n\n=====\n\n{variation_instruction}" # Use minimal state for generation var_state = state.copy() var_state['history'] = {'internal': [], 'visible': [], 'metadata': {}} var_state['auto_max_new_tokens'] = True var_state['enable_thinking'] = False var_state['reasoning_effort'] = 'low' var_state['start_with'] = "" formatted_prompt = generate_chat_prompt(augmented_message, var_state) variation = "" for reply in generate_reply(formatted_prompt, var_state, stopping_strings=[], is_chat=True): variation = reply # Strip thinking blocks if present if "" in variation: variation = variation.rsplit("", 1)[1] elif "<|start|>assistant<|channel|>final<|message|>" in variation: variation = variation.rsplit("<|start|>assistant<|channel|>final<|message|>", 1)[1] elif "<|channel|>final<|message|>" in variation: variation = variation.rsplit("<|channel|>final<|message|>", 1)[1] elif "" in variation: variation = variation.rsplit("", 1)[1] variation = variation.strip() if len(variation) >= 2 and variation.startswith('"') and variation.endswith('"'): variation = variation[1:-1] if variation: logger.info("Prompt variation:") print(variation) return variation return prompt def progress_bar_html(progress=0, text=""): """Generate HTML for progress bar. Empty div when progress <= 0.""" if progress <= 0: return '
      ' return f'''
      {text}
      ''' def generate(state, save_images=True): """ Generate images using the loaded model. Automatically adjusts parameters based on pipeline type. """ import queue import threading import torch from modules.torch_utils import clear_torch_cache, get_device try: model_name = state['image_model_menu'] if not model_name or model_name == 'None': logger.error("No image model selected. Go to the Model tab and select a model.") yield [], progress_bar_html() return if shared.image_model is None: result = load_image_model( model_name, dtype=state['image_dtype'], attn_backend=state['image_attn_backend'], cpu_offload=state['image_cpu_offload'], compile_model=state['image_compile'], quant_method=state['image_quant'] ) if result is None: logger.error(f"Failed to load model `{model_name}`.") yield [], progress_bar_html() return shared.image_model_name = model_name seed = state['image_seed'] if seed == -1: seed = random.randint(0, 2**32 - 1) device = get_device() if device is None: device = "cpu" generator = torch.Generator(device) all_images = [] # Get pipeline type for parameter adjustment pipeline_type = getattr(shared, 'image_pipeline_type', None) if pipeline_type is None: pipeline_type = get_pipeline_type(shared.image_model) prompt = state['image_prompt'] shared.stop_everything = False batch_count = int(state['image_batch_count']) steps_per_batch = int(state['image_steps']) total_steps = steps_per_batch * batch_count # Queue for progress updates from callback progress_queue = queue.Queue() def interrupt_callback(pipe, step_index, timestep, callback_kwargs): if shared.stop_everything: pipe._interrupt = True progress_queue.put(step_index + 1) return callback_kwargs gen_kwargs = { "prompt": prompt, "negative_prompt": state['image_neg_prompt'], "height": int(state['image_height']), "width": int(state['image_width']), "num_inference_steps": steps_per_batch, "num_images_per_prompt": int(state['image_batch_size']), "generator": generator, "callback_on_step_end": interrupt_callback, } cfg_val = state.get('image_cfg_scale', 0.0) if pipeline_type == 'qwenimage': gen_kwargs["true_cfg_scale"] = cfg_val else: gen_kwargs["guidance_scale"] = cfg_val t0 = time.time() for batch_idx in range(batch_count): if shared.stop_everything: break generator.manual_seed(int(seed + batch_idx)) # Generate prompt variation if enabled if state['image_llm_variations']: gen_kwargs["prompt"] = generate_prompt_variation(state) # Run generation in thread so we can yield progress result_holder = [] error_holder = [] def run_batch(): try: # Apply magic suffix only at generation time for qwenimage clean_prompt = gen_kwargs["prompt"] if pipeline_type == 'qwenimage': magic_suffix = ", Ultra HD, 4K, cinematic composition" if magic_suffix.strip(", ") not in clean_prompt: gen_kwargs["prompt"] = clean_prompt + magic_suffix result_holder.extend(shared.image_model(**gen_kwargs).images) gen_kwargs["prompt"] = clean_prompt # restore except Exception as e: error_holder.append(e) thread = threading.Thread(target=run_batch) thread.start() # Yield progress updates while generation runs while thread.is_alive(): try: step = progress_queue.get(timeout=0.1) absolute_step = batch_idx * steps_per_batch + step pct = absolute_step / total_steps text = f"Batch {batch_idx + 1}/{batch_count} — Step {step}/{steps_per_batch}" yield all_images, progress_bar_html(pct, text) except queue.Empty: pass thread.join() if error_holder: raise error_holder[0] # Save this batch's images with the actual prompt and seed used if save_images: batch_seed = seed + batch_idx original_prompt = state['image_prompt'] state['image_prompt'] = gen_kwargs["prompt"] saved_paths = save_generated_images(result_holder, state, batch_seed) state['image_prompt'] = original_prompt # Use file paths so gallery serves actual PNGs with metadata all_images.extend(saved_paths) else: # Fallback to PIL objects if not saving all_images.extend(result_holder) yield all_images, progress_bar_html((batch_idx + 1) / batch_count, f"Batch {batch_idx + 1}/{batch_count} complete") t1 = time.time() total_images = batch_count * int(state['image_batch_size']) logger.info(f'Generated {total_images} {"image" if total_images == 1 else "images"} in {(t1 - t0):.2f} seconds ({total_steps / (t1 - t0):.2f} steps/s, seed {seed})') yield all_images, progress_bar_html() clear_torch_cache() except Exception as e: logger.error(f"Image generation failed: {e}") traceback.print_exc() yield [], progress_bar_html() clear_torch_cache() def load_image_model_wrapper(model_name, dtype, attn_backend, cpu_offload, compile_model, quant_method): if not model_name or model_name == 'None': yield "No model selected" return try: yield f"Loading `{model_name}`..." unload_image_model() result = load_image_model( model_name, dtype=dtype, attn_backend=attn_backend, cpu_offload=cpu_offload, compile_model=compile_model, quant_method=quant_method ) if result is not None: shared.image_model_name = model_name yield f"✓ Loaded **{model_name}** (quantization: {quant_method})" else: yield f"✗ Failed to load `{model_name}`" except Exception: yield f"Error:\n```\n{traceback.format_exc()}\n```" def unload_image_model_wrapper(): previous_name = shared.image_model_name unload_image_model() if previous_name != 'None': return f"Model: **{previous_name}** (unloaded)" return "No model loaded" def download_image_model_wrapper(model_path): from huggingface_hub import snapshot_download if not model_path: yield "No model specified", gr.update() return try: model_path = model_path.strip() if model_path.startswith('https://huggingface.co/'): model_path = model_path[len('https://huggingface.co/'):] elif model_path.startswith('huggingface.co/'): model_path = model_path[len('huggingface.co/'):] if ':' in model_path: model_id, branch = model_path.rsplit(':', 1) else: model_id, branch = model_path, 'main' folder_name = model_id.replace('/', '_') output_folder = Path(shared.args.image_model_dir) / folder_name yield f"Downloading `{model_id}` (branch: {branch})...", gr.update() snapshot_download( repo_id=model_id, revision=branch, local_dir=output_folder, local_dir_use_symlinks=False, ) new_choices = utils.get_available_image_models() yield f"✓ Downloaded to `{output_folder}`", gr.update(choices=new_choices, value=folder_name) except Exception: yield f"Error:\n```\n{traceback.format_exc()}\n```", gr.update() ================================================ FILE: modules/ui_model_menu.py ================================================ import importlib import math import queue import threading import traceback from functools import partial from pathlib import Path import gradio as gr from modules import loaders, shared, ui, utils from modules.logging_colors import logger from modules.LoRA import add_lora_to_model from modules.models import load_model, unload_model from modules.models_settings import ( apply_model_settings_to_state, get_model_metadata, save_instruction_template, save_model_settings, update_gpu_layers_and_vram, update_model_parameters ) from modules.utils import gradio def create_ui(): mu = shared.args.multi_user with gr.Tab("Model", elem_id="model-tab"): with gr.Row(): with gr.Column(): with gr.Row(): shared.gradio['model_menu'] = gr.Dropdown(choices=utils.get_available_models(), value=lambda: shared.model_name, label='Model', elem_classes='slim-dropdown', interactive=not mu) ui.create_refresh_button(shared.gradio['model_menu'], lambda: None, lambda: {'choices': utils.get_available_models()}, 'refresh-button', interactive=not mu) shared.gradio['load_model'] = gr.Button("Load", elem_classes='refresh-button', interactive=not mu) shared.gradio['unload_model'] = gr.Button("Unload", elem_classes='refresh-button', interactive=not mu) shared.gradio['save_model_settings'] = gr.Button("Save settings", elem_classes='refresh-button', interactive=not mu) shared.gradio['loader'] = gr.Dropdown(label="Model loader", choices=loaders.loaders_and_params.keys() if not shared.args.portable else ['llama.cpp'], value=None) with gr.Blocks(): gr.Markdown("## Main options") with gr.Row(): with gr.Column(): shared.gradio['gpu_layers'] = gr.Slider(label="gpu-layers", minimum=-1, maximum=get_initial_gpu_layers_max(), step=1, value=shared.args.gpu_layers, info='Number of layers to offload to the GPU. -1 = auto.') shared.gradio['ctx_size'] = gr.Slider(label='ctx-size', minimum=0, maximum=1048576, step=1024, value=shared.args.ctx_size, info='Context length. 0 = auto for llama.cpp (requires gpu-layers=-1), 8192 for other loaders. Common values: 4096, 8192, 16384, 32768, 65536, 131072.') shared.gradio['gpu_split'] = gr.Textbox(label='gpu-split', info='Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7') shared.gradio['attn_implementation'] = gr.Dropdown(label="attn-implementation", choices=['sdpa', 'eager', 'flash_attention_2'], value=shared.args.attn_implementation, info='Attention implementation.') shared.gradio['cache_type'] = gr.Dropdown(label="cache-type", choices=['fp16', 'q8_0', 'q4_0', 'fp8', 'q8', 'q7', 'q6', 'q5', 'q4', 'q3', 'q2'], value=shared.args.cache_type, allow_custom_value=True, info='Valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV3 - fp16, q2 to q8. For ExLlamaV3, you can type custom combinations for separate k/v bits (e.g. q4_q8).') shared.gradio['fit_target'] = gr.Textbox(label='fit-target', value=shared.args.fit_target, info='Target VRAM margin per device for auto GPU layers (MiB). Comma-separated list for multiple devices.') shared.gradio['tp_backend'] = gr.Dropdown(label="tp-backend", choices=['native', 'nccl'], value=shared.args.tp_backend, info='The backend for tensor parallelism.') with gr.Column(): shared.gradio['vram_info'] = gr.HTML(value=get_initial_vram_info()) shared.gradio['cpu_moe'] = gr.Checkbox(label="cpu-moe", value=shared.args.cpu_moe, info='Move the experts to the CPU. Saves VRAM on MoE models.') shared.gradio['streaming_llm'] = gr.Checkbox(label="streaming-llm", value=shared.args.streaming_llm, info='Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.') shared.gradio['load_in_8bit'] = gr.Checkbox(label="load-in-8bit", value=shared.args.load_in_8bit) shared.gradio['load_in_4bit'] = gr.Checkbox(label="load-in-4bit", value=shared.args.load_in_4bit) shared.gradio['use_double_quant'] = gr.Checkbox(label="use_double_quant", value=shared.args.use_double_quant, info='Used by load-in-4bit.') shared.gradio['enable_tp'] = gr.Checkbox(label="enable_tp", value=shared.args.enable_tp, info='Enable tensor parallelism (TP).') shared.gradio['tensorrt_llm_info'] = gr.Markdown( '* TensorRT-LLM has to be installed manually: `pip install tensorrt_llm==1.1.0 --extra-index-url https://pypi.nvidia.com`.\n\n' '* You can load either a pre-built TensorRT engine or a regular HF model. ' 'HF models will be compiled to a TensorRT engine automatically on each load (this can take a while).' ) # Multimodal with gr.Accordion("Multimodal (vision)", open=False, elem_classes='tgw-accordion') as shared.gradio['mmproj_accordion']: with gr.Row(): shared.gradio['mmproj'] = gr.Dropdown(label="mmproj file", choices=utils.get_available_mmproj(), value=lambda: shared.args.mmproj or 'None', elem_classes='slim-dropdown', info=f'Select a file that matches your model. Must be placed in {shared.user_data_dir}/mmproj/', interactive=not mu) ui.create_refresh_button(shared.gradio['mmproj'], lambda: None, lambda: {'choices': utils.get_available_mmproj()}, 'refresh-button', interactive=not mu) # Speculative decoding with gr.Accordion("Speculative decoding", open=False, elem_classes='tgw-accordion') as shared.gradio['speculative_decoding_accordion']: shared.gradio['draft_max'] = gr.Number(label="draft-max", precision=0, step=1, value=shared.args.draft_max, info='Maximum number of tokens to draft for speculative decoding. Recommended: 4 for draft model, 64 for n-gram.') gr.Markdown('#### Draft model') with gr.Row(): shared.gradio['model_draft'] = gr.Dropdown(label="model-draft", choices=['None'] + utils.get_available_models(), value=lambda: shared.args.model_draft, elem_classes='slim-dropdown', info='Draft model. Must share the same vocabulary as the main model.', interactive=not mu) ui.create_refresh_button(shared.gradio['model_draft'], lambda: None, lambda: {'choices': ['None'] + utils.get_available_models()}, 'refresh-button', interactive=not mu) shared.gradio['gpu_layers_draft'] = gr.Slider(label="gpu-layers-draft", minimum=0, maximum=256, value=shared.args.gpu_layers_draft, info='Number of layers to offload to the GPU for the draft model.') shared.gradio['device_draft'] = gr.Textbox(label="device-draft", value=shared.args.device_draft, info='Comma-separated list of devices to use for offloading the draft model. Example: CUDA0,CUDA1') shared.gradio['ctx_size_draft'] = gr.Number(label="ctx-size-draft", precision=0, step=256, value=shared.args.ctx_size_draft, info='Size of the prompt context for the draft model. If 0, uses the same as the main model.') shared.gradio['ngram_header'] = gr.Markdown('#### N-gram (draftless)') shared.gradio['spec_type'] = gr.Dropdown(label="spec-type", choices=['none', 'ngram-mod', 'ngram-simple', 'ngram-map-k', 'ngram-map-k4v', 'ngram-cache'], value=shared.args.spec_type, info='Draftless speculative decoding type. Recommended: ngram-mod.') shared.gradio['spec_ngram_size_n'] = gr.Number(label="spec-ngram-size-n", precision=0, step=1, value=shared.args.spec_ngram_size_n, info='N-gram lookup size for speculative decoding.', visible=shared.args.spec_type != 'none') shared.gradio['spec_ngram_size_m'] = gr.Number(label="spec-ngram-size-m", precision=0, step=1, value=shared.args.spec_ngram_size_m, info='Draft n-gram size for speculative decoding.', visible=shared.args.spec_type != 'none') shared.gradio['spec_ngram_min_hits'] = gr.Number(label="spec-ngram-min-hits", precision=0, step=1, value=shared.args.spec_ngram_min_hits, info='Minimum n-gram hits for ngram-map speculative decoding.', visible=shared.args.spec_type != 'none') gr.Markdown("## Other options") with gr.Accordion("See more options", open=False, elem_classes='tgw-accordion'): with gr.Row(): with gr.Column(): shared.gradio['parallel'] = gr.Slider(label="parallel", minimum=1, step=1, maximum=64, value=shared.args.parallel, info='Number of parallel request slots for the API. The context size is divided equally among slots. For example, to have 4 slots with 8192 context each, set ctx_size to 32768.') shared.gradio['threads'] = gr.Slider(label="threads", minimum=0, step=1, maximum=256, value=shared.args.threads) shared.gradio['threads_batch'] = gr.Slider(label="threads_batch", minimum=0, step=1, maximum=256, value=shared.args.threads_batch) shared.gradio['batch_size'] = gr.Slider(label="batch_size", minimum=1, maximum=4096, step=1, value=shared.args.batch_size) shared.gradio['ubatch_size'] = gr.Slider(label="ubatch_size", minimum=1, maximum=4096, step=1, value=shared.args.ubatch_size) shared.gradio['tensor_split'] = gr.Textbox(label='tensor_split', info='List of proportions to split the model across multiple GPUs. Example: 60,40') shared.gradio['extra_flags'] = gr.Textbox(label='extra-flags', info='Additional flags to pass to llama-server. Format: "flag1=value1,flag2,flag3=value3". Example: "override-tensor=exps=CPU"', value=shared.args.extra_flags) shared.gradio['cpu_memory'] = gr.Number(label="Maximum CPU memory in GiB. Use this for CPU offloading.", value=shared.args.cpu_memory) shared.gradio['compute_dtype'] = gr.Dropdown(label="compute_dtype", choices=["bfloat16", "float16", "float32"], value=shared.args.compute_dtype, info='Used by load-in-4bit.') shared.gradio['quant_type'] = gr.Dropdown(label="quant_type", choices=["nf4", "fp4"], value=shared.args.quant_type, info='Used by load-in-4bit.') with gr.Column(): shared.gradio['cpu'] = gr.Checkbox(label="cpu", value=shared.args.cpu, info='Use PyTorch in CPU mode.') shared.gradio['disk'] = gr.Checkbox(label="disk", value=shared.args.disk) shared.gradio['row_split'] = gr.Checkbox(label="row_split", value=shared.args.row_split, info='Split the model by rows across GPUs. This may improve multi-gpu performance.') shared.gradio['no_kv_offload'] = gr.Checkbox(label="no_kv_offload", value=shared.args.no_kv_offload, info='Do not offload the K, Q, V to the GPU. This saves VRAM but reduces the performance.') shared.gradio['no_mmap'] = gr.Checkbox(label="no-mmap", value=shared.args.no_mmap) shared.gradio['mlock'] = gr.Checkbox(label="mlock", value=shared.args.mlock) shared.gradio['numa'] = gr.Checkbox(label="numa", value=shared.args.numa, info='NUMA support can help on some systems with non-uniform memory access.') shared.gradio['bf16'] = gr.Checkbox(label="bf16", value=shared.args.bf16) shared.gradio['cfg_cache'] = gr.Checkbox(label="cfg-cache", value=shared.args.cfg_cache, info='Necessary to use CFG with this loader.') shared.gradio['no_use_fast'] = gr.Checkbox(label="no_use_fast", value=shared.args.no_use_fast, info='Set use_fast=False while loading the tokenizer.') if not shared.args.portable: with gr.Row(): shared.gradio['lora_menu'] = gr.Dropdown(multiselect=True, choices=utils.get_available_loras(), value=shared.lora_names, label='LoRA(s)', elem_classes='slim-dropdown', interactive=not mu) ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': utils.get_available_loras(), 'value': shared.lora_names}, 'refresh-button', interactive=not mu) shared.gradio['lora_menu_apply'] = gr.Button(value='Apply LoRAs', elem_classes='refresh-button', interactive=not mu) with gr.Column(): with gr.Tab("Download"): shared.gradio['custom_model_menu'] = gr.Textbox(label="Download model or LoRA", info="Enter the Hugging Face username/model path, for instance: facebook/galactica-125m. To specify a branch, add it at the end after a \":\" character like this: facebook/galactica-125m:main. To download a single file, enter its name in the second box.", interactive=not mu) shared.gradio['download_specific_file'] = gr.Textbox(placeholder="File name (for GGUF models)", show_label=False, max_lines=1, interactive=not mu) with gr.Row(): shared.gradio['download_model_button'] = gr.Button("Download", variant='primary', interactive=not mu) shared.gradio['get_file_list'] = gr.Button("Get file list", interactive=not mu) with gr.Tab("Customize instruction template"): with gr.Row(): shared.gradio['customized_template'] = gr.Dropdown(choices=utils.get_available_instruction_templates(), value='None', label='Select the desired instruction template', elem_classes='slim-dropdown') ui.create_refresh_button(shared.gradio['customized_template'], lambda: None, lambda: {'choices': utils.get_available_instruction_templates()}, 'refresh-button', interactive=not mu) shared.gradio['customized_template_submit'] = gr.Button("Submit", variant="primary", interactive=not mu) gr.Markdown("This allows you to set a customized template for the model currently selected in the \"Model loader\" menu. Whenever the model gets loaded, this template will be used in place of the template specified in the model's medatada, which sometimes is wrong.") with gr.Row(): shared.gradio['model_status'] = gr.Markdown('No model is loaded' if shared.model_name == 'None' else 'Ready') def create_event_handlers(): mu = shared.args.multi_user if mu: return shared.gradio['loader'].change(loaders.make_loader_params_visible, gradio('loader'), gradio(loaders.get_all_params()), show_progress=False) # In this event handler, the interface state is read and updated # with the model defaults (if any), and then the model is loaded shared.gradio['model_menu'].change( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( handle_load_model_event_initial, gradio('model_menu', 'interface_state'), gradio(ui.list_interface_input_elements()) + gradio('interface_state') + gradio('vram_info'), show_progress=False).then( partial(load_model_wrapper, autoload=False), gradio('model_menu', 'loader'), gradio('model_status'), show_progress=True).success( handle_load_model_event_final, gradio('truncation_length', 'loader', 'interface_state'), gradio('truncation_length', 'filter_by_loader'), show_progress=False) shared.gradio['load_model'].click( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( update_model_parameters, gradio('interface_state'), None).then( partial(load_model_wrapper, autoload=True), gradio('model_menu', 'loader'), gradio('model_status'), show_progress=True).success( handle_load_model_event_final, gradio('truncation_length', 'loader', 'interface_state'), gradio('truncation_length', 'filter_by_loader'), show_progress=False) shared.gradio['unload_model'].click(handle_unload_model_click, None, gradio('model_status'), show_progress=False).then( update_gpu_layers_and_vram, gradio('loader', 'model_menu', 'gpu_layers', 'ctx_size', 'cache_type'), gradio('vram_info'), show_progress=False) shared.gradio['save_model_settings'].click( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( save_model_settings, gradio('model_menu', 'interface_state'), gradio('model_status'), show_progress=False) # For ctx_size and cache_type - update VRAM display for param in ['ctx_size', 'cache_type']: shared.gradio[param].change( update_gpu_layers_and_vram, gradio('loader', 'model_menu', 'gpu_layers', 'ctx_size', 'cache_type'), gradio('vram_info'), show_progress=False) # For manual gpu_layers changes - only update VRAM shared.gradio['gpu_layers'].change( update_gpu_layers_and_vram, gradio('loader', 'model_menu', 'gpu_layers', 'ctx_size', 'cache_type'), gradio('vram_info'), show_progress=False) if not shared.args.portable: shared.gradio['lora_menu_apply'].click(load_lora_wrapper, gradio('lora_menu'), gradio('model_status'), show_progress=False) shared.gradio['spec_type'].change( lambda x: [gr.update(visible=x != 'none')] * 3, gradio('spec_type'), gradio('spec_ngram_size_n', 'spec_ngram_size_m', 'spec_ngram_min_hits'), show_progress=False ) shared.gradio['download_model_button'].click(download_model_wrapper, gradio('custom_model_menu', 'download_specific_file'), gradio('model_status'), show_progress=True) shared.gradio['get_file_list'].click(partial(download_model_wrapper, return_links=True), gradio('custom_model_menu', 'download_specific_file'), gradio('model_status'), show_progress=True) shared.gradio['customized_template_submit'].click(save_instruction_template, gradio('model_menu', 'customized_template'), gradio('model_status'), show_progress=True) def load_model_wrapper(selected_model, loader, autoload=False): try: settings = get_model_metadata(selected_model) except FileNotFoundError: exc = traceback.format_exc() yield exc.replace('\n', '\n\n') return if not autoload: yield "### {}\n\n- Settings updated: Click \"Load\" to load the model\n- Max sequence length: {}".format(selected_model, settings['truncation_length_info']) return if selected_model == 'None': yield "No model selected" else: try: yield f"Loading `{selected_model}`..." unload_model() if selected_model != '': shared.model, shared.tokenizer = load_model(selected_model, loader) if shared.model is not None: yield f"Successfully loaded `{selected_model}`." else: yield f"Failed to load `{selected_model}`." except Exception: exc = traceback.format_exc() logger.error('Failed to load the model.') print(exc) yield exc.replace('\n', '\n\n') def load_lora_wrapper(selected_loras): yield ("Applying the following LoRAs to {}:\n\n{}".format(shared.model_name, '\n'.join(selected_loras))) add_lora_to_model(selected_loras) yield ("Successfuly applied the LoRAs") def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), return_links=False, check=False): downloader_module = importlib.import_module("download-model") downloader = downloader_module.ModelDownloader() update_queue = queue.Queue() try: # Handle direct GGUF URLs if repo_id.startswith("https://") and ("huggingface.co" in repo_id) and (repo_id.endswith(".gguf") or repo_id.endswith(".gguf?download=true")): try: path = repo_id.split("huggingface.co/")[1] parts = path.split("/") if len(parts) >= 2: extracted_repo_id = f"{parts[0]}/{parts[1]}" filename = repo_id.split("/")[-1].replace("?download=true", "") repo_id = extracted_repo_id specific_file = filename except Exception as e: yield f"Error parsing GGUF URL: {e}" progress(0.0) return if not repo_id: yield "Please enter a model path." progress(0.0) return repo_id = repo_id.strip() specific_file = specific_file.strip() progress(0.0, "Preparing download...") model, branch = downloader.sanitize_model_and_branch_names(repo_id, None) yield "Getting download links from Hugging Face..." links, sha256, is_lora, is_llamacpp, file_sizes = downloader.get_download_links_from_huggingface(model, branch, text_only=False, specific_file=specific_file) if not links: yield "No files found to download for the given model/criteria." progress(0.0) return # Check for multiple GGUF files gguf_files = [link for link in links if link.lower().endswith('.gguf')] if len(gguf_files) > 1 and not specific_file: # Sort by size in ascending order gguf_data = [] for i, link in enumerate(links): if link.lower().endswith('.gguf'): file_size = file_sizes[i] gguf_data.append((file_size, link)) gguf_data.sort(key=lambda x: x[0]) output = "Multiple GGUF files found. Please copy one of the following filenames to the 'File name' field above:\n\n```\n" for file_size, link in gguf_data: size_str = format_file_size(file_size) output += f"{size_str} - {Path(link).name}\n" output += "```" yield output return if return_links: # Sort by size in ascending order file_data = list(zip(file_sizes, links)) file_data.sort(key=lambda x: x[0]) output = "```\n" for file_size, link in file_data: size_str = format_file_size(file_size) output += f"{size_str} - {Path(link).name}\n" output += "```" yield output return yield "Determining output folder..." output_folder = downloader.get_output_folder( model, branch, is_lora, is_llamacpp=is_llamacpp, model_dir=shared.args.model_dir if shared.args.model_dir != shared.args_defaults.model_dir else None ) if output_folder == shared.user_data_dir / "models": output_folder = Path(shared.args.model_dir) elif output_folder == shared.user_data_dir / "loras": output_folder = Path(shared.args.lora_dir) if check: yield "Checking previously downloaded files..." progress(0.5, "Verifying files...") downloader.check_model_files(model, branch, links, sha256, output_folder) progress(1.0, "Verification complete.") yield "File check complete." return yield "" progress(0.0, "Download starting...") def downloader_thread_target(): try: downloader.download_model_files( model, branch, links, sha256, output_folder, progress_queue=update_queue, threads=4, is_llamacpp=is_llamacpp, specific_file=specific_file ) update_queue.put(("COMPLETED", f"Model successfully saved to `{output_folder}/`.")) except Exception as e: tb_str = traceback.format_exc().replace('\n', '\n\n') update_queue.put(("ERROR", tb_str)) download_thread = threading.Thread(target=downloader_thread_target) download_thread.start() while True: try: message = update_queue.get(timeout=0.2) if not isinstance(message, tuple) or len(message) != 2: continue msg_identifier, data = message if msg_identifier == "COMPLETED": progress(1.0, "Download complete!") yield data break elif msg_identifier == "ERROR": progress(0.0, "Error occurred") yield data break elif isinstance(msg_identifier, float): progress_value = msg_identifier description_str = data progress(progress_value, f"Downloading: {description_str}") except queue.Empty: if not download_thread.is_alive(): yield "Download process finished." break download_thread.join() except Exception as e: progress(0.0) tb_str = traceback.format_exc().replace('\n', '\n\n') yield tb_str def update_truncation_length(current_length, state): if 'loader' in state: if state['loader'].lower().startswith('exllama') or state['loader'] == 'llama.cpp': if state['ctx_size'] > 0: return state['ctx_size'] # ctx_size == 0 means auto: use the actual value from the server return shared.settings['truncation_length'] return current_length def get_initial_vram_info(): if shared.model_name != 'None' and shared.args.loader == 'llama.cpp': return update_gpu_layers_and_vram( shared.args.loader, shared.model_name, shared.args.gpu_layers, shared.args.ctx_size, shared.args.cache_type, ) return "
      Estimated VRAM to load the model:
      " def get_initial_gpu_layers_max(): if shared.model_name != 'None' and shared.args.loader == 'llama.cpp': model_settings = get_model_metadata(shared.model_name) return model_settings.get('max_gpu_layers', 256) return 256 def handle_load_model_event_initial(model, state): state = apply_model_settings_to_state(model, state) output = ui.apply_interface_values(state) update_model_parameters(state) # This updates the command-line flags vram_info = state.get('vram_info', "
      Estimated VRAM to load the model:
      ") return output + [state] + [vram_info] def handle_load_model_event_final(truncation_length, loader, state): truncation_length = update_truncation_length(truncation_length, state) return [truncation_length, loader] def handle_unload_model_click(): unload_model() return "Model unloaded" def format_file_size(size_bytes): """Convert bytes to human readable format with 2 decimal places for GB and above""" if size_bytes == 0: return "0 B" size_names = ["B", "KB", "MB", "GB", "TB"] i = int(math.floor(math.log(size_bytes, 1024))) p = math.pow(1024, i) s = size_bytes / p if i >= 3: # GB or TB return f"{s:.2f} {size_names[i]}" else: return f"{s:.1f} {size_names[i]}" ================================================ FILE: modules/ui_notebook.py ================================================ import threading import time from pathlib import Path import gradio as gr from modules import logits, shared, ui, utils from modules.prompts import count_tokens, load_prompt from modules.text_generation import ( generate_reply_wrapper, get_token_ids, stop_everything_event ) from modules.utils import gradio _notebook_file_lock = threading.Lock() _notebook_auto_save_timer = None _last_notebook_text = None _last_notebook_prompt = None inputs = ('textbox-notebook', 'interface_state') outputs = ('textbox-notebook', 'html-notebook') def create_ui(): mu = shared.args.multi_user with gr.Row(visible=not shared.settings['show_two_notebook_columns']) as shared.gradio['notebook-tab']: shared.gradio['last_input-notebook'] = gr.State('') with gr.Row(): with gr.Column(scale=4): with gr.Tab('Raw'): with gr.Row(): shared.gradio['textbox-notebook'] = gr.Textbox(label="", value="", lines=27, elem_id='textbox-notebook', elem_classes=['textbox', 'add_scrollbar']) shared.gradio['token-counter-notebook'] = gr.HTML(value="0", elem_id="notebook-token-counter") with gr.Tab('Markdown'): shared.gradio['markdown_render-notebook'] = gr.Button('Render') shared.gradio['markdown-notebook'] = gr.Markdown() with gr.Tab('HTML'): shared.gradio['html-notebook'] = gr.HTML() with gr.Tab('Logits'): with gr.Row(): with gr.Column(scale=10): shared.gradio['get_logits-notebook'] = gr.Button('Get next token probabilities') with gr.Column(scale=1): shared.gradio['use_samplers-notebook'] = gr.Checkbox(label='Use samplers', value=True, elem_classes=['no-background']) with gr.Row(): shared.gradio['logits-notebook'] = gr.Textbox(lines=23, label='Output', elem_classes=['textbox_logits_notebook', 'add_scrollbar']) shared.gradio['logits-notebook-previous'] = gr.Textbox(lines=23, label='Previous output', elem_classes=['textbox_logits_notebook', 'add_scrollbar']) with gr.Tab('Tokens'): shared.gradio['get_tokens-notebook'] = gr.Button('Get token IDs for the input') shared.gradio['tokens-notebook'] = gr.Textbox(lines=23, label='Tokens', elem_classes=['textbox_logits_notebook', 'add_scrollbar', 'monospace']) with gr.Row(): shared.gradio['Undo'] = gr.Button('Undo', elem_classes='small-button') shared.gradio['Regenerate-notebook'] = gr.Button('Regenerate', elem_classes='small-button') shared.gradio['Stop-notebook'] = gr.Button('Stop', visible=False, elem_classes='small-button', elem_id='stop') shared.gradio['Generate-notebook'] = gr.Button('Generate', variant='primary', elem_classes='small-button') with gr.Column(scale=1): gr.HTML('
      ') with gr.Row(): shared.gradio['prompt_menu-notebook'] = gr.Dropdown(choices=utils.get_available_prompts(), value=shared.settings['prompt-notebook'], label='Prompt', elem_classes='slim-dropdown') with gr.Row(): ui.create_refresh_button(shared.gradio['prompt_menu-notebook'], lambda: None, lambda: {'choices': utils.get_available_prompts()}, ['refresh-button'], interactive=not mu) shared.gradio['new_prompt-notebook'] = gr.Button('New', elem_classes=['refresh-button'], interactive=not mu) shared.gradio['rename_prompt-notebook'] = gr.Button('Rename', elem_classes=['refresh-button'], interactive=not mu) shared.gradio['delete_prompt-notebook'] = gr.Button('🗑️', elem_classes=['refresh-button'], interactive=not mu) shared.gradio['delete_prompt-confirm-notebook'] = gr.Button('Confirm', variant='stop', elem_classes=['refresh-button'], visible=False) shared.gradio['delete_prompt-cancel-notebook'] = gr.Button('Cancel', elem_classes=['refresh-button'], visible=False) with gr.Row(visible=False) as shared.gradio['rename-row-notebook']: shared.gradio['rename_prompt_to-notebook'] = gr.Textbox(label="New name", elem_classes=['no-background']) shared.gradio['rename_prompt-cancel-notebook'] = gr.Button('Cancel', elem_classes=['refresh-button']) shared.gradio['rename_prompt-confirm-notebook'] = gr.Button('Confirm', elem_classes=['refresh-button'], variant='primary') def create_event_handlers(): shared.gradio['Generate-notebook'].click( lambda x: x, gradio('textbox-notebook'), gradio('last_input-notebook')).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('Stop-notebook', 'Generate-notebook')).then( generate_and_save_wrapper_notebook, gradio('textbox-notebook', 'interface_state', 'prompt_menu-notebook'), gradio(outputs), show_progress=False).then( lambda state, text: state.update({'textbox-notebook': text}), gradio('interface_state', 'textbox-notebook'), None).then( lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('Stop-notebook', 'Generate-notebook')).then( None, None, None, js=f'() => {{{ui.audio_notification_js}}}') shared.gradio['textbox-notebook'].submit( lambda x: x, gradio('textbox-notebook'), gradio('last_input-notebook')).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('Stop-notebook', 'Generate-notebook')).then( generate_and_save_wrapper_notebook, gradio('textbox-notebook', 'interface_state', 'prompt_menu-notebook'), gradio(outputs), show_progress=False).then( lambda state, text: state.update({'textbox-notebook': text}), gradio('interface_state', 'textbox-notebook'), None).then( lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('Stop-notebook', 'Generate-notebook')).then( None, None, None, js=f'() => {{{ui.audio_notification_js}}}') shared.gradio['Regenerate-notebook'].click( lambda x: x, gradio('last_input-notebook'), gradio('textbox-notebook'), show_progress=False).then( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('Stop-notebook', 'Generate-notebook')).then( generate_and_save_wrapper_notebook, gradio('textbox-notebook', 'interface_state', 'prompt_menu-notebook'), gradio(outputs), show_progress=False).then( lambda state, text: state.update({'textbox-notebook': text}), gradio('interface_state', 'textbox-notebook'), None).then( lambda: [gr.update(visible=False), gr.update(visible=True)], None, gradio('Stop-notebook', 'Generate-notebook')).then( None, None, None, js=f'() => {{{ui.audio_notification_js}}}') shared.gradio['Undo'].click( lambda x: x, gradio('last_input-notebook'), gradio('textbox-notebook'), show_progress=False).then( lambda state, text: state.update({'textbox-notebook': text}), gradio('interface_state', 'textbox-notebook'), None) shared.gradio['markdown_render-notebook'].click(lambda x: x, gradio('textbox-notebook'), gradio('markdown-notebook'), queue=False) shared.gradio['Stop-notebook'].click(stop_everything_event, None, None, queue=False) shared.gradio['prompt_menu-notebook'].change(load_prompt, gradio('prompt_menu-notebook'), gradio('textbox-notebook'), show_progress=False) shared.gradio['new_prompt-notebook'].click(handle_new_prompt, None, gradio('prompt_menu-notebook'), show_progress=False) shared.gradio['delete_prompt-notebook'].click( lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)], None, gradio('delete_prompt-notebook', 'delete_prompt-cancel-notebook', 'delete_prompt-confirm-notebook'), show_progress=False) shared.gradio['delete_prompt-cancel-notebook'].click( lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)], None, gradio('delete_prompt-notebook', 'delete_prompt-cancel-notebook', 'delete_prompt-confirm-notebook'), show_progress=False) shared.gradio['delete_prompt-confirm-notebook'].click( handle_delete_prompt_confirm_notebook, gradio('prompt_menu-notebook'), gradio('prompt_menu-notebook', 'delete_prompt-notebook', 'delete_prompt-cancel-notebook', 'delete_prompt-confirm-notebook'), show_progress=False) shared.gradio['rename_prompt-notebook'].click( handle_rename_prompt_click_notebook, gradio('prompt_menu-notebook'), gradio('rename_prompt_to-notebook', 'rename_prompt-notebook', 'rename-row-notebook'), show_progress=False) shared.gradio['rename_prompt-cancel-notebook'].click( lambda: [gr.update(visible=True), gr.update(visible=False)], None, gradio('rename_prompt-notebook', 'rename-row-notebook'), show_progress=False) shared.gradio['rename_prompt-confirm-notebook'].click( handle_rename_prompt_confirm_notebook, gradio('rename_prompt_to-notebook', 'prompt_menu-notebook'), gradio('prompt_menu-notebook', 'rename_prompt-notebook', 'rename-row-notebook'), show_progress=False) shared.gradio['textbox-notebook'].input(lambda x: f"{count_tokens(x)}", gradio('textbox-notebook'), gradio('token-counter-notebook'), show_progress=False) shared.gradio['textbox-notebook'].change( store_notebook_state_and_debounce, gradio('textbox-notebook', 'prompt_menu-notebook'), None, show_progress=False ) shared.gradio['get_logits-notebook'].click( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( logits.get_next_logits, gradio('textbox-notebook', 'interface_state', 'use_samplers-notebook', 'logits-notebook'), gradio('logits-notebook', 'logits-notebook-previous'), show_progress=False) shared.gradio['get_tokens-notebook'].click(get_token_ids, gradio('textbox-notebook'), gradio('tokens-notebook'), show_progress=False) def generate_and_save_wrapper_notebook(textbox_content, interface_state, prompt_name): """Generate reply and automatically save the result for notebook mode with periodic saves""" last_save_time = time.monotonic() save_interval = 8 output = textbox_content # Initial autosave safe_autosave_prompt(output, prompt_name) for i, (output, html_output) in enumerate(generate_reply_wrapper(textbox_content, interface_state)): yield output, html_output current_time = time.monotonic() # Save on first iteration or if save_interval seconds have passed if i == 0 or (current_time - last_save_time) >= save_interval: safe_autosave_prompt(output, prompt_name) last_save_time = current_time # Final autosave safe_autosave_prompt(output, prompt_name) def handle_new_prompt(): new_name = utils.current_time() # Create the new prompt file prompt_path = shared.user_data_dir / "logs" / "notebook" / f"{new_name}.txt" prompt_path.parent.mkdir(parents=True, exist_ok=True) prompt_path.write_text("In this story,", encoding='utf-8') return gr.update(choices=utils.get_available_prompts(), value=new_name) def handle_delete_prompt_confirm_notebook(prompt_name): available_prompts = utils.get_available_prompts() current_index = available_prompts.index(prompt_name) if prompt_name in available_prompts else 0 (shared.user_data_dir / "logs" / "notebook" / f"{prompt_name}.txt").unlink(missing_ok=True) available_prompts = utils.get_available_prompts() if available_prompts: new_value = available_prompts[min(current_index, len(available_prompts) - 1)] else: new_value = utils.current_time() (shared.user_data_dir / "logs" / "notebook").mkdir(parents=True, exist_ok=True) (shared.user_data_dir / "logs" / "notebook" / f"{new_value}.txt").write_text("In this story,") available_prompts = [new_value] return [ gr.update(choices=available_prompts, value=new_value), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) ] def handle_rename_prompt_click_notebook(current_name): return [ gr.update(value=current_name), gr.update(visible=False), gr.update(visible=True) ] def handle_rename_prompt_confirm_notebook(new_name, current_name): old_path = shared.user_data_dir / "logs" / "notebook" / f"{current_name}.txt" new_path = shared.user_data_dir / "logs" / "notebook" / f"{new_name}.txt" if old_path.exists() and not new_path.exists(): old_path.rename(new_path) available_prompts = utils.get_available_prompts() return [ gr.update(choices=available_prompts, value=new_name), gr.update(visible=True), gr.update(visible=False) ] def autosave_prompt(text, prompt_name): """Automatically save the text to the selected prompt file""" if prompt_name and text.strip(): prompt_path = shared.user_data_dir / "logs" / "notebook" / f"{prompt_name}.txt" prompt_path.parent.mkdir(parents=True, exist_ok=True) prompt_path.write_text(text, encoding='utf-8') def safe_autosave_prompt(content, prompt_name): """Thread-safe wrapper for autosave_prompt to prevent file corruption""" with _notebook_file_lock: autosave_prompt(content, prompt_name) def store_notebook_state_and_debounce(text, prompt_name): """Store current notebook state and trigger debounced save""" global _notebook_auto_save_timer, _last_notebook_text, _last_notebook_prompt if shared.args.multi_user: return _last_notebook_text = text _last_notebook_prompt = prompt_name if _notebook_auto_save_timer is not None: _notebook_auto_save_timer.cancel() _notebook_auto_save_timer = threading.Timer(1.0, _perform_notebook_debounced_save) _notebook_auto_save_timer.start() def _perform_notebook_debounced_save(): """Actually perform the notebook save using the stored state""" try: if _last_notebook_text is not None and _last_notebook_prompt is not None: safe_autosave_prompt(_last_notebook_text, _last_notebook_prompt) except Exception as e: print(f"Notebook auto-save failed: {e}") ================================================ FILE: modules/ui_parameters.py ================================================ from pathlib import Path import gradio as gr from modules import loaders, presets, shared, ui, ui_chat, utils from modules.utils import gradio def create_ui(): mu = shared.args.multi_user with gr.Tab("Parameters", elem_id="parameters"): with gr.Tab("Generation"): with gr.Row(): with gr.Column(): with gr.Row(): shared.gradio['preset_menu'] = gr.Dropdown(choices=utils.get_available_presets(), value=shared.settings['preset'], label='Preset', elem_classes='slim-dropdown') ui.create_refresh_button(shared.gradio['preset_menu'], lambda: None, lambda: {'choices': utils.get_available_presets()}, 'refresh-button', interactive=not mu) shared.gradio['save_preset'] = gr.Button('💾', elem_classes='refresh-button', interactive=not mu) shared.gradio['delete_preset'] = gr.Button('🗑️', elem_classes='refresh-button', interactive=not mu) shared.gradio['reset_preset'] = gr.Button('Restore preset', elem_classes='refresh-button', interactive=True) shared.gradio['neutralize_samplers'] = gr.Button('Neutralize samplers', elem_classes='refresh-button', interactive=True) with gr.Column(): shared.gradio['filter_by_loader'] = gr.Dropdown(label="Filter by loader", choices=["All"] + list(loaders.loaders_and_params.keys()) if not shared.args.portable else ['llama.cpp'], value="All", elem_classes='slim-dropdown') with gr.Row(): with gr.Column(): with gr.Row(): with gr.Column(): gr.Markdown('## Curve shape') shared.gradio['temperature'] = gr.Slider(0.01, 5, value=shared.settings['temperature'], step=0.01, label='temperature') shared.gradio['dynatemp_low'] = gr.Slider(0.01, 5, value=shared.settings['dynatemp_low'], step=0.01, label='dynatemp_low', visible=shared.settings['dynamic_temperature']) shared.gradio['dynatemp_high'] = gr.Slider(0.01, 5, value=shared.settings['dynatemp_high'], step=0.01, label='dynatemp_high', visible=shared.settings['dynamic_temperature']) shared.gradio['dynatemp_exponent'] = gr.Slider(0.01, 5, value=shared.settings['dynatemp_exponent'], step=0.01, label='dynatemp_exponent', visible=shared.settings['dynamic_temperature']) shared.gradio['smoothing_factor'] = gr.Slider(0.0, 10.0, value=shared.settings['smoothing_factor'], step=0.01, label='smoothing_factor', info='Activates Quadratic Sampling.') shared.gradio['smoothing_curve'] = gr.Slider(1.0, 10.0, value=shared.settings['smoothing_curve'], step=0.01, label='smoothing_curve', info='Adjusts the dropoff curve of Quadratic Sampling.') shared.gradio['dynamic_temperature'] = gr.Checkbox(value=shared.settings['dynamic_temperature'], label='dynamic_temperature') gr.Markdown('## Curve cutoff') shared.gradio['top_p'] = gr.Slider(0.0, 1.0, value=shared.settings['top_p'], step=0.01, label='top_p') shared.gradio['top_k'] = gr.Slider(0, 200, value=shared.settings['top_k'], step=1, label='top_k') shared.gradio['min_p'] = gr.Slider(0.0, 1.0, value=shared.settings['min_p'], step=0.01, label='min_p') shared.gradio['top_n_sigma'] = gr.Slider(0.0, 5.0, value=shared.settings['top_n_sigma'], step=0.01, label='top_n_sigma') shared.gradio['typical_p'] = gr.Slider(0.0, 1.0, value=shared.settings['typical_p'], step=0.01, label='typical_p') shared.gradio['xtc_threshold'] = gr.Slider(0, 0.5, value=shared.settings['xtc_threshold'], step=0.01, label='xtc_threshold', info='If 2 or more tokens have probability above this threshold, consider removing all but the last one.') shared.gradio['xtc_probability'] = gr.Slider(0, 1, value=shared.settings['xtc_probability'], step=0.01, label='xtc_probability', info='Probability that the removal will actually happen. 0 disables the sampler. 1 makes it always happen.') shared.gradio['epsilon_cutoff'] = gr.Slider(0, 9, value=shared.settings['epsilon_cutoff'], step=0.01, label='epsilon_cutoff') shared.gradio['eta_cutoff'] = gr.Slider(0, 20, value=shared.settings['eta_cutoff'], step=0.01, label='eta_cutoff') shared.gradio['tfs'] = gr.Slider(0.0, 1.0, value=shared.settings['tfs'], step=0.01, label='tfs') shared.gradio['top_a'] = gr.Slider(0.0, 1.0, value=shared.settings['top_a'], step=0.01, label='top_a') gr.Markdown('## Repetition suppression') shared.gradio['dry_multiplier'] = gr.Slider(0, 5, value=shared.settings['dry_multiplier'], step=0.01, label='dry_multiplier', info='Set to greater than 0 to enable DRY. Recommended value: 0.8.') shared.gradio['dry_allowed_length'] = gr.Slider(1, 20, value=shared.settings['dry_allowed_length'], step=1, label='dry_allowed_length', info='Longest sequence that can be repeated without being penalized.') shared.gradio['dry_base'] = gr.Slider(1, 4, value=shared.settings['dry_base'], step=0.01, label='dry_base', info='Controls how fast the penalty grows with increasing sequence length.') shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=shared.settings['repetition_penalty'], step=0.01, label='repetition_penalty') shared.gradio['frequency_penalty'] = gr.Slider(0, 2, value=shared.settings['frequency_penalty'], step=0.05, label='frequency_penalty') shared.gradio['presence_penalty'] = gr.Slider(0, 2, value=shared.settings['presence_penalty'], step=0.05, label='presence_penalty') shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=shared.settings['encoder_repetition_penalty'], step=0.01, label='encoder_repetition_penalty') shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=shared.settings['no_repeat_ngram_size'], label='no_repeat_ngram_size') shared.gradio['repetition_penalty_range'] = gr.Slider(0, 4096, step=64, value=shared.settings['repetition_penalty_range'], label='repetition_penalty_range') with gr.Column(): gr.Markdown('## Alternative sampling methods') shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=shared.settings['penalty_alpha'], label='penalty_alpha', info='For Contrastive Search. do_sample must be unchecked.') shared.gradio['guidance_scale'] = gr.Slider(-0.5, 2.5, step=0.05, value=shared.settings['guidance_scale'], label='guidance_scale', info='For CFG. 1.5 is a good value.') shared.gradio['mirostat_mode'] = gr.Slider(0, 2, step=1, value=shared.settings['mirostat_mode'], label='mirostat_mode', info='mode=1 is for llama.cpp only.') shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=shared.settings['mirostat_tau'], label='mirostat_tau') shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=shared.settings['mirostat_eta'], label='mirostat_eta') shared.gradio['adaptive_target'] = gr.Slider(0.0, 1.0, value=shared.settings['adaptive_target'], step=0.01, label='adaptive_target', info='Target probability for adaptive-p sampling. Tokens near this probability are favored. 0 disables.') shared.gradio['adaptive_decay'] = gr.Slider(0.0, 0.99, value=shared.settings['adaptive_decay'], step=0.01, label='adaptive_decay', info='EMA decay rate for adaptive-p. Controls history window (~1/(1-decay) tokens). Default: 0.9.') gr.Markdown('## Other options') shared.gradio['do_sample'] = gr.Checkbox(value=shared.settings['do_sample'], label='do_sample') shared.gradio['temperature_last'] = gr.Checkbox(value=shared.settings['temperature_last'], label='temperature_last', info='Moves temperature/dynamic temperature/quadratic sampling to the end of the sampler stack, ignoring their positions in "Sampler priority".') shared.gradio['sampler_priority'] = gr.DragDrop(value=shared.settings['sampler_priority'], label='Sampler priority', info='Parameter names separated by new lines or commas.', elem_classes=['add_scrollbar']) shared.gradio['dry_sequence_breakers'] = gr.Textbox(value=shared.settings['dry_sequence_breakers'], label='dry_sequence_breakers', info='Tokens across which sequence matching is not continued. Specified as a comma-separated list of quoted strings.') with gr.Column(): with gr.Row(): with gr.Column(): with gr.Blocks(): shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], value=shared.settings['max_new_tokens'], step=1, label='max_new_tokens', info='⚠️ Setting this too high can cause prompt truncation.') shared.gradio['prompt_lookup_num_tokens'] = gr.Slider(value=shared.settings['prompt_lookup_num_tokens'], minimum=0, maximum=10, step=1, label='prompt_lookup_num_tokens', info='Activates Prompt Lookup Decoding.') shared.gradio['max_tokens_second'] = gr.Slider(value=shared.settings['max_tokens_second'], minimum=0, maximum=20, step=1, label='Maximum tokens/second', info='To make text readable in real time.') shared.gradio['auto_max_new_tokens'] = gr.Checkbox(value=shared.settings['auto_max_new_tokens'], label='auto_max_new_tokens', info='Expand max_new_tokens to the available context length.') shared.gradio['ban_eos_token'] = gr.Checkbox(value=shared.settings['ban_eos_token'], label='Ban the eos_token', info='Forces the model to never end the generation prematurely.') shared.gradio['add_bos_token'] = gr.Checkbox(value=shared.settings['add_bos_token'], label='Add the bos_token to the beginning of prompts', info='Only applies to text completion (notebook). In chat mode, templates control BOS tokens.') shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.') shared.gradio['stream'] = gr.Checkbox(value=shared.settings['stream'], label='Activate text streaming') shared.gradio['static_cache'] = gr.Checkbox(value=shared.settings['static_cache'], label='Static KV cache', info='Use a static cache for improved performance.') with gr.Column(): shared.gradio['truncation_length'] = gr.Number(precision=0, step=256, value=get_truncation_length(), label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length.') shared.gradio['seed'] = gr.Number(value=shared.settings['seed'], label='Seed (-1 for random)') shared.gradio['custom_system_message'] = gr.Textbox(value=shared.settings['custom_system_message'], lines=2, label='Custom system message', info='If not empty, will be used instead of the default one.', elem_classes=['add_scrollbar']) shared.gradio['custom_stopping_strings'] = gr.Textbox(lines=2, value=shared.settings["custom_stopping_strings"] or None, label='Custom stopping strings', info='Written between "" and separated by commas.', placeholder='"\\n", "\\nYou:"') shared.gradio['custom_token_bans'] = gr.Textbox(value=shared.settings['custom_token_bans'] or None, label='Token bans', info='Token IDs to ban, separated by commas. The IDs can be found in the Default or Notebook tab.') shared.gradio['negative_prompt'] = gr.Textbox(value=shared.settings['negative_prompt'], label='Negative prompt', info='For CFG. Only used when guidance_scale is different than 1.', lines=3, elem_classes=['add_scrollbar']) with gr.Row() as shared.gradio['grammar_file_row']: shared.gradio['grammar_file'] = gr.Dropdown(value='None', choices=utils.get_available_grammars(), label='Load grammar from file (.gbnf)', elem_classes='slim-dropdown') ui.create_refresh_button(shared.gradio['grammar_file'], lambda: None, lambda: {'choices': utils.get_available_grammars()}, 'refresh-button', interactive=not mu) shared.gradio['save_grammar'] = gr.Button('💾', elem_classes='refresh-button', interactive=not mu) shared.gradio['delete_grammar'] = gr.Button('🗑️ ', elem_classes='refresh-button', interactive=not mu) shared.gradio['grammar_string'] = gr.Textbox(value=shared.settings['grammar_string'], label='Grammar', lines=16, elem_classes=['add_scrollbar', 'monospace']) ui_chat.create_chat_settings_ui() def create_event_handlers(): shared.gradio['filter_by_loader'].change(loaders.blacklist_samplers, gradio('filter_by_loader', 'dynamic_temperature'), gradio(loaders.list_all_samplers()), show_progress=False) shared.gradio['preset_menu'].change( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( presets.load_preset_for_ui, gradio('preset_menu', 'interface_state'), gradio('interface_state') + gradio(presets.presets_params()), show_progress=False) shared.gradio['reset_preset'].click( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( presets.reset_preset_for_ui, gradio('preset_menu', 'interface_state'), gradio('interface_state') + gradio(presets.presets_params()), show_progress=False) shared.gradio['neutralize_samplers'].click( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( presets.neutralize_samplers_for_ui, gradio('interface_state'), gradio('interface_state') + gradio(presets.presets_params()), show_progress=False) shared.gradio['grammar_file'].change(load_grammar, gradio('grammar_file'), gradio('grammar_string'), show_progress=False) shared.gradio['dynamic_temperature'].change(lambda x: [gr.update(visible=x)] * 3, gradio('dynamic_temperature'), gradio('dynatemp_low', 'dynatemp_high', 'dynatemp_exponent'), show_progress=False) def get_truncation_length(): if shared.args.ctx_size > 0 and ('ctx_size' in shared.provided_arguments or shared.args.ctx_size != shared.args_defaults.ctx_size): return shared.args.ctx_size else: return shared.settings['truncation_length'] def load_grammar(name): p = shared.user_data_dir / 'grammars' / name if p.exists(): return open(p, 'r', encoding='utf-8').read() else: return '' ================================================ FILE: modules/ui_session.py ================================================ import gradio as gr from modules import shared, ui, utils from modules.utils import gradio def create_ui(): mu = shared.args.multi_user with gr.Tab("Session", elem_id="session-tab"): with gr.Row(): with gr.Column(): gr.Markdown("## Settings") shared.gradio['toggle_dark_mode'] = gr.Button('Toggle light/dark theme 💡', elem_classes='refresh-button') shared.gradio['show_two_notebook_columns'] = gr.Checkbox(label='Show two columns in the Notebook tab', value=shared.settings['show_two_notebook_columns']) shared.gradio['paste_to_attachment'] = gr.Checkbox(label='Turn long pasted text into attachments in the Chat tab', value=shared.settings['paste_to_attachment'], elem_id='paste_to_attachment') shared.gradio['include_past_attachments'] = gr.Checkbox(label='Include attachments/search results from previous messages in the chat prompt', value=shared.settings['include_past_attachments']) with gr.Column(): gr.Markdown("## Extensions & flags") shared.gradio['save_settings'] = gr.Button(f'Save extensions settings to {shared.user_data_dir}/settings.yaml', elem_classes='refresh-button', interactive=not mu) shared.gradio['reset_interface'] = gr.Button("Apply flags/extensions and restart", interactive=not mu) with gr.Row(): with gr.Column(): shared.gradio['extensions_menu'] = gr.CheckboxGroup(choices=utils.get_available_extensions(), value=shared.args.extensions, label="Available extensions", info='Note that some of these extensions may require manually installing Python requirements through the command: pip install -r extensions/extension_name/requirements.txt', elem_classes='checkboxgroup-table') with gr.Column(): shared.gradio['bool_menu'] = gr.CheckboxGroup(choices=get_boolean_arguments(), value=get_boolean_arguments(active=True), label="Boolean command-line flags", elem_classes='checkboxgroup-table') shared.gradio['theme_state'] = gr.Textbox(visible=False, value='dark' if shared.settings['dark_theme'] else 'light') if not mu: shared.gradio['save_settings'].click( ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then( handle_save_settings, gradio('interface_state', 'preset_menu', 'extensions_menu', 'show_controls', 'theme_state'), gradio('save_contents', 'save_filename', 'save_root', 'save_root_state', 'file_saver'), show_progress=False) shared.gradio['toggle_dark_mode'].click( lambda x: 'dark' if x == 'light' else 'light', gradio('theme_state'), gradio('theme_state')).then( None, None, None, js=f'() => {{{ui.dark_theme_js}; toggleDarkMode(); localStorage.setItem("theme", document.body.classList.contains("dark") ? "dark" : "light")}}') shared.gradio['show_two_notebook_columns'].change( handle_default_to_notebook_change, gradio('show_two_notebook_columns', 'textbox-default', 'output_textbox', 'prompt_menu-default', 'textbox-notebook', 'prompt_menu-notebook'), gradio('default-tab', 'notebook-tab', 'textbox-default', 'output_textbox', 'prompt_menu-default', 'textbox-notebook', 'prompt_menu-notebook') ) # Reset interface event if not mu: shared.gradio['reset_interface'].click( set_interface_arguments, gradio('extensions_menu', 'bool_menu'), None).then( None, None, None, js='() => {document.body.innerHTML=\'

      Reloading...

      \'; setTimeout(function(){location.reload()},2500); return []}') def handle_save_settings(state, preset, extensions, show_controls, theme): contents = ui.save_settings(state, preset, extensions, show_controls, theme, manual_save=True) root = str(shared.user_data_dir) + "/" return [ contents, "settings.yaml", root, root, gr.update(visible=True) ] def handle_default_to_notebook_change(show_two_columns, default_input, default_output, default_prompt, notebook_input, notebook_prompt): if show_two_columns: # Notebook to default return [ gr.update(visible=True), gr.update(visible=False), notebook_input, "", gr.update(value=notebook_prompt, choices=utils.get_available_prompts()), gr.update(), gr.update(), ] else: # Default to notebook return [ gr.update(visible=False), gr.update(visible=True), gr.update(), gr.update(), gr.update(), default_input, gr.update(value=default_prompt, choices=utils.get_available_prompts()) ] def set_interface_arguments(extensions, bool_active): shared.args.extensions = extensions bool_list = get_boolean_arguments() for k in bool_list: setattr(shared.args, k, False) for k in bool_active: setattr(shared.args, k, True) if k == 'api': shared.add_extension('openai', last=True) shared.need_restart = True def get_boolean_arguments(active=False): cmd_list = vars(shared.args) bool_list = sorted([k for k in cmd_list if type(cmd_list[k]) is bool and k not in ui.list_model_elements()]) bool_active = [k for k in bool_list if vars(shared.args)[k]] if active: return bool_active else: return bool_list ================================================ FILE: modules/utils.py ================================================ import os import re from datetime import datetime from pathlib import Path from modules import shared from modules.logging_colors import logger # Helper function to get multiple values from shared.gradio def gradio(*keys): if len(keys) == 1 and type(keys[0]) in [list, tuple]: keys = keys[0] return [shared.gradio[k] for k in keys] def sanitize_filename(name): """Strip path traversal components from a filename. Returns only the final path component with leading dots removed, preventing directory traversal via '../' or absolute paths. """ name = Path(name).name # drop all directory components name = name.lstrip('.') # remove leading dots return name def _is_path_allowed(abs_path_str): """Check if a path is under the configured user_data directory.""" abs_path = Path(abs_path_str).resolve() user_data_resolved = shared.user_data_dir.resolve() try: abs_path.relative_to(user_data_resolved) return True except ValueError: return False def save_file(fname, contents): if fname == '': logger.error('File name is empty!') return abs_path_str = os.path.abspath(fname) if not _is_path_allowed(abs_path_str): logger.error(f'Invalid file path: \"{fname}\"') return if Path(abs_path_str).suffix.lower() not in ('.yaml', '.yml', '.json', '.txt', '.gbnf'): logger.error(f'Refusing to save file with disallowed extension: \"{fname}\"') return with open(abs_path_str, 'w', encoding='utf-8') as f: f.write(contents) logger.info(f'Saved \"{abs_path_str}\".') def delete_file(fname): if fname == '': logger.error('File name is empty!') return abs_path_str = os.path.abspath(fname) if not _is_path_allowed(abs_path_str): logger.error(f'Invalid file path: \"{fname}\"') return p = Path(abs_path_str) if p.exists(): p.unlink() logger.info(f'Deleted \"{fname}\".') def current_time(): return f"{datetime.now().strftime('%Y-%m-%d_%Hh%Mm%Ss')}" def atoi(text): return int(text) if text.isdigit() else text.lower() # Replace multiple string pairs in a string def replace_all(text, dic): for i, j in dic.items(): text = text.replace(i, j) return text def natural_keys(text): return [atoi(c) for c in re.split(r'(\d+)', text)] def check_model_loaded(): if shared.model_name == 'None' or shared.model is None: if len(get_available_models()) == 0: error_msg = f"No model is loaded.\n\nTo get started:\n1) Place a GGUF file in your {shared.user_data_dir}/models folder\n2) Go to the Model tab and select it" logger.error(error_msg) return False, error_msg else: error_msg = "No model is loaded. Please select one in the Model tab." logger.error(error_msg) return False, error_msg return True, None def resolve_model_path(model_name_or_path, image_model=False): """ Resolves a model path, checking for a direct path before the default models directory. """ path_candidate = Path(model_name_or_path) if path_candidate.exists(): return path_candidate elif image_model: return Path(f'{shared.args.image_model_dir}/{model_name_or_path}') else: return Path(f'{shared.args.model_dir}/{model_name_or_path}') def get_available_models(): # Get all GGUF files gguf_files = get_available_ggufs() # Filter out non-first parts of multipart GGUF files filtered_gguf_files = [] for gguf_path in gguf_files: filename = os.path.basename(gguf_path) match = re.search(r'-(\d+)-of-\d+\.gguf$', filename) if match: part_number = match.group(1) # Keep only if it's part 1 if part_number.lstrip("0") == "1": filtered_gguf_files.append(gguf_path) else: # Not a multi-part file filtered_gguf_files.append(gguf_path) model_dir = Path(shared.args.model_dir) # Find top-level directories containing GGUF files dirs_with_gguf = set() for gguf_path in gguf_files: path = Path(gguf_path) if len(path.parts) > 0: dirs_with_gguf.add(path.parts[0]) # Find directories with safetensors files dirs_with_safetensors = set() for item in os.listdir(model_dir): item_path = model_dir / item if item_path.is_dir(): if any(file.lower().endswith(('.safetensors', '.pt')) for file in os.listdir(item_path) if (item_path / file).is_file()): dirs_with_safetensors.add(item) # Find valid model directories model_dirs = [] for item in os.listdir(model_dir): item_path = model_dir / item if not item_path.is_dir(): continue # Include directory if it either doesn't contain GGUF files # or contains both GGUF and safetensors files if item not in dirs_with_gguf or item in dirs_with_safetensors: model_dirs.append(item) model_dirs = sorted(model_dirs, key=natural_keys) return filtered_gguf_files + model_dirs def get_available_image_models(): model_dir = Path(shared.args.image_model_dir) model_dir.mkdir(parents=True, exist_ok=True) # Find valid model directories model_dirs = [] for item in os.listdir(model_dir): item_path = model_dir / item if not item_path.is_dir(): continue model_dirs.append(item) model_dirs = sorted(model_dirs, key=natural_keys) return model_dirs def get_available_ggufs(): model_list = [] model_dir = Path(shared.args.model_dir) for dirpath, _, files in os.walk(model_dir, followlinks=True): for file in files: if file.lower().endswith(".gguf"): model_path = Path(dirpath) / file rel_path = model_path.relative_to(model_dir) model_list.append(str(rel_path)) return sorted(model_list, key=natural_keys) def get_available_mmproj(): mmproj_dir = shared.user_data_dir / 'mmproj' if not mmproj_dir.exists(): return ['None'] mmproj_files = [] for item in mmproj_dir.iterdir(): if item.is_file() and item.suffix.lower() in ('.gguf', '.bin'): mmproj_files.append(item.name) return ['None'] + sorted(mmproj_files, key=natural_keys) def get_available_presets(): return sorted(set((k.stem for k in (shared.user_data_dir / 'presets').glob('*.yaml'))), key=natural_keys) def get_available_prompts(): notebook_dir = shared.user_data_dir / 'logs' / 'notebook' notebook_dir.mkdir(parents=True, exist_ok=True) prompt_files = list(notebook_dir.glob('*.txt')) if not prompt_files: new_name = current_time() new_path = notebook_dir / f"{new_name}.txt" new_path.write_text("In this story,", encoding='utf-8') prompt_files = [new_path] sorted_files = sorted(prompt_files, key=lambda x: x.stat().st_mtime, reverse=True) prompts = [file.stem for file in sorted_files] return prompts def get_available_characters(): paths = (x for x in (shared.user_data_dir / 'characters').iterdir() if x.suffix in ('.json', '.yaml', '.yml')) return sorted(set((k.stem for k in paths)), key=natural_keys) def get_available_users(): users_dir = shared.user_data_dir / 'users' users_dir.mkdir(parents=True, exist_ok=True) paths = (x for x in users_dir.iterdir() if x.suffix in ('.json', '.yaml', '.yml')) return sorted(set((k.stem for k in paths)), key=natural_keys) def get_available_instruction_templates(): path = str(shared.user_data_dir / "instruction-templates") paths = [] if os.path.exists(path): paths = (x for x in Path(path).iterdir() if x.suffix in ('.json', '.yaml', '.yml')) return ['None'] + sorted(set((k.stem for k in paths)), key=natural_keys) def get_available_extensions(): # User extensions (higher priority) user_extensions = [] user_ext_path = shared.user_data_dir / 'extensions' if user_ext_path.exists(): user_exts = map(lambda x: x.parent.name, user_ext_path.glob('*/script.py')) user_extensions = sorted(set(user_exts), key=natural_keys) # System extensions (excluding those overridden by user extensions) system_exts = map(lambda x: x.parent.name, Path('extensions').glob('*/script.py')) system_extensions = sorted(set(system_exts) - set(user_extensions), key=natural_keys) return user_extensions + system_extensions def get_available_loras(): return ['None'] + sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=natural_keys) def get_datasets(path: str, ext: str): # include subdirectories for raw txt files to allow training from a subdirectory of txt files if ext == "txt": return ['None'] + sorted(set([k.stem for k in list(Path(path).glob('*.txt')) + list(Path(path).glob('*/')) if k.stem != 'put-trainer-datasets-here']), key=natural_keys) return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=natural_keys) def get_chat_datasets(path: str): """List JSON datasets that contain chat conversations (messages or ShareGPT format).""" return ['None'] + sorted(set([k.stem for k in Path(path).glob('*.json') if k.stem != 'put-trainer-datasets-here' and _is_chat_dataset(k)]), key=natural_keys) def get_text_datasets(path: str): """List JSON datasets that contain raw text ({"text": ...} format).""" return ['None'] + sorted(set([k.stem for k in Path(path).glob('*.json') if k.stem != 'put-trainer-datasets-here' and _is_text_dataset(k)]), key=natural_keys) def _peek_json_keys(filepath): """Read the first object in a JSON array file and return its keys.""" import json decoder = json.JSONDecoder() WS = ' \t\n\r' try: with open(filepath, 'r', encoding='utf-8') as f: buf = '' obj_start = None while len(buf) < 1 << 20: # Read up to 1MB chunk = f.read(8192) if not chunk: break buf += chunk if obj_start is None: idx = 0 while idx < len(buf) and buf[idx] in WS: idx += 1 if idx >= len(buf): continue if buf[idx] != '[': return set() idx += 1 while idx < len(buf) and buf[idx] in WS: idx += 1 if idx >= len(buf): continue obj_start = idx try: obj, _ = decoder.raw_decode(buf, obj_start) if isinstance(obj, dict): return set(obj.keys()) return set() except json.JSONDecodeError: continue except Exception: pass return set() def _is_chat_dataset(filepath): keys = _peek_json_keys(filepath) return bool(keys & {'messages', 'conversations'}) def _is_text_dataset(filepath): keys = _peek_json_keys(filepath) return 'text' in keys def get_available_chat_styles(): return sorted(set(('-'.join(k.stem.split('-')[1:]) for k in Path('css').glob('chat_style*.css'))), key=natural_keys) def get_available_grammars(): return ['None'] + sorted([item.name for item in list((shared.user_data_dir / 'grammars').glob('*.gbnf'))], key=natural_keys) ================================================ FILE: modules/web_search.py ================================================ import concurrent.futures import html import ipaddress import random import re import socket from concurrent.futures import as_completed from datetime import datetime from urllib.parse import parse_qs, quote_plus, urljoin, urlparse import requests from modules import shared from modules.logging_colors import logger def _validate_url(url): """Validate that a URL is safe to fetch (not targeting private/internal networks).""" parsed = urlparse(url) if parsed.scheme not in ('http', 'https'): raise ValueError(f"Unsupported URL scheme: {parsed.scheme}") hostname = parsed.hostname if not hostname: raise ValueError("No hostname in URL") # Resolve hostname and check all returned addresses try: for family, _, _, _, sockaddr in socket.getaddrinfo(hostname, None): ip = ipaddress.ip_address(sockaddr[0]) if not ip.is_global: raise ValueError(f"Access to non-public address {ip} is blocked") except socket.gaierror: raise ValueError(f"Could not resolve hostname: {hostname}") def get_current_timestamp(): """Returns the current time in 24-hour format""" return datetime.now().strftime('%b %d, %Y %H:%M') def download_web_page(url, timeout=10, include_links=False): """ Download a web page and extract its main content as Markdown text. """ import trafilatura try: _validate_url(url) headers = { 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' } max_redirects = 5 for _ in range(max_redirects): response = requests.get(url, headers=headers, timeout=timeout, allow_redirects=False) if response.is_redirect and 'Location' in response.headers: url = urljoin(url, response.headers['Location']) _validate_url(url) else: break response.raise_for_status() result = trafilatura.extract( response.text, include_links=include_links, output_format='markdown', url=url ) return result or "" except requests.exceptions.RequestException as e: logger.error(f"Error downloading {url}: {e}") return "" except Exception as e: logger.error(f"An unexpected error occurred: {e}") return "" def perform_web_search(query, num_pages=3, max_workers=5, timeout=10, fetch_content=True): """Perform web search and return results, optionally with page content""" try: search_url = f"https://html.duckduckgo.com/html/?q={quote_plus(query)}" agents = [ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36" ] response = requests.get(search_url, headers={'User-Agent': random.choice(agents)}, timeout=timeout) response.raise_for_status() response_text = response.text # Extract results - title and URL come from the same element result_links = re.findall(r']*class="[^"]*result__a[^"]*"[^>]*>(.*?)', response_text, re.DOTALL) result_tags = re.findall(r']*class="[^"]*result__a[^"]*"[^>]*)>', response_text, re.DOTALL) # Prepare download tasks download_tasks = [] for i, (tag_attrs, raw_title) in enumerate(zip(result_tags, result_links)): if num_pages is not None and i >= num_pages: break # Extract href and resolve the actual URL from DuckDuckGo's redirect link href_match = re.search(r'href="([^"]*)"', tag_attrs) if not href_match: continue uddg = parse_qs(urlparse(html.unescape(href_match.group(1))).query).get('uddg', [''])[0] if not uddg: continue title = html.unescape(re.sub(r'<[^>]+>', '', raw_title).strip()) download_tasks.append((uddg, title, len(download_tasks))) search_results = [None] * len(download_tasks) # Pre-allocate to maintain order if not fetch_content: for url, title, index in download_tasks: search_results[index] = { 'title': title, 'url': url, 'content': '' } return search_results # Download pages in parallel with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: # Submit all download tasks future_to_task = { executor.submit(download_web_page, task[0]): task for task in download_tasks } # Collect results as they complete for future in as_completed(future_to_task): url, title, index = future_to_task[future] try: content = future.result() search_results[index] = { 'title': title, 'url': url, 'content': content } except Exception: search_results[index] = { 'title': title, 'url': url, 'content': '' } return search_results except Exception as e: logger.error(f"Error performing web search: {e}") return [] def truncate_content_by_tokens(content, max_tokens=8192): """Truncate content to fit within token limit using binary search""" if len(shared.tokenizer.encode(content)) <= max_tokens: return content left, right = 0, len(content) while left < right: mid = (left + right + 1) // 2 if len(shared.tokenizer.encode(content[:mid])) <= max_tokens: left = mid else: right = mid - 1 return content[:left] def add_web_search_attachments(history, row_idx, user_message, search_query, state): """Perform web search and add results as attachments""" if not search_query: logger.warning("No search query provided") return try: logger.info(f"Using search query: {search_query}") # Perform web search num_pages = int(state.get('web_search_pages', 3)) search_results = perform_web_search(search_query, num_pages) if not search_results: logger.warning("No search results found") return # Filter out failed downloads before adding attachments successful_results = [result for result in search_results if result['content'].strip()] if not successful_results: logger.warning("No successful downloads to add as attachments") return # Add search results as attachments key = f"user_{row_idx}" if key not in history['metadata']: history['metadata'][key] = {"timestamp": get_current_timestamp()} if "attachments" not in history['metadata'][key]: history['metadata'][key]["attachments"] = [] for result in successful_results: attachment = { "name": result['title'], "type": "text/html", "url": result['url'], "content": truncate_content_by_tokens(result['content']) } history['metadata'][key]["attachments"].append(attachment) logger.info(f"Added {len(successful_results)} successful web search results as attachments.") except Exception as e: logger.error(f"Error in web search: {e}") ================================================ FILE: one_click.py ================================================ import argparse import glob import hashlib import json import os import platform import re import signal import site import subprocess import sys # Define the required versions TORCH_VERSION = "2.9.1" PYTHON_VERSION = "3.13" LIBSTDCXX_VERSION_LINUX = "12.1.0" # Environment script_dir = os.getcwd() conda_env_path = os.path.join(script_dir, "installer_files", "env") state_file = '.installer_state.json' # Command-line flags flags = f"{' '.join([flag for flag in sys.argv[1:] if flag != '--update-wizard'])}" def signal_handler(sig, frame): sys.exit(0) signal.signal(signal.SIGINT, signal_handler) def is_linux(): return sys.platform.startswith("linux") def is_windows(): return sys.platform.startswith("win") def is_macos(): return sys.platform.startswith("darwin") def is_x86_64(): return platform.machine() == "x86_64" def is_installed(): site_packages_path = None for sitedir in site.getsitepackages(): if "site-packages" in sitedir and conda_env_path in sitedir: site_packages_path = sitedir break if site_packages_path: return os.path.isfile(os.path.join(site_packages_path, 'torch', '__init__.py')) else: return os.path.isdir(conda_env_path) def load_state(): """Load installer state from JSON file""" if os.path.exists(state_file): try: with open(state_file, 'r') as f: return json.load(f) except Exception: return {} return {} def save_state(state): """Save installer state to JSON file""" with open(state_file, 'w') as f: json.dump(state, f) def get_gpu_choice(): """Get GPU choice from state file or ask user""" state = load_state() gpu_choice = state.get('gpu_choice') if not gpu_choice: if "GPU_CHOICE" in os.environ: choice = os.environ["GPU_CHOICE"].upper() print_big_message(f"Selected GPU choice \"{choice}\" based on the GPU_CHOICE environment variable.") else: choice = get_user_choice( "What is your GPU?", { 'A': 'NVIDIA', 'B': 'AMD - Linux only, ROCm 7.2', 'C': 'Apple M Series', 'D': 'Intel Arc (beta)', 'N': 'CPU mode' }, ) # Convert choice to GPU name gpu_choice = {"A": "NVIDIA_CUDA128", "B": "AMD", "C": "APPLE", "D": "INTEL", "N": "NONE"}[choice] # Save choice to state state['gpu_choice'] = gpu_choice save_state(state) return gpu_choice def get_pytorch_install_command(gpu_choice): """Get PyTorch installation command based on GPU choice""" base_cmd = f"python -m pip install torch=={TORCH_VERSION} " pypi_fallback = " --extra-index-url https://pypi.org/simple/" if gpu_choice == "NVIDIA_CUDA128": return base_cmd + "--index-url https://download.pytorch.org/whl/cu128" + pypi_fallback elif gpu_choice == "AMD": py_tag = f"cp{PYTHON_VERSION.replace('.', '')}" return f"python -m pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torch-{TORCH_VERSION}%2Brocm7.2.0.lw.git7e1940d4-{py_tag}-{py_tag}-linux_x86_64.whl" elif gpu_choice in ["APPLE", "NONE"]: return base_cmd + "--index-url https://download.pytorch.org/whl/cpu" + pypi_fallback elif gpu_choice == "INTEL": return base_cmd + "--index-url https://download.pytorch.org/whl/xpu" else: return base_cmd def get_pytorch_update_command(gpu_choice): """Get PyTorch update command based on GPU choice""" base_cmd = f"python -m pip install --upgrade torch=={TORCH_VERSION} " pypi_fallback = " --extra-index-url https://pypi.org/simple/" if gpu_choice == "NVIDIA_CUDA128": return f"{base_cmd}--index-url https://download.pytorch.org/whl/cu128" + pypi_fallback elif gpu_choice == "AMD": py_tag = f"cp{PYTHON_VERSION.replace('.', '')}" return f"python -m pip install --upgrade https://repo.radeon.com/rocm/manylinux/rocm-rel-7.2/torch-{TORCH_VERSION}%2Brocm7.2.0.lw.git7e1940d4-{py_tag}-{py_tag}-linux_x86_64.whl" elif gpu_choice in ["APPLE", "NONE"]: return f"{base_cmd}--index-url https://download.pytorch.org/whl/cpu" + pypi_fallback elif gpu_choice == "INTEL": return f"{base_cmd}--index-url https://download.pytorch.org/whl/xpu" else: return base_cmd def get_requirements_file(gpu_choice): """Get requirements file path based on GPU choice""" requirements_base = os.path.join("requirements", "full") if gpu_choice == "NVIDIA_CUDA128": file_name = "requirements.txt" elif gpu_choice == "AMD": file_name = "requirements_amd.txt" elif gpu_choice == "APPLE": file_name = f"requirements_apple_{'intel' if is_x86_64() else 'silicon'}.txt" elif gpu_choice in ["INTEL", "NONE"]: file_name = "requirements_cpu_only.txt" else: raise ValueError(f"Unknown GPU choice: {gpu_choice}") return os.path.join(requirements_base, file_name) def get_current_commit(): result = run_cmd("git rev-parse HEAD", capture_output=True, environment=True) return result.stdout.decode('utf-8').strip() def get_extensions_names(): return [foldername for foldername in os.listdir('extensions') if os.path.isfile(os.path.join('extensions', foldername, 'requirements.txt'))] def check_env(): # If we have access to conda, we are probably in an environment conda_exist = run_cmd("conda", environment=True, capture_output=True).returncode == 0 if not conda_exist: print("Conda is not installed. Exiting...") sys.exit(1) # Ensure this is a new environment and not the base environment if os.environ.get("CONDA_DEFAULT_ENV", "") == "base": print("Create an environment for this project and activate it. Exiting...") sys.exit(1) def clear_cache(): run_cmd("conda clean -a -y", environment=True) run_cmd("python -m pip cache purge", environment=True) def run_cmd(cmd, assert_success=False, environment=False, capture_output=False, env=None): # Use the conda environment if environment: if is_windows(): conda_bat_path = os.path.join(script_dir, "installer_files", "conda", "condabin", "conda.bat") python_path = os.path.join(conda_env_path, "python.exe") cmd = cmd.replace("python ", f'"{python_path}" ') cmd = f'"{conda_bat_path}" activate "{conda_env_path}" >nul && {cmd}' else: conda_sh_path = os.path.join(script_dir, "installer_files", "conda", "etc", "profile.d", "conda.sh") cmd = f'. "{conda_sh_path}" && conda activate "{conda_env_path}" && {cmd}' # Set executable to None for Windows, bash for everything else executable = None if is_windows() else 'bash' # Run shell commands result = subprocess.run(cmd, shell=True, capture_output=capture_output, env=env, executable=executable) # Assert the command ran successfully if assert_success and result.returncode != 0: print(f"Command '{cmd}' failed with exit status code '{str(result.returncode)}'.\n\nExiting now.\nTry running the start/update script again.") sys.exit(1) return result def print_big_message(message): message = message.strip() lines = message.split('\n') print("\n\n*******************************************************************") for line in lines: print("*", line) print("*******************************************************************\n\n") def calculate_file_hash(file_path): p = os.path.join(script_dir, file_path) if os.path.isfile(p): with open(p, 'rb') as f: return hashlib.sha256(f.read()).hexdigest() else: return '' def generate_alphabetic_sequence(index): result = '' while index >= 0: index, remainder = divmod(index, 26) result = chr(ord('A') + remainder) + result index -= 1 return result def get_user_choice(question, options_dict): print() print(question) print() for key, value in options_dict.items(): print(f"{key}) {value}") print() choice = input("Input> ").upper() while choice not in options_dict.keys(): print("Invalid choice. Please try again.") choice = input("Input> ").upper() return choice def update_pytorch_and_python(): print_big_message("Checking for PyTorch updates.") gpu_choice = get_gpu_choice() install_cmd = get_pytorch_update_command(gpu_choice) run_cmd(install_cmd, assert_success=True, environment=True) def clean_outdated_pytorch_cuda_dependencies(): patterns = ["cu121", "cu122", "rocm6", "torch2.4", "torch2.6", "torch2.7", "torchvision", "torchaudio"] result = run_cmd("python -m pip list --format=freeze", capture_output=True, environment=True) matching_packages = [] for line in result.stdout.decode('utf-8').splitlines(): if "==" in line: pkg_name, version = line.split('==', 1) if any(pattern in version for pattern in patterns): matching_packages.append(pkg_name) if matching_packages: print(f"\nUninstalling: {', '.join(matching_packages)}\n") run_cmd(f"python -m pip uninstall -y {' '.join(matching_packages)}", assert_success=True, environment=True) return matching_packages def install_webui(): if os.path.isfile(state_file): os.remove(state_file) # Get GPU choice and save it to state gpu_choice = get_gpu_choice() # Write a flag to CMD_FLAGS.txt for CPU mode if gpu_choice == "NONE": cmd_flags_path = os.path.join(script_dir, "user_data", "CMD_FLAGS.txt") with open(cmd_flags_path, 'r+') as cmd_flags_file: if "--cpu" not in cmd_flags_file.read(): print_big_message("Adding the --cpu flag to user_data/CMD_FLAGS.txt.") cmd_flags_file.write("\n--cpu\n") # Handle CUDA version display elif any((is_windows(), is_linux())) and gpu_choice == "NVIDIA_CUDA128": print("CUDA: 12.8") # No PyTorch for AMD on Windows elif is_windows() and gpu_choice == "AMD": print("PyTorch setup on Windows is not implemented yet. Exiting...") sys.exit(1) # Install Git and then Pytorch print_big_message("Installing PyTorch.") install_pytorch = get_pytorch_install_command(gpu_choice) run_cmd(f"conda install -y ninja git && {install_pytorch}", assert_success=True, environment=True) # Install the webui requirements update_requirements(initial_installation=True, pull=False) def update_requirements(initial_installation=False, pull=True): # Create .git directory if missing if not os.path.exists(os.path.join(script_dir, ".git")): run_cmd( "git init -b main && git remote add origin https://github.com/oobabooga/text-generation-webui && " "git fetch && git symbolic-ref refs/remotes/origin/HEAD refs/remotes/origin/main && " "git reset --hard origin/main && git branch --set-upstream-to=origin/main", environment=True, assert_success=True ) # Check for outdated Python version and refuse to update if '.'.join(map(str, sys.version_info[:2])) != PYTHON_VERSION: print_big_message( "Your current installation uses Python {}.{}, which is outdated.\n" "Python {} is now required. A clean installation is needed.\n\n" "INSTRUCTIONS:\n" "1. Delete the 'installer_files' folder in your text-generation-webui directory.\n" "2. Run the start script again (e.g., start_windows.bat).\n\n" "This will create a fresh environment with the latest software.".format(*sys.version_info[:2], PYTHON_VERSION) ) sys.exit(0) # Check for outdated CUDA 12.4 installs and refuse to update state = load_state() if state.get('gpu_choice') == 'NVIDIA': print_big_message( "Your current installation uses CUDA 12.4, which has been removed.\n" "To update to the new default (CUDA 12.8), a clean installation is required.\n\n" "INSTRUCTIONS:\n" "1. Delete the 'installer_files' folder in your text-generation-webui directory.\n" "2. Run the start script again (e.g., start_windows.bat).\n\n" "This will create a fresh environment with the latest software." ) sys.exit(0) current_commit = get_current_commit() wheels_changed = not os.path.exists(state_file) installed_wheels = set() if not wheels_changed: state = load_state() installed_wheels = set(state.get('installed_wheels', [])) if 'wheels_changed' in state or state.get('last_installed_commit') != current_commit: wheels_changed = True gpu_choice = get_gpu_choice() requirements_file = get_requirements_file(gpu_choice) if pull: # Read .whl lines before pulling before_pull_whl_lines = [] if os.path.exists(requirements_file): with open(requirements_file, 'r') as f: before_pull_whl_lines = [line for line in f if '.whl' in line] print_big_message('Updating the local copy of the repository with "git pull"') # Hash files before pulling files_to_check = [ 'start_linux.sh', 'start_macos.sh', 'start_windows.bat', 'start_wsl.bat', 'update_wizard_linux.sh', 'update_wizard_macos.sh', 'update_wizard_windows.bat', 'update_wizard_wsl.bat', 'one_click.py' ] before_hashes = {file: calculate_file_hash(file) for file in files_to_check} # Perform the git pull run_cmd("git pull --autostash", assert_success=True, environment=True) current_commit = get_current_commit() # Check hashes after pulling after_hashes = {file: calculate_file_hash(file) for file in files_to_check} if os.path.exists(requirements_file): with open(requirements_file, 'r') as f: after_pull_whl_lines = [line for line in f if '.whl' in line] wheels_changed = wheels_changed or (before_pull_whl_lines != after_pull_whl_lines) # Check for changes to installer files for file in files_to_check: if before_hashes[file] != after_hashes[file]: print_big_message(f"File '{file}' was updated during 'git pull'. Please run the script again.") # Save state before exiting state = load_state() state['last_installed_commit'] = current_commit if wheels_changed: state['wheels_changed'] = True save_state(state) sys.exit(1) if os.environ.get("INSTALL_EXTENSIONS", "").lower() in ("yes", "y", "true", "1", "t", "on"): install_extensions_requirements() if is_linux(): run_cmd(f"conda install -y -c conda-forge 'libstdcxx-ng>={LIBSTDCXX_VERSION_LINUX}'", assert_success=True, environment=True) # Update PyTorch if not initial_installation: update_pytorch_and_python() clean_outdated_pytorch_cuda_dependencies() print_big_message(f"Installing webui requirements from file: {requirements_file}") print(f"GPU Choice: {gpu_choice}\n") # Prepare the requirements file textgen_requirements = open(requirements_file).read().splitlines() all_whl_lines = [line.strip() for line in textgen_requirements if '.whl' in line] if not initial_installation: if installed_wheels: # Per-wheel comparison: only re-download wheels that changed textgen_requirements = [ line for line in textgen_requirements if '.whl' not in line or line.strip() not in installed_wheels ] elif not wheels_changed: textgen_requirements = [line for line in textgen_requirements if '.whl' not in line] with open('temp_requirements.txt', 'w') as file: file.write('\n'.join(textgen_requirements)) # Workaround for git+ packages not updating properly. git_requirements = [req for req in textgen_requirements if req.startswith("git+")] for req in git_requirements: url = req.replace("git+", "") package_name = url.split("/")[-1].split("@")[0].rstrip(".git") run_cmd(f"python -m pip uninstall -y {package_name}", environment=True) print(f"Uninstalled {package_name}") # Install/update the project requirements run_cmd("python -m pip install -r temp_requirements.txt --upgrade", assert_success=True, environment=True) # Save state after successful installation state = load_state() state['last_installed_commit'] = current_commit state['installed_wheels'] = all_whl_lines state.pop('wheels_changed', None) save_state(state) # Clean up os.remove('temp_requirements.txt') clear_cache() def install_extensions_requirements(): print_big_message("Installing extensions requirements.\nSome of these may fail on Windows.\nDon\'t worry if you see error messages, as they will not affect the main program.") extensions = get_extensions_names() for i, extension in enumerate(extensions): print(f"\n\n--- [{i + 1}/{len(extensions)}]: {extension}\n\n") extension_req_path = os.path.join("extensions", extension, "requirements.txt") run_cmd(f"python -m pip install -r {extension_req_path} --upgrade", assert_success=False, environment=True) def launch_webui(): run_cmd(f"python server.py {flags}", environment=True) if __name__ == "__main__": # Verifies we are in a conda environment check_env() parser = argparse.ArgumentParser(add_help=False) parser.add_argument('--update-wizard', action='store_true', help='Launch a menu with update options.') args, _ = parser.parse_known_args() if args.update_wizard: while True: choice = get_user_choice( "What would you like to do?", { 'A': 'Update the web UI', 'B': 'Install/update extensions requirements', 'C': 'Revert local changes to repository files with \"git reset --hard\"', 'N': 'Nothing (exit)' }, ) if choice == 'A': update_requirements() elif choice == 'B': choices = {'A': 'All extensions'} for i, name in enumerate(get_extensions_names()): key = generate_alphabetic_sequence(i + 1) choices[key] = name choice = get_user_choice("What extension?", choices) if choice == 'A': install_extensions_requirements() else: extension_req_path = os.path.join("extensions", choices[choice], "requirements.txt") run_cmd(f"python -m pip install -r {extension_req_path} --upgrade", assert_success=False, environment=True) update_requirements(pull=False) elif choice == 'C': run_cmd("git reset --hard", assert_success=True, environment=True) elif choice == 'N': sys.exit() else: if not is_installed(): install_webui() os.chdir(script_dir) if os.environ.get("LAUNCH_AFTER_INSTALL", "").lower() in ("no", "n", "false", "0", "f", "off"): print_big_message("Will now exit due to LAUNCH_AFTER_INSTALL.") sys.exit() # Check if a model has been downloaded yet if '--model-dir' in flags: # Splits on ' ' or '=' while maintaining spaces within quotes flags_list = re.split(' +(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)|=', flags) model_dir = [flags_list[(flags_list.index(flag) + 1)] for flag in flags_list if flag == '--model-dir'][0].strip('"\'') else: model_dir = 'user_data/models' if len([item for item in glob.glob(f'{model_dir}/*') if not item.endswith(('.txt', '.yaml'))]) == 0: print_big_message("You haven't downloaded any model yet.\nOnce the web UI launches, head over to the \"Model\" tab and download one.") # Workaround for llama-cpp-python loading paths in CUDA env vars even if they do not exist conda_path_bin = os.path.join(conda_env_path, "bin") if not os.path.exists(conda_path_bin): os.mkdir(conda_path_bin) # Launch the webui launch_webui() ================================================ FILE: requirements/full/requirements.txt ================================================ accelerate==1.12.* audioop-lts<1.0; python_version >= "3.13" bitsandbytes==0.49.* datasets diffusers==0.37.* einops fastapi==0.112.4 flash-linear-attention==0.4.* huggingface-hub==1.5.* jinja2==3.1.6 markdown numpy==2.2.* pandas peft==0.18.* Pillow>=9.5.0 pydantic==2.11.0 pymupdf==1.27.1 python-docx==1.1.2 pyyaml requests rich safetensors==0.7.* scipy sentencepiece tensorboard torchao==0.15.* trafilatura==2.0.0 transformers==5.3.* triton-windows==3.5.1.post24; platform_system == "Windows" tqdm wandb # Gradio https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl # API flask_cloudflared==0.0.15 sse-starlette==1.6.5 tiktoken # CUDA wheels https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+cu124-py3-none-win_amd64.whl; platform_system == "Windows" https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+cu124-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" https://github.com/turboderp-org/exllamav3/releases/download/v0.0.25/exllamav3-0.0.25+cu128.torch2.9.0-cp313-cp313-win_amd64.whl; platform_system == "Windows" and python_version == "3.13" https://github.com/turboderp-org/exllamav3/releases/download/v0.0.25/exllamav3-0.0.25+cu128.torch2.9.0-cp313-cp313-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.13" https://github.com/kingbri1/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu128torch2.9.0cxx11abiFALSE-cp313-cp313-win_amd64.whl; platform_system == "Windows" and python_version == "3.13" https://github.com/kingbri1/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu128torch2.9.0cxx11abiFALSE-cp313-cp313-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.13" ================================================ FILE: requirements/full/requirements_amd.txt ================================================ accelerate==1.12.* audioop-lts<1.0; python_version >= "3.13" datasets diffusers==0.37.* einops fastapi==0.112.4 huggingface-hub==1.5.* jinja2==3.1.6 markdown numpy==2.2.* pandas peft==0.18.* Pillow>=9.5.0 pydantic==2.11.0 pymupdf==1.27.1 python-docx==1.1.2 pyyaml requests rich safetensors==0.7.* scipy sentencepiece tensorboard torchao==0.15.* transformers==5.3.* tqdm trafilatura==2.0.0 wandb # Gradio https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl # API flask_cloudflared==0.0.15 sse-starlette==1.6.5 tiktoken # AMD wheels https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+rocm7.2-py3-none-win_amd64.whl; platform_system == "Windows" https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+rocm7.2-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" ================================================ FILE: requirements/full/requirements_apple_intel.txt ================================================ accelerate==1.12.* audioop-lts<1.0; python_version >= "3.13" datasets diffusers==0.37.* einops fastapi==0.112.4 huggingface-hub==1.5.* jinja2==3.1.6 markdown numpy==2.2.* pandas peft==0.18.* Pillow>=9.5.0 pydantic==2.11.0 pymupdf==1.27.1 python-docx==1.1.2 pyyaml requests rich safetensors==0.7.* scipy sentencepiece tensorboard torchao==0.15.* transformers==5.3.* tqdm trafilatura==2.0.0 wandb # Gradio https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl # API flask_cloudflared==0.0.15 sse-starlette==1.6.5 tiktoken # Mac wheels https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0-py3-none-macosx_13_0_x86_64.whl; platform_system == "Darwin" ================================================ FILE: requirements/full/requirements_apple_silicon.txt ================================================ accelerate==1.12.* audioop-lts<1.0; python_version >= "3.13" datasets diffusers==0.37.* einops fastapi==0.112.4 huggingface-hub==1.5.* jinja2==3.1.6 markdown numpy==2.2.* pandas peft==0.18.* Pillow>=9.5.0 pydantic==2.11.0 pymupdf==1.27.1 python-docx==1.1.2 pyyaml requests rich safetensors==0.7.* scipy sentencepiece tensorboard torchao==0.15.* transformers==5.3.* tqdm trafilatura==2.0.0 wandb # Gradio https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl # API flask_cloudflared==0.0.15 sse-starlette==1.6.5 tiktoken # Mac wheels https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0-py3-none-macosx_13_0_arm64.whl; platform_system == "Darwin" ================================================ FILE: requirements/full/requirements_cpu_only.txt ================================================ accelerate==1.12.* audioop-lts<1.0; python_version >= "3.13" datasets diffusers==0.37.* einops fastapi==0.112.4 huggingface-hub==1.5.* jinja2==3.1.6 markdown numpy==2.2.* pandas peft==0.18.* Pillow>=9.5.0 pydantic==2.11.0 pymupdf==1.27.1 python-docx==1.1.2 pyyaml requests rich safetensors==0.7.* scipy sentencepiece tensorboard torchao==0.15.* transformers==5.3.* tqdm trafilatura==2.0.0 wandb # Gradio https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl # API flask_cloudflared==0.0.15 sse-starlette==1.6.5 tiktoken # llama.cpp (CPU only) https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+cpu-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+cpu-py3-none-win_amd64.whl; platform_system == "Windows" ================================================ FILE: requirements/full/requirements_nowheels.txt ================================================ accelerate==1.12.* audioop-lts<1.0; python_version >= "3.13" datasets diffusers==0.37.* einops fastapi==0.112.4 huggingface-hub==1.5.* jinja2==3.1.6 markdown numpy==2.2.* pandas peft==0.18.* Pillow>=9.5.0 pydantic==2.11.0 pymupdf==1.27.1 python-docx==1.1.2 pyyaml requests rich safetensors==0.7.* scipy sentencepiece tensorboard torchao==0.15.* transformers==5.3.* tqdm trafilatura==2.0.0 wandb # Gradio https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl # API flask_cloudflared==0.0.15 sse-starlette==1.6.5 tiktoken ================================================ FILE: requirements/portable/requirements.txt ================================================ audioop-lts<1.0; python_version >= "3.13" fastapi==0.112.4 huggingface-hub==1.5.* jinja2==3.1.6 markdown numpy==2.2.* pydantic==2.11.0 pymupdf==1.27.1 python-docx==1.1.2 pyyaml requests rich trafilatura==2.0.0 tqdm # Gradio https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl # API flask_cloudflared==0.0.15 sse-starlette==1.6.5 tiktoken # CUDA wheels https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+cu124-py3-none-win_amd64.whl; platform_system == "Windows" https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+cu124-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" ================================================ FILE: requirements/portable/requirements_amd.txt ================================================ audioop-lts<1.0; python_version >= "3.13" fastapi==0.112.4 huggingface-hub==1.5.* jinja2==3.1.6 markdown numpy==2.2.* pydantic==2.11.0 pymupdf==1.27.1 python-docx==1.1.2 pyyaml requests rich trafilatura==2.0.0 tqdm # Gradio https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl # API flask_cloudflared==0.0.15 sse-starlette==1.6.5 tiktoken # AMD wheels https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+rocm7.2-py3-none-win_amd64.whl; platform_system == "Windows" https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+rocm7.2-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" ================================================ FILE: requirements/portable/requirements_apple_intel.txt ================================================ audioop-lts<1.0; python_version >= "3.13" fastapi==0.112.4 huggingface-hub==1.5.* jinja2==3.1.6 markdown numpy==2.2.* pydantic==2.11.0 pymupdf==1.27.1 python-docx==1.1.2 pyyaml requests rich trafilatura==2.0.0 tqdm # Gradio https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl # API flask_cloudflared==0.0.15 sse-starlette==1.6.5 tiktoken # Mac wheels https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0-py3-none-macosx_13_0_x86_64.whl; platform_system == "Darwin" ================================================ FILE: requirements/portable/requirements_apple_silicon.txt ================================================ audioop-lts<1.0; python_version >= "3.13" fastapi==0.112.4 huggingface-hub==1.5.* jinja2==3.1.6 markdown numpy==2.2.* pydantic==2.11.0 pymupdf==1.27.1 python-docx==1.1.2 pyyaml requests rich trafilatura==2.0.0 tqdm # Gradio https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl # API flask_cloudflared==0.0.15 sse-starlette==1.6.5 tiktoken # Mac wheels https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0-py3-none-macosx_13_0_arm64.whl; platform_system == "Darwin" ================================================ FILE: requirements/portable/requirements_cpu_only.txt ================================================ audioop-lts<1.0; python_version >= "3.13" fastapi==0.112.4 huggingface-hub==1.5.* jinja2==3.1.6 markdown numpy==2.2.* pydantic==2.11.0 pymupdf==1.27.1 python-docx==1.1.2 pyyaml requests rich trafilatura==2.0.0 tqdm # Gradio https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl # API flask_cloudflared==0.0.15 sse-starlette==1.6.5 tiktoken # llama.cpp (CPU only) https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+cpu-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+cpu-py3-none-win_amd64.whl; platform_system == "Windows" ================================================ FILE: requirements/portable/requirements_cuda131.txt ================================================ audioop-lts<1.0; python_version >= "3.13" fastapi==0.112.4 huggingface-hub==1.5.* jinja2==3.1.6 markdown numpy==2.2.* pydantic==2.11.0 pymupdf==1.27.1 python-docx==1.1.2 pyyaml requests rich trafilatura==2.0.0 tqdm # Gradio https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl # API flask_cloudflared==0.0.15 sse-starlette==1.6.5 tiktoken # CUDA wheels https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+cu131-py3-none-win_amd64.whl; platform_system == "Windows" https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+cu131-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" ================================================ FILE: requirements/portable/requirements_nowheels.txt ================================================ audioop-lts<1.0; python_version >= "3.13" fastapi==0.112.4 huggingface-hub==1.5.* jinja2==3.1.6 markdown numpy==2.2.* pydantic==2.11.0 pymupdf==1.27.1 python-docx==1.1.2 pyyaml requests rich trafilatura==2.0.0 tqdm # Gradio https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl # API flask_cloudflared==0.0.15 sse-starlette==1.6.5 tiktoken ================================================ FILE: requirements/portable/requirements_vulkan.txt ================================================ audioop-lts<1.0; python_version >= "3.13" fastapi==0.112.4 huggingface-hub==1.5.* jinja2==3.1.6 markdown numpy==2.2.* pydantic==2.11.0 pymupdf==1.27.1 python-docx==1.1.2 pyyaml requests rich trafilatura==2.0.0 tqdm # Gradio https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio-4.37.2+custom.11-py3-none-any.whl https://github.com/oobabooga/gradio/releases/download/4.37.2-custom.11/gradio_client-1.0.2+custom.11-py3-none-any.whl # API flask_cloudflared==0.0.15 sse-starlette==1.6.5 tiktoken # Vulkan wheels https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+vulkan-py3-none-win_amd64.whl; platform_system == "Windows" https://github.com/oobabooga/llama-cpp-binaries/releases/download/v0.94.0/llama_cpp_binaries-0.94.0+vulkan-py3-none-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" ================================================ FILE: server.py ================================================ import os import signal import sys import time import warnings from functools import partial from pathlib import Path from threading import Lock, Thread import yaml from modules import shared, utils from modules.image_models import load_image_model from modules.logging_colors import logger from modules.prompts import load_prompt import modules.extensions as extensions_module from modules.LoRA import add_lora_to_model from modules.models import load_model, unload_model_if_idle from modules.models_settings import ( get_fallback_settings, get_model_metadata, update_model_parameters ) from modules.shared import do_cmd_flags_warnings os.environ['BITSANDBYTES_NOWELCOME'] = '1' warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') warnings.filterwarnings('ignore', category=UserWarning, message='Using the update method is deprecated') warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_name" has conflict') warnings.filterwarnings('ignore', category=UserWarning, message='Field "model_names" has conflict') def signal_handler(sig, frame): # On second Ctrl+C, force an immediate exit signal.signal(signal.SIGINT, signal.SIG_DFL) signal.signal(signal.SIGTERM, signal.SIG_DFL) logger.info("Received Ctrl+C. Shutting down Text Generation Web UI gracefully.") # Explicitly stop LlamaServer to avoid __del__ cleanup issues during shutdown if shared.model and shared.model.__class__.__name__ == 'LlamaServer': try: shared.model.stop() except Exception: pass sys.exit(0) signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) def create_interface(): import shutil import gradio as gr from modules import ( training, ui, ui_chat, ui_default, ui_file_saving, ui_image_generation, ui_model_menu, ui_notebook, ui_parameters, ui_session, ) from modules.chat import generate_pfp_cache from modules.extensions import apply_extensions from modules.utils import gradio warnings.filterwarnings('ignore', category=UserWarning, message='The value passed into gr.Dropdown()') # Set up Gradio temp directory path gradio_temp_path = shared.user_data_dir / 'cache' / 'gradio' shutil.rmtree(gradio_temp_path, ignore_errors=True) gradio_temp_path.mkdir(parents=True, exist_ok=True) os.environ.update({ 'GRADIO_ANALYTICS_ENABLED': 'False', 'GRADIO_TEMP_DIR': str(gradio_temp_path) }) title = 'Text Generation Web UI' # Password authentication auth = [] if shared.args.gradio_auth: auth.extend(x.strip() for x in shared.args.gradio_auth.strip('"').replace('\n', '').split(',') if x.strip()) if shared.args.gradio_auth_path: with open(shared.args.gradio_auth_path, 'r', encoding="utf8") as file: auth.extend(x.strip() for line in file for x in line.split(',') if x.strip()) auth = [tuple(cred.split(':')) for cred in auth] # Allowed paths allowed_paths = ["css", "js", "extensions", str(shared.user_data_dir / "cache")] if not shared.args.multi_user: allowed_paths.append(str(shared.user_data_dir / "image_outputs")) # Import the extensions and execute their setup() functions if shared.args.extensions is not None and len(shared.args.extensions) > 0: extensions_module.load_extensions() # Force some events to be triggered on page load shared.persistent_interface_state.update({ 'mode': shared.settings['mode'], 'loader': shared.args.loader or 'llama.cpp', 'filter_by_loader': (shared.args.loader or 'All') if not shared.args.portable else 'llama.cpp' }) if not shared.settings['prompt-notebook']: shared.settings['prompt-notebook'] = utils.get_available_prompts()[0] prompt = load_prompt(shared.settings['prompt-notebook']) shared.persistent_interface_state.update({ 'textbox-default': prompt, 'textbox-notebook': prompt }) # Clear existing cache files for cache_file in ['pfp_character.png', 'pfp_character_thumb.png']: cache_path = shared.user_data_dir / "cache" / cache_file if cache_path.exists(): cache_path.unlink() # Regenerate for default character if shared.settings['mode'] != 'instruct': generate_pfp_cache(shared.settings['character']) # css/js strings css = ui.css js = ui.js css += apply_extensions('css') js += apply_extensions('js') # Interface state elements shared.input_elements = ui.list_interface_input_elements() # Head HTML for font preloads, KaTeX, highlight.js, morphdom, and global JS head_html = '\n'.join([ '', '', '', '', '', '', '', '', '', '', f'', '', f'', ]) with gr.Blocks(css=css, analytics_enabled=False, title=title, theme=ui.theme, head=head_html, dark_theme=shared.settings['dark_theme']) as shared.gradio['interface']: # Interface state shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements}) # Audio notification if (shared.user_data_dir / "notification.mp3").exists(): shared.gradio['audio_notification'] = gr.Audio(interactive=False, value=str(shared.user_data_dir / "notification.mp3"), elem_id="audio_notification", visible=False) # Floating menus for saving/deleting files ui_file_saving.create_ui() # Temporary clipboard for saving files shared.gradio['temporary_text'] = gr.Textbox(visible=False) # Chat tab ui_chat.create_ui() # Notebook tab with gr.Tab("Notebook", elem_id='notebook-parent-tab'): ui_default.create_ui() ui_notebook.create_ui() ui_parameters.create_ui() # Parameters tab ui_chat.create_character_settings_ui() # Character tab ui_model_menu.create_ui() # Model tab if not shared.args.portable: ui_image_generation.create_ui() # Image generation tab training.create_ui() # Training tab ui_session.create_ui() # Session tab # Generation events ui_chat.create_event_handlers() ui_default.create_event_handlers() ui_notebook.create_event_handlers() if not shared.args.portable: ui_image_generation.create_event_handlers() # Other events ui_file_saving.create_event_handlers() ui_parameters.create_event_handlers() ui_model_menu.create_event_handlers() # UI persistence events ui.setup_auto_save() # Interface launch events shared.gradio['interface'].load( None, gradio('show_controls'), None, js=f"""(x) => {{ {js} {ui.show_controls_js} toggle_controls(x); }}""" ) shared.gradio['interface'].load(partial(ui.apply_interface_values, {}, use_persistent=True), None, gradio(ui.list_interface_input_elements()), show_progress=False) # Sync theme_state with the actual client-side theme so that # autosave always writes the correct dark_theme value. shared.gradio['interface'].load(None, None, gradio('theme_state'), js='() => document.body.classList.contains("dark") ? "dark" : "light"') extensions_module.create_extensions_tabs() # Extensions tabs extensions_module.create_extensions_block() # Extensions block # Launch the interface shared.gradio['interface'].queue() shared.gradio['interface'].launch( max_threads=64, prevent_thread_lock=True, share=shared.args.share, server_name=None if not shared.args.listen else (shared.args.listen_host or '0.0.0.0'), server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch, auth=auth or None, ssl_verify=False if (shared.args.ssl_keyfile or shared.args.ssl_certfile) else True, ssl_keyfile=shared.args.ssl_keyfile, ssl_certfile=shared.args.ssl_certfile, root_path=shared.args.subpath, allowed_paths=allowed_paths, ) if __name__ == "__main__": logger.info("Starting Text Generation Web UI") do_cmd_flags_warnings() # Load custom settings settings_file = None if shared.args.settings is not None and Path(shared.args.settings).exists(): settings_file = Path(shared.args.settings) elif (shared.user_data_dir / 'settings.yaml').exists(): settings_file = shared.user_data_dir / 'settings.yaml' if settings_file is not None: logger.info(f"Loading settings from \"{settings_file}\"") with open(settings_file, 'r', encoding='utf-8') as f: new_settings = yaml.safe_load(f.read()) if new_settings: shared.settings.update(new_settings) # Apply CLI overrides for image model settings (CLI flags take precedence over saved settings) shared.apply_image_model_cli_overrides() # Fallback settings for models shared.model_config['.*'] = get_fallback_settings() shared.model_config.move_to_end('.*', last=False) # Move to the beginning # Activate the extensions listed on settings.yaml extensions_module.available_extensions = utils.get_available_extensions() for extension in shared.settings['default_extensions']: shared.args.extensions = shared.args.extensions or [] if extension not in shared.args.extensions: shared.args.extensions.append(extension) # Load image model if specified via CLI if shared.args.image_model: logger.info(f"Loading image model: {shared.args.image_model}") result = load_image_model( shared.args.image_model, dtype=shared.settings.get('image_dtype', 'bfloat16'), attn_backend=shared.settings.get('image_attn_backend', 'sdpa'), cpu_offload=shared.settings.get('image_cpu_offload', False), compile_model=shared.settings.get('image_compile', False), quant_method=shared.settings.get('image_quant', 'none') ) if result is not None: shared.image_model_name = shared.args.image_model else: logger.error(f"Failed to load image model: {shared.args.image_model}") available_models = utils.get_available_models() # Model defined through --model if shared.args.model is not None: shared.model_name = shared.args.model # Select the model from a command-line menu elif shared.args.model_menu: if len(available_models) == 0: logger.error('No models are available! Please download at least one.') sys.exit(0) else: print('The following models are available:\n') for i, model in enumerate(available_models): print(f'{i+1}. {model}') print(f'\nWhich one do you want to load? 1-{len(available_models)}\n') i = int(input()) - 1 print() shared.model_name = available_models[i] # If any model has been selected, load it if shared.model_name != 'None': model_settings = get_model_metadata(shared.model_name) update_model_parameters(model_settings, initial=True) # hijack the command-line arguments # Load the model shared.model, shared.tokenizer = load_model(shared.model_name) if shared.args.lora: add_lora_to_model(shared.args.lora) shared.generation_lock = Lock() if shared.args.idle_timeout > 0: timer_thread = Thread(target=unload_model_if_idle) timer_thread.daemon = True timer_thread.start() if shared.args.nowebui: # Start the API in standalone mode shared.args.extensions = [x for x in (shared.args.extensions or []) if x != 'gallery'] if shared.args.extensions: extensions_module.load_extensions() else: # Launch the web UI create_interface() while True: time.sleep(0.5) if shared.need_restart: shared.need_restart = False time.sleep(0.5) shared.gradio['interface'].close() time.sleep(0.5) create_interface() ================================================ FILE: setup.cfg ================================================ [pycodestyle] max-line-length = 120 ignore = E402, E501, E722 ================================================ FILE: start_linux.sh ================================================ #!/usr/bin/env bash # environment isolation export PYTHONNOUSERSITE=1 unset PYTHONPATH unset PYTHONHOME cd "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" # Portable install case if [ -d "portable_env" ]; then ./portable_env/bin/python3 server.py --portable --api --auto-launch "$@" exit $? fi if [[ "$(pwd)" =~ " " ]]; then echo This script relies on Miniforge which can not be silently installed under a path with spaces. && exit; fi # deactivate existing conda envs as needed to avoid conflicts { conda deactivate && conda deactivate && conda deactivate; } 2> /dev/null OS_ARCH=$(uname -m) case "${OS_ARCH}" in x86_64*) OS_ARCH="x86_64";; arm64*) OS_ARCH="aarch64";; aarch64*) OS_ARCH="aarch64";; *) echo "Unknown system architecture: $OS_ARCH! This script runs only on x86_64 or arm64" && exit esac # config INSTALL_DIR="$(pwd)/installer_files" CONDA_ROOT_PREFIX="$(pwd)/installer_files/conda" INSTALL_ENV_DIR="$(pwd)/installer_files/env" MINIFORGE_DOWNLOAD_URL="https://github.com/conda-forge/miniforge/releases/download/26.1.0-0/Miniforge3-26.1.0-0-Linux-${OS_ARCH}.sh" conda_exists="F" # figure out whether git and conda needs to be installed if "$CONDA_ROOT_PREFIX/bin/conda" --version &>/dev/null; then conda_exists="T"; fi # (if necessary) install git and conda into a contained environment # download miniforge if [ "$conda_exists" == "F" ]; then echo "Downloading Miniforge from $MINIFORGE_DOWNLOAD_URL to $INSTALL_DIR/miniforge_installer.sh" mkdir -p "$INSTALL_DIR" curl -L "$MINIFORGE_DOWNLOAD_URL" > "$INSTALL_DIR/miniforge_installer.sh" chmod u+x "$INSTALL_DIR/miniforge_installer.sh" bash "$INSTALL_DIR/miniforge_installer.sh" -b -p $CONDA_ROOT_PREFIX # test the conda binary echo "Miniforge version:" "$CONDA_ROOT_PREFIX/bin/conda" --version # delete the Miniforge installer rm "$INSTALL_DIR/miniforge_installer.sh" fi # create the installer env if [ ! -e "$INSTALL_ENV_DIR" ]; then "$CONDA_ROOT_PREFIX/bin/conda" create -y -k --prefix "$INSTALL_ENV_DIR" python=3.13 fi # check if conda environment was actually created if [ ! -e "$INSTALL_ENV_DIR/bin/python" ]; then echo "Conda environment is empty." exit fi export CUDA_PATH="$INSTALL_ENV_DIR" export CUDA_HOME="$CUDA_PATH" # activate installer env source "$CONDA_ROOT_PREFIX/etc/profile.d/conda.sh" # otherwise conda complains about 'shell not initialized' (needed when running in a script) conda activate "$INSTALL_ENV_DIR" # setup installer env python one_click.py $@ ================================================ FILE: start_macos.sh ================================================ #!/bin/bash # environment isolation export PYTHONNOUSERSITE=1 unset PYTHONPATH unset PYTHONHOME cd "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" # Portable install case if [ -d "portable_env" ]; then ./portable_env/bin/python3 server.py --portable --api --auto-launch --api-port 5005 "$@" exit $? fi if [[ "$(pwd)" =~ " " ]]; then echo This script relies on Miniforge which can not be silently installed under a path with spaces. && exit; fi # deactivate existing conda envs as needed to avoid conflicts { conda deactivate && conda deactivate && conda deactivate; } 2> /dev/null # M Series or Intel OS_ARCH=$(uname -m) case "${OS_ARCH}" in x86_64*) OS_ARCH="x86_64";; arm64*) OS_ARCH="arm64";; *) echo "Unknown system architecture: $OS_ARCH! This script runs only on x86_64 or arm64" && exit esac # config INSTALL_DIR="$(pwd)/installer_files" CONDA_ROOT_PREFIX="$(pwd)/installer_files/conda" INSTALL_ENV_DIR="$(pwd)/installer_files/env" MINIFORGE_DOWNLOAD_URL="https://github.com/conda-forge/miniforge/releases/download/26.1.0-0/Miniforge3-26.1.0-0-MacOSX-${OS_ARCH}.sh" conda_exists="F" # figure out whether git and conda needs to be installed if "$CONDA_ROOT_PREFIX/bin/conda" --version &>/dev/null; then conda_exists="T"; fi # (if necessary) install git and conda into a contained environment # download miniforge if [ "$conda_exists" == "F" ]; then echo "Downloading Miniforge from $MINIFORGE_DOWNLOAD_URL to $INSTALL_DIR/miniforge_installer.sh" mkdir -p "$INSTALL_DIR" curl -L "$MINIFORGE_DOWNLOAD_URL" > "$INSTALL_DIR/miniforge_installer.sh" chmod u+x "$INSTALL_DIR/miniforge_installer.sh" bash "$INSTALL_DIR/miniforge_installer.sh" -b -p $CONDA_ROOT_PREFIX # test the conda binary echo "Miniforge version:" "$CONDA_ROOT_PREFIX/bin/conda" --version # delete the Miniforge installer rm "$INSTALL_DIR/miniforge_installer.sh" fi # create the installer env if [ ! -e "$INSTALL_ENV_DIR" ]; then "$CONDA_ROOT_PREFIX/bin/conda" create -y -k --prefix "$INSTALL_ENV_DIR" python=3.13 fi # check if conda environment was actually created if [ ! -e "$INSTALL_ENV_DIR/bin/python" ]; then echo "Conda environment is empty." exit fi export CUDA_PATH="$INSTALL_ENV_DIR" export CUDA_HOME="$CUDA_PATH" # activate installer env source "$CONDA_ROOT_PREFIX/etc/profile.d/conda.sh" # otherwise conda complains about 'shell not initialized' (needed when running in a script) conda activate "$INSTALL_ENV_DIR" # setup installer env python one_click.py $@ ================================================ FILE: start_windows.bat ================================================ @echo off setlocal enabledelayedexpansion @rem environment isolation set PYTHONNOUSERSITE=1 set PYTHONPATH= set PYTHONHOME= set PYTHONUTF8=1 cd /D "%~dp0" @rem Portable install case if exist "portable_env" ( .\portable_env\python.exe server.py --portable --api --auto-launch %* exit /b %errorlevel% ) set PATH=%PATH%;%SystemRoot%\system32 echo "%CD%"| findstr /C:" " >nul && echo This script relies on Miniforge which can not be silently installed under a path with spaces. && goto end @rem Check for special characters in installation path set "SPCHARMESSAGE="WARNING: Special characters were detected in the installation path!" " This can cause the installation to fail!"" echo "%CD%"| findstr /R /C:"[!#\$%&()\*+,;<=>?@\[\]\^`{|}~]" >nul && ( call :PrintBigMessage %SPCHARMESSAGE% ) set SPCHARMESSAGE= @rem fix failed install when installing to a separate drive set TMP=%cd%\installer_files set TEMP=%cd%\installer_files @rem deactivate existing conda envs as needed to avoid conflicts (call conda deactivate && call conda deactivate && call conda deactivate) 2>nul @rem config set INSTALL_DIR=%cd%\installer_files set CONDA_ROOT_PREFIX=%cd%\installer_files\conda set INSTALL_ENV_DIR=%cd%\installer_files\env set MINIFORGE_DOWNLOAD_URL=https://github.com/conda-forge/miniforge/releases/download/26.1.0-0/Miniforge3-26.1.0-0-Windows-x86_64.exe set MINIFORGE_CHECKSUM=0ad64473c20a8649be9313f64ee898f4b23a35a7a25ea9998a751c542e5e3840 set conda_exists=F @rem figure out whether git and conda needs to be installed call "%CONDA_ROOT_PREFIX%\_conda.exe" --version >nul 2>&1 if "%ERRORLEVEL%" EQU "0" set conda_exists=T @rem (if necessary) install git and conda into a contained environment @rem download conda if "%conda_exists%" == "F" ( echo Downloading Miniforge from %MINIFORGE_DOWNLOAD_URL% to %INSTALL_DIR%\miniforge_installer.exe mkdir "%INSTALL_DIR%" call curl -Lk "%MINIFORGE_DOWNLOAD_URL%" > "%INSTALL_DIR%\miniforge_installer.exe" || ( echo. && echo Miniforge failed to download. && goto end ) @rem Try CertUtil first for /f %%a in ('CertUtil -hashfile "%INSTALL_DIR%\miniforge_installer.exe" SHA256 ^| find /i /v " " ^| find /i "%MINIFORGE_CHECKSUM%"') do ( set "output=%%a" ) @rem If CertUtil fails, try PowerShell if not defined output ( for /f %%a in ('powershell -Command "if((Get-FileHash \"%INSTALL_DIR%\miniforge_installer.exe\" -Algorithm SHA256).Hash -eq ''%MINIFORGE_CHECKSUM%''){echo true}"') do ( set "output=%%a" ) ) if not defined output ( echo The checksum verification for miniforge_installer.exe has failed. del "%INSTALL_DIR%\miniforge_installer.exe" goto end ) else ( echo The checksum verification for miniforge_installer.exe has passed successfully. ) echo Installing Miniforge to %CONDA_ROOT_PREFIX% start /wait "" "%INSTALL_DIR%\miniforge_installer.exe" /InstallationType=JustMe /NoShortcuts=1 /AddToPath=0 /RegisterPython=0 /NoRegistry=1 /S /D=%CONDA_ROOT_PREFIX% @rem test the conda binary echo Miniforge version: call "%CONDA_ROOT_PREFIX%\_conda.exe" --version || ( echo. && echo Miniforge not found. && goto end ) @rem delete the Miniforge installer del "%INSTALL_DIR%\miniforge_installer.exe" ) @rem create the installer env if not exist "%INSTALL_ENV_DIR%" ( echo Packages to install: %PACKAGES_TO_INSTALL% call "%CONDA_ROOT_PREFIX%\_conda.exe" create --no-shortcuts -y -k --prefix "%INSTALL_ENV_DIR%" python=3.13 || ( echo. && echo Conda environment creation failed. && goto end ) ) @rem check if conda environment was actually created if not exist "%INSTALL_ENV_DIR%\python.exe" ( echo. && echo Conda environment is empty. && goto end ) set "CUDA_PATH=%INSTALL_ENV_DIR%" set "CUDA_HOME=%CUDA_PATH%" @rem activate installer env call "%CONDA_ROOT_PREFIX%\condabin\conda.bat" activate "%INSTALL_ENV_DIR%" || ( echo. && echo Miniforge hook not found. && goto end ) @rem setup installer env call "%INSTALL_ENV_DIR%\python.exe" one_click.py %* @rem below are functions for the script next line skips these during normal execution goto end :PrintBigMessage echo. && echo. echo ******************************************************************* for %%M in (%*) do echo * %%~M echo ******************************************************************* echo. && echo. exit /b :end pause ================================================ FILE: update_wizard_linux.sh ================================================ #!/usr/bin/env bash cd "$(dirname "${BASH_SOURCE[0]}")" if [[ "$(pwd)" =~ " " ]]; then echo This script relies on Miniforge which can not be silently installed under a path with spaces. && exit; fi # deactivate existing conda envs as needed to avoid conflicts { conda deactivate && conda deactivate && conda deactivate; } 2> /dev/null # config CONDA_ROOT_PREFIX="$(pwd)/installer_files/conda" INSTALL_ENV_DIR="$(pwd)/installer_files/env" # environment isolation export PYTHONNOUSERSITE=1 unset PYTHONPATH unset PYTHONHOME export CUDA_PATH="$INSTALL_ENV_DIR" export CUDA_HOME="$CUDA_PATH" # activate installer env source "$CONDA_ROOT_PREFIX/etc/profile.d/conda.sh" # otherwise conda complains about 'shell not initialized' (needed when running in a script) conda activate "$INSTALL_ENV_DIR" # update installer env python one_click.py --update-wizard && echo -e "\nHave a great day!" ================================================ FILE: update_wizard_macos.sh ================================================ #!/bin/bash cd "$(dirname "${BASH_SOURCE[0]}")" if [[ "$(pwd)" =~ " " ]]; then echo This script relies on Miniforge which can not be silently installed under a path with spaces. && exit; fi # deactivate existing conda envs as needed to avoid conflicts { conda deactivate && conda deactivate && conda deactivate; } 2> /dev/null # config CONDA_ROOT_PREFIX="$(pwd)/installer_files/conda" INSTALL_ENV_DIR="$(pwd)/installer_files/env" # environment isolation export PYTHONNOUSERSITE=1 unset PYTHONPATH unset PYTHONHOME export CUDA_PATH="$INSTALL_ENV_DIR" export CUDA_HOME="$CUDA_PATH" # activate installer env source "$CONDA_ROOT_PREFIX/etc/profile.d/conda.sh" # otherwise conda complains about 'shell not initialized' (needed when running in a script) conda activate "$INSTALL_ENV_DIR" # update installer env python one_click.py --update-wizard && echo -e "\nHave a great day!" ================================================ FILE: update_wizard_windows.bat ================================================ @echo off cd /D "%~dp0" set PATH=%PATH%;%SystemRoot%\system32 echo "%CD%"| findstr /C:" " >nul && echo This script relies on Miniforge which can not be silently installed under a path with spaces. && goto end @rem fix failed install when installing to a separate drive set TMP=%cd%\installer_files set TEMP=%cd%\installer_files @rem deactivate existing conda envs as needed to avoid conflicts (call conda deactivate && call conda deactivate && call conda deactivate) 2>nul @rem config set CONDA_ROOT_PREFIX=%cd%\installer_files\conda set INSTALL_ENV_DIR=%cd%\installer_files\env @rem environment isolation set PYTHONNOUSERSITE=1 set PYTHONPATH= set PYTHONHOME= set "CUDA_PATH=%INSTALL_ENV_DIR%" set "CUDA_HOME=%CUDA_PATH%" @rem activate installer env call "%CONDA_ROOT_PREFIX%\condabin\conda.bat" activate "%INSTALL_ENV_DIR%" || ( echo. && echo Miniforge hook not found. && goto end ) @rem update installer env call "%INSTALL_ENV_DIR%\python.exe" one_click.py --update-wizard && ( echo. echo Have a great day! ) :end pause