[
  {
    "path": ".eslintignore",
    "content": "extensions\nextensions-disabled\nrepositories\nvenv"
  },
  {
    "path": ".eslintrc.js",
    "content": "/* global module */\nmodule.exports = {\n    env: {\n        browser: true,\n        es2021: true,\n    },\n    extends: \"eslint:recommended\",\n    parserOptions: {\n        ecmaVersion: \"latest\",\n    },\n    rules: {\n        \"arrow-spacing\": \"error\",\n        \"block-spacing\": \"error\",\n        \"brace-style\": \"error\",\n        \"comma-dangle\": [\"error\", \"only-multiline\"],\n        \"comma-spacing\": \"error\",\n        \"comma-style\": [\"error\", \"last\"],\n        \"curly\": [\"error\", \"multi-line\", \"consistent\"],\n        \"eol-last\": \"error\",\n        \"func-call-spacing\": \"error\",\n        \"function-call-argument-newline\": [\"error\", \"consistent\"],\n        \"function-paren-newline\": [\"error\", \"consistent\"],\n        \"indent\": [\"error\", 4],\n        \"key-spacing\": \"error\",\n        \"keyword-spacing\": \"error\",\n        \"linebreak-style\": [\"error\", \"unix\"],\n        \"no-extra-semi\": \"error\",\n        \"no-mixed-spaces-and-tabs\": \"error\",\n        \"no-multi-spaces\": \"error\",\n        \"no-redeclare\": [\"error\", {builtinGlobals: false}],\n        \"no-trailing-spaces\": \"error\",\n        \"no-unused-vars\": \"off\",\n        \"no-whitespace-before-property\": \"error\",\n        \"object-curly-newline\": [\"error\", {consistent: true, multiline: true}],\n        \"object-curly-spacing\": [\"error\", \"never\"],\n        \"operator-linebreak\": [\"error\", \"after\"],\n        \"quote-props\": [\"error\", \"consistent-as-needed\"],\n        \"semi\": [\"error\", \"always\"],\n        \"semi-spacing\": \"error\",\n        \"semi-style\": [\"error\", \"last\"],\n        \"space-before-blocks\": \"error\",\n        \"space-before-function-paren\": [\"error\", \"never\"],\n        \"space-in-parens\": [\"error\", \"never\"],\n        \"space-infix-ops\": \"error\",\n        \"space-unary-ops\": \"error\",\n        \"switch-colon-spacing\": \"error\",\n        \"template-curly-spacing\": [\"error\", \"never\"],\n        \"unicode-bom\": \"error\",\n    },\n    globals: {\n        //script.js\n        gradioApp: \"readonly\",\n        executeCallbacks: \"readonly\",\n        onAfterUiUpdate: \"readonly\",\n        onOptionsChanged: \"readonly\",\n        onUiLoaded: \"readonly\",\n        onUiUpdate: \"readonly\",\n        uiCurrentTab: \"writable\",\n        uiElementInSight: \"readonly\",\n        uiElementIsVisible: \"readonly\",\n        //ui.js\n        opts: \"writable\",\n        all_gallery_buttons: \"readonly\",\n        selected_gallery_button: \"readonly\",\n        selected_gallery_index: \"readonly\",\n        switch_to_txt2img: \"readonly\",\n        switch_to_img2img_tab: \"readonly\",\n        switch_to_img2img: \"readonly\",\n        switch_to_sketch: \"readonly\",\n        switch_to_inpaint: \"readonly\",\n        switch_to_inpaint_sketch: \"readonly\",\n        switch_to_extras: \"readonly\",\n        get_tab_index: \"readonly\",\n        create_submit_args: \"readonly\",\n        restart_reload: \"readonly\",\n        updateInput: \"readonly\",\n        onEdit: \"readonly\",\n        //extraNetworks.js\n        requestGet: \"readonly\",\n        popup: \"readonly\",\n        // profilerVisualization.js\n        createVisualizationTable: \"readonly\",\n        // from python\n        localization: \"readonly\",\n        // progrssbar.js\n        randomId: \"readonly\",\n        requestProgress: \"readonly\",\n        // imageviewer.js\n        modalPrevImage: \"readonly\",\n        modalNextImage: \"readonly\",\n        // localStorage.js\n        localSet: \"readonly\",\n        localGet: \"readonly\",\n        localRemove: \"readonly\",\n        // resizeHandle.js\n        setupResizeHandle: \"writable\"\n    }\n};\n"
  },
  {
    "path": ".git-blame-ignore-revs",
    "content": "# Apply ESlint\n9c54b78d9dde5601e916f308d9a9d6953ec39430"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/bug_report.yml",
    "content": "name: Bug Report\ndescription: You think something is broken in the UI\ntitle: \"[Bug]: \"\nlabels: [\"bug-report\"]\n\nbody:\n  - type: markdown\n    attributes:\n      value: |\n        > The title of the bug report should be short and descriptive.\n        > Use relevant keywords for searchability.\n        > Do not leave it blank, but also do not put an entire error log in it.\n  - type: checkboxes\n    attributes:\n      label: Checklist\n      description: |\n        Please perform basic debugging to see if extensions or configuration is the cause of the issue.\n        Basic debug procedure\n        　1. Disable all third-party extensions - check if extension is the cause\n        　2. Update extensions and webui - sometimes things just need to be updated\n        　3. Backup and remove your config.json and ui-config.json - check if the issue is caused by bad configuration\n        　4. Delete venv with third-party extensions disabled - sometimes extensions might cause wrong libraries to be installed\n        　5. Try a fresh installation webui in a different directory - see if a clean installation solves the issue\n        Before making a issue report please, check that the issue hasn't been reported recently.\n      options:\n        - label: The issue exists after disabling all extensions\n        - label: The issue exists on a clean installation of webui\n        - label: The issue is caused by an extension, but I believe it is caused by a bug in the webui\n        - label: The issue exists in the current version of the webui\n        - label: The issue has not been reported before recently\n        - label: The issue has been reported before but has not been fixed yet\n  - type: markdown\n    attributes:\n      value: |\n        > Please fill this form with as much information as possible. Don't forget to \"Upload Sysinfo\" and \"What browsers\" and provide screenshots if possible\n  - type: textarea\n    id: what-did\n    attributes:\n      label: What happened?\n      description: Tell us what happened in a very clear and simple way\n      placeholder: |\n        txt2img is not working as intended.\n    validations:\n      required: true\n  - type: textarea\n    id: steps\n    attributes:\n      label: Steps to reproduce the problem\n      description: Please provide us with precise step by step instructions on how to reproduce the bug\n      placeholder: |\n        1. Go to ...\n        2. Press ...\n        3. ...\n    validations:\n      required: true\n  - type: textarea\n    id: what-should\n    attributes:\n      label: What should have happened?\n      description: Tell us what you think the normal behavior should be\n      placeholder: |\n        WebUI should ...\n    validations:\n      required: true\n  - type: dropdown\n    id: browsers\n    attributes:\n      label: What browsers do you use to access the UI ?\n      multiple: true\n      options:\n        - Mozilla Firefox\n        - Google Chrome\n        - Brave\n        - Apple Safari\n        - Microsoft Edge\n        - Android\n        - iOS\n        - Other\n  - type: textarea\n    id: sysinfo\n    attributes:\n      label: Sysinfo\n      description: System info file, generated by WebUI. You can generate it in settings, on the Sysinfo page. Drag the file into the field to upload it. If you submit your report without including the sysinfo file, the report will be closed. If needed, review the report to make sure it includes no personal information you don't want to share. If you can't start WebUI, you can use --dump-sysinfo commandline argument to generate the file.\n      placeholder: |\n        1. Go to WebUI Settings -> Sysinfo -> Download system info.\n            If WebUI fails to launch, use --dump-sysinfo commandline argument to generate the file\n        2. Upload the Sysinfo as a attached file, Do NOT paste it in as plain text.\n    validations:\n      required: true\n  - type: textarea\n    id: logs\n    attributes:\n      label: Console logs\n      description: Please provide **full** cmd/terminal logs from the moment you started UI to the end of it, after the bug occurred. If it's very long, provide a link to pastebin or similar service.\n      render: Shell\n    validations:\n      required: true\n  - type: textarea\n    id: misc\n    attributes:\n      label: Additional information\n      description: | \n        Please provide us with any relevant additional info or context.\n        Examples:\n        　I have updated my GPU driver recently.\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/config.yml",
    "content": "blank_issues_enabled: false\ncontact_links:\n  - name: WebUI Community Support\n    url: https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions\n    about: Please ask and answer questions here.\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/feature_request.yml",
    "content": "name: Feature request\ndescription: Suggest an idea for this project\ntitle: \"[Feature Request]: \"\nlabels: [\"enhancement\"]\n\nbody:\n  - type: checkboxes\n    attributes:\n      label: Is there an existing issue for this?\n      description: Please search to see if an issue already exists for the feature you want, and that it's not implemented in a recent build/commit.\n      options:\n        - label: I have searched the existing issues and checked the recent builds/commits\n          required: true\n  - type: markdown\n    attributes:\n      value: |\n        *Please fill this form with as much information as possible, provide screenshots and/or illustrations of the feature if possible*\n  - type: textarea\n    id: feature\n    attributes:\n      label: What would your feature do ?\n      description: Tell us about your feature in a very clear and simple way, and what problem it would solve\n    validations:\n      required: true\n  - type: textarea\n    id: workflow\n    attributes:\n      label: Proposed workflow\n      description: Please provide us with step by step information on how you'd like the feature to be accessed and used\n      value: |\n        1. Go to .... \n        2. Press ....\n        3. ...\n    validations:\n      required: true\n  - type: textarea\n    id: misc\n    attributes:\n      label: Additional information\n      description: Add any other context or screenshots about the feature request here.\n"
  },
  {
    "path": ".github/pull_request_template.md",
    "content": "## Description\n\n* a simple description of what you're trying to accomplish\n* a summary of changes in code\n* which issues it fixes, if any\n\n## Screenshots/videos:\n\n\n## Checklist:\n\n- [ ] I have read [contributing wiki page](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing)\n- [ ] I have performed a self-review of my own code\n- [ ] My code follows the [style guidelines](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing#code-style)\n- [ ] My code passes [tests](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Tests)\n"
  },
  {
    "path": ".github/workflows/on_pull_request.yaml",
    "content": "name: Linter\n\non:\n  - push\n  - pull_request\n\njobs:\n  lint-python:\n    name: ruff\n    runs-on: ubuntu-latest\n    if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name != github.event.pull_request.base.repo.full_name\n    steps:\n      - name: Checkout Code\n        uses: actions/checkout@v4\n      - uses: actions/setup-python@v5\n        with:\n          python-version: 3.11\n          # NB: there's no cache: pip here since we're not installing anything\n          #     from the requirements.txt file(s) in the repository; it's faster\n          #     not to have GHA download an (at the time of writing) 4 GB cache\n          #     of PyTorch and other dependencies.\n      - name: Install Ruff\n        run: pip install ruff==0.3.3\n      - name: Run Ruff\n        run: ruff .\n  lint-js:\n    name: eslint\n    runs-on: ubuntu-latest\n    if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name != github.event.pull_request.base.repo.full_name\n    steps:\n      - name: Checkout Code\n        uses: actions/checkout@v4\n      - name: Install Node.js\n        uses: actions/setup-node@v4\n        with:\n          node-version: 18\n      - run: npm i --ci\n      - run: npm run lint\n"
  },
  {
    "path": ".github/workflows/run_tests.yaml",
    "content": "name: Tests\n\non:\n  - push\n  - pull_request\n\njobs:\n  test:\n    name: tests on CPU with empty model\n    runs-on: ubuntu-latest\n    if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name != github.event.pull_request.base.repo.full_name\n    steps:\n      - name: Checkout Code\n        uses: actions/checkout@v4\n      - name: Set up Python 3.10\n        uses: actions/setup-python@v5\n        with:\n          python-version: 3.10.6\n          cache: pip\n          cache-dependency-path: |\n            **/requirements*txt\n            launch.py\n      - name: Cache models\n        id: cache-models\n        uses: actions/cache@v4\n        with:\n          path: models\n          key: \"2023-12-30\"\n      - name: Install test dependencies\n        run: pip install wait-for-it -r requirements-test.txt\n        env:\n          PIP_DISABLE_PIP_VERSION_CHECK: \"1\"\n          PIP_PROGRESS_BAR: \"off\"\n      - name: Setup environment\n        run: python launch.py --skip-torch-cuda-test --exit\n        env:\n          PIP_DISABLE_PIP_VERSION_CHECK: \"1\"\n          PIP_PROGRESS_BAR: \"off\"\n          TORCH_INDEX_URL: https://download.pytorch.org/whl/cpu\n          WEBUI_LAUNCH_LIVE_OUTPUT: \"1\"\n          PYTHONUNBUFFERED: \"1\"\n      - name: Print installed packages\n        run: pip freeze\n      - name: Start test server\n        run: >\n          python -m coverage run\n          --data-file=.coverage.server\n          launch.py\n          --skip-prepare-environment\n          --skip-torch-cuda-test\n          --test-server\n          --do-not-download-clip\n          --no-half\n          --disable-opt-split-attention\n          --use-cpu all\n          --api-server-stop\n          2>&1 | tee output.txt &\n      - name: Run tests\n        run: |\n          wait-for-it --service 127.0.0.1:7860 -t 20\n          python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test\n      - name: Kill test server\n        if: always()\n        run: curl -vv -XPOST http://127.0.0.1:7860/sdapi/v1/server-stop && sleep 10\n      - name: Show coverage\n        run: |\n          python -m coverage combine .coverage*\n          python -m coverage report -i\n          python -m coverage html -i\n      - name: Upload main app output\n        uses: actions/upload-artifact@v4\n        if: always()\n        with:\n          name: output\n          path: output.txt\n      - name: Upload coverage HTML\n        uses: actions/upload-artifact@v4\n        if: always()\n        with:\n          name: htmlcov\n          path: htmlcov\n"
  },
  {
    "path": ".github/workflows/warns_merge_master.yml",
    "content": "name: Pull requests can't target master branch\n\n\"on\":\n  pull_request:\n    types:\n      - opened\n      - synchronize\n      - reopened\n    branches:\n      - master\n\njobs:\n  check:\n    runs-on: ubuntu-latest\n    steps:\n      - name: Warning marge into master\n        run: |\n          echo -e \"::warning::This pull request directly merge into \\\"master\\\" branch, normally development happens on \\\"dev\\\" branch.\"\n          exit 1\n"
  },
  {
    "path": ".gitignore",
    "content": "__pycache__\n*.ckpt\n*.safetensors\n*.pth\n.DS_Store\n/ESRGAN/*\n/SwinIR/*\n/repositories\n/venv\n/tmp\n/model.ckpt\n/models/**/*\n/GFPGANv1.3.pth\n/gfpgan/weights/*.pth\n/ui-config.json\n/outputs\n/config.json\n/log\n/webui.settings.bat\n/embeddings\n/styles.csv\n/params.txt\n/styles.csv.bak\n/webui-user.bat\n/webui-user.sh\n/interrogate\n/user.css\n/.idea\nnotification.mp3\n/SwinIR\n/textual_inversion\n.vscode\n/extensions\n/test/stdout.txt\n/test/stderr.txt\n/cache.json*\n/config_states/\n/node_modules\n/package-lock.json\n/.coverage*\n/test/test_outputs\n/cache\ntrace.json\n/sysinfo-????-??-??-??-??.json\n"
  },
  {
    "path": ".pylintrc",
    "content": "# See https://pylint.pycqa.org/en/latest/user_guide/messages/message_control.html\n[MESSAGES CONTROL]\ndisable=C,R,W,E,I\n"
  },
  {
    "path": "CHANGELOG.md",
    "content": "## 1.10.1\r\n\r\n### Bug Fixes:\r\n* fix image upscale on cpu ([#16275](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16275))\r\n\r\n\r\n## 1.10.0\r\n\r\n### Features:\r\n* A lot of performance improvements (see below in Performance section)\r\n* Stable Diffusion 3 support ([#16030](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16030), [#16164](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16164), [#16212](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16212))\r\n  * Recommended Euler sampler; DDIM and other timestamp samplers currently not supported\r\n  * T5 text model is disabled by default, enable it in settings\r\n* New schedulers:\r\n  * Align Your Steps ([#15751](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15751))\r\n  * KL Optimal ([#15608](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15608))\r\n  * Normal ([#16149](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16149))\r\n  * DDIM ([#16149](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16149))\r\n  * Simple ([#16142](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16142))\r\n  * Beta ([#16235](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16235))\r\n* New sampler: DDIM CFG++ ([#16035](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16035))\r\n\r\n### Minor:\r\n* Option to skip CFG on early steps ([#15607](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15607))\r\n* Add --models-dir option ([#15742](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15742))\r\n* Allow mobile users to open context menu by using two fingers press ([#15682](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15682))\r\n* Infotext: add Lora name as TI hashes for bundled Textual Inversion ([#15679](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15679))\r\n* Check model's hash after downloading it to prevent corruped downloads ([#15602](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15602))\r\n* More extension tag filtering options ([#15627](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15627))\r\n* When saving AVIF, use JPEG's quality setting ([#15610](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15610))\r\n* Add filename pattern: `[basename]` ([#15978](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15978))\r\n* Add option to enable clip skip for clip L on SDXL ([#15992](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15992))\r\n* Option to prevent screen sleep during generation ([#16001](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16001))\r\n* ToggleLivePriview button in image viewer ([#16065](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16065))\r\n* Remove ui flashing on reloading and fast scrollong ([#16153](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16153))\r\n* option to disable save button log.csv ([#16242](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16242))\r\n\r\n### Extensions and API:\r\n* Add process_before_every_sampling hook ([#15984](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15984))\r\n* Return HTTP 400 instead of 404 on invalid sampler error ([#16140](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16140))\r\n\r\n### Performance:\r\n* [Performance 1/6] use_checkpoint = False ([#15803](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15803))\r\n* [Performance 2/6] Replace einops.rearrange with torch native ops ([#15804](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15804))\r\n* [Performance 4/6] Precompute is_sdxl_inpaint flag ([#15806](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15806))\r\n* [Performance 5/6] Prevent unnecessary extra networks bias backup ([#15816](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15816))\r\n* [Performance 6/6] Add --precision half option to avoid casting during inference ([#15820](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15820))\r\n* [Performance] LDM optimization patches ([#15824](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15824))\r\n* [Performance] Keep sigmas on CPU ([#15823](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15823))\r\n* Check for nans in unet only once, after all steps have been completed\r\n* Added pption to run torch profiler for image generation\r\n\r\n### Bug Fixes:\r\n* Fix for grids without comprehensive infotexts ([#15958](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15958))\r\n* feat: lora partial update precede full update ([#15943](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15943))\r\n* Fix bug where file extension had an extra '.' under some circumstances ([#15893](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15893))\r\n* Fix corrupt model initial load loop ([#15600](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15600))\r\n* Allow old sampler names in API ([#15656](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15656))\r\n* more old sampler scheduler compatibility ([#15681](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15681))\r\n* Fix Hypertile xyz ([#15831](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15831))\r\n* XYZ CSV skipinitialspace ([#15832](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15832))\r\n* fix soft inpainting on mps and xpu, torch_utils.float64 ([#15815](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15815))\r\n* fix extention update when not on main branch ([#15797](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15797))\r\n* update pickle safe filenames\r\n* use relative path for webui-assets css ([#15757](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15757))\r\n* When creating a virtual environment, upgrade pip in webui.bat/webui.sh ([#15750](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15750))\r\n* Fix AttributeError ([#15738](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15738))\r\n* use script_path for webui root in launch_utils ([#15705](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15705))\r\n* fix extra batch mode P Transparency ([#15664](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15664))\r\n* use gradio theme colors in css ([#15680](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15680))\r\n* Fix dragging text within prompt input ([#15657](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15657))\r\n* Add correct mimetype for .mjs files ([#15654](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15654))\r\n* QOL Items - handle metadata issues more cleanly for SD models, Loras and embeddings ([#15632](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15632))\r\n* replace wsl-open with wslpath and explorer.exe ([#15968](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15968))\r\n* Fix SDXL Inpaint ([#15976](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15976))\r\n* multi size grid ([#15988](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15988))\r\n* fix Replace preview ([#16118](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16118))\r\n* Possible fix of wrong scale in weight decomposition ([#16151](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16151))\r\n* Ensure use of python from venv on Mac and Linux ([#16116](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16116))\r\n* Prioritize python3.10 over python3 if both are available on Linux and Mac (with fallback) ([#16092](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16092))\r\n* stoping generation extras ([#16085](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16085))\r\n* Fix SD2 loading ([#16078](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16078), [#16079](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16079))\r\n* fix infotext Lora hashes for hires fix different lora ([#16062](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16062))\r\n* Fix sampler scheduler autocorrection warning ([#16054](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16054))\r\n* fix ui flashing on reloading and fast scrollong ([#16153](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16153))\r\n* fix upscale logic ([#16239](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16239))\r\n* [bug] do not break progressbar on non-job actions (add wrap_gradio_call_no_job) ([#16202](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16202))\r\n* fix OSError: cannot write mode P as JPEG ([#16194](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16194))\r\n\r\n### Other:\r\n* fix changelog #15883 -> #15882 ([#15907](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15907))\r\n* ReloadUI backgroundColor --background-fill-primary ([#15864](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15864))\r\n* Use different torch versions for Intel and ARM Macs ([#15851](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15851))\r\n* XYZ override rework ([#15836](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15836))\r\n* scroll extensions table on overflow ([#15830](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15830))\r\n* img2img batch upload method ([#15817](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15817))\r\n* chore: sync v1.8.0 packages according to changelog ([#15783](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15783))\r\n* Add AVIF MIME type support to mimetype definitions ([#15739](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15739))\r\n* Update imageviewer.js ([#15730](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15730))\r\n* no-referrer ([#15641](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15641))\r\n* .gitignore trace.json ([#15980](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15980))\r\n* Bump spandrel to 0.3.4 ([#16144](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16144))\r\n* Defunct --max-batch-count ([#16119](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16119))\r\n* docs: update bug_report.yml ([#16102](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16102))\r\n* Maintaining Project Compatibility for Python 3.9 Users Without Upgrade Requirements. ([#16088](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16088), [#16169](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16169), [#16192](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16192))\r\n* Update torch for ARM Macs to 2.3.1 ([#16059](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16059))\r\n* remove deprecated setting dont_fix_second_order_samplers_schedule ([#16061](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16061))\r\n* chore: fix typos ([#16060](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16060))\r\n* shlex.join launch args in console log ([#16170](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16170))\r\n* activate venv .bat ([#16231](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16231))\r\n* add ids to the resize tabs in img2img ([#16218](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16218))\r\n* update installation guide linux ([#16178](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16178))\r\n* Robust sysinfo ([#16173](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16173))\r\n* do not send image size on paste inpaint ([#16180](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16180))\r\n* Fix noisy DS_Store files for MacOS ([#16166](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/16166))\r\n\r\n\r\n## 1.9.4\r\n\r\n### Bug Fixes:\r\n*  pin setuptools version to fix the startup error ([#15882](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15882)) \r\n\r\n## 1.9.3\r\n\r\n### Bug Fixes:\r\n*  fix get_crop_region_v2 ([#15594](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15594)) \r\n\r\n## 1.9.2\r\n\r\n### Extensions and API:\r\n* restore 1.8.0-style naming of scripts\r\n\r\n## 1.9.1\r\n\r\n### Minor:\r\n* Add avif support ([#15582](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15582))\r\n* Add filename patterns: `[sampler_scheduler]` and `[scheduler]` ([#15581](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15581))\r\n\r\n### Extensions and API:\r\n* undo adding scripts to sys.modules\r\n* Add schedulers API endpoint ([#15577](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15577))\r\n* Remove API upscaling factor limits ([#15560](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15560))\r\n\r\n### Bug Fixes:\r\n* Fix images do not match / Coordinate 'right' is less than 'left' ([#15534](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15534))\r\n* fix: remove_callbacks_for_function should also remove from the ordered map ([#15533](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15533))\r\n* fix x1 upscalers ([#15555](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15555))\r\n* Fix cls.__module__ value in extension script ([#15532](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15532))\r\n* fix typo in function call (eror -> error) ([#15531](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15531))\r\n\r\n### Other:\r\n* Hide 'No Image data blocks found.' message ([#15567](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15567))\r\n* Allow webui.sh to be runnable from arbitrary directories containing a .git file ([#15561](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15561))\r\n* Compatibility with Debian 11, Fedora 34+ and openSUSE 15.4+ ([#15544](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15544))\r\n* numpy DeprecationWarning product -> prod ([#15547](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15547))\r\n* get_crop_region_v2 ([#15583](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15583), [#15587](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15587))\r\n\r\n\r\n## 1.9.0\r\n\r\n### Features:\r\n* Make refiner switchover based on model timesteps instead of sampling steps ([#14978](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14978))\r\n* add an option to have old-style directory view instead of tree view; stylistic changes for extra network sorting/search controls\r\n* add UI for reordering callbacks, support for specifying callback order in extension metadata ([#15205](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15205))\r\n* Sgm uniform scheduler for SDXL-Lightning models ([#15325](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15325))\r\n* Scheduler selection in main UI ([#15333](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15333), [#15361](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15361), [#15394](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15394))\r\n\r\n### Minor:\r\n* \"open images directory\" button now opens the actual dir ([#14947](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14947))\r\n* Support inference with LyCORIS BOFT networks ([#14871](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14871), [#14973](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14973))\r\n* make extra network card description plaintext by default, with an option to re-enable HTML as it was\r\n* resize handle for extra networks ([#15041](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15041))\r\n* cmd args: `--unix-filenames-sanitization` and `--filenames-max-length` ([#15031](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15031))\r\n* show extra networks parameters in HTML table rather than raw JSON ([#15131](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15131))\r\n* Add DoRA (weight-decompose) support for LoRA/LoHa/LoKr ([#15160](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15160), [#15283](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15283))\r\n* Add '--no-prompt-history' cmd args for disable last generation prompt history ([#15189](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15189))\r\n* update preview on Replace Preview ([#15201](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15201))\r\n* only fetch updates for extensions' active git branches ([#15233](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15233))\r\n* put upscale postprocessing UI into an accordion ([#15223](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15223))\r\n* Support dragdrop for URLs to read infotext ([#15262](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15262))\r\n* use diskcache library for caching ([#15287](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15287), [#15299](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15299))\r\n* Allow PNG-RGBA for Extras Tab ([#15334](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15334))\r\n* Support cover images embedded in safetensors metadata ([#15319](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15319))\r\n* faster interrupt when using NN upscale ([#15380](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15380))\r\n* Extras upscaler: an input field to limit maximul side length for the output image ([#15293](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15293), [#15415](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15415), [#15417](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15417), [#15425](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15425))\r\n* add an option to hide postprocessing options in Extras tab\r\n\r\n### Extensions and API:\r\n* ResizeHandleRow - allow overriden column scale parametr ([#15004](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15004))\r\n* call script_callbacks.ui_settings_callback earlier; fix extra-options-section built-in extension killing the ui if using a setting that doesn't exist\r\n* make it possible to use zoom.js outside webui context ([#15286](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15286), [#15288](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15288))\r\n* allow variants for extension name in metadata.ini ([#15290](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15290))\r\n* make reloading UI scripts optional when doing Reload UI, and off by default\r\n* put request: gr.Request at start of img2img function similar to txt2img\r\n* open_folder as util ([#15442](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15442))\r\n* make it possible to import extensions' script files as `import scripts.<filename>` ([#15423](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15423))\r\n\r\n### Performance:\r\n* performance optimization for extra networks HTML pages\r\n* optimization for extra networks filtering\r\n* optimization for extra networks sorting\r\n\r\n### Bug Fixes:\r\n* prevent escape button causing an interrupt when no generation has been made yet\r\n* [bug] avoid doble upscaling in inpaint ([#14966](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14966))\r\n* possible fix for reload button not appearing in some cases for extra networks.\r\n* fix: the `split_threshold` parameter does not work when running Split oversized images ([#15006](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15006))\r\n* Fix resize-handle visability for vertical layout (mobile) ([#15010](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15010))\r\n* register_tmp_file also for mtime ([#15012](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15012))\r\n* Protect alphas_cumprod during refiner switchover ([#14979](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14979))\r\n* Fix EXIF orientation in API image loading ([#15062](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15062))\r\n* Only override emphasis if actually used in prompt ([#15141](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15141))\r\n* Fix emphasis infotext missing from `params.txt` ([#15142](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15142))\r\n* fix extract_style_text_from_prompt #15132 ([#15135](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15135))\r\n* Fix Soft Inpaint for AnimateDiff ([#15148](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15148))\r\n* edit-attention: deselect surrounding whitespace ([#15178](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15178))\r\n* chore: fix font not loaded ([#15183](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15183))\r\n* use natural sort in extra networks when ordering by path\r\n* Fix built-in lora system bugs caused by torch.nn.MultiheadAttention ([#15190](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15190))\r\n* Avoid error from None in get_learned_conditioning ([#15191](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15191))\r\n* Add entry to MassFileLister after writing metadata ([#15199](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15199))\r\n* fix issue with Styles when Hires prompt is used ([#15269](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15269), [#15276](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15276))\r\n* Strip comments from hires fix prompt ([#15263](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15263))\r\n* Make imageviewer event listeners browser consistent ([#15261](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15261))\r\n* Fix AttributeError in OFT when trying to get MultiheadAttention weight ([#15260](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15260))\r\n* Add missing .mean() back ([#15239](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15239))\r\n* fix \"Restore progress\" button ([#15221](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15221))\r\n* fix ui-config for InputAccordion [custom_script_source] ([#15231](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15231))\r\n* handle 0 wheel deltaY ([#15268](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15268))\r\n* prevent alt menu for firefox ([#15267](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15267))\r\n* fix: fix syntax errors ([#15179](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15179))\r\n* restore outputs path ([#15307](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15307))\r\n* Escape btn_copy_path filename ([#15316](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15316))\r\n* Fix extra networks buttons when filename contains an apostrophe ([#15331](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15331))\r\n* escape brackets in lora random prompt generator ([#15343](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15343))\r\n* fix: Python version check for PyTorch installation compatibility ([#15390](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15390))\r\n* fix typo in call_queue.py ([#15386](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15386))\r\n* fix: when find already_loaded model, remove loaded by array index ([#15382](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15382))\r\n* minor bug fix of sd model memory management ([#15350](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15350))\r\n* Fix CodeFormer weight ([#15414](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15414))\r\n* Fix: Remove script callbacks in ordered_callbacks_map ([#15428](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15428))\r\n* fix limited file write (thanks, Sylwia)\r\n* Fix extra-single-image API not doing upscale failed ([#15465](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15465))\r\n* error handling paste_field callables ([#15470](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15470))\r\n\r\n### Hardware:\r\n* Add training support and change lspci for Ascend NPU ([#14981](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14981))\r\n* Update to ROCm5.7 and PyTorch ([#14820](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14820))\r\n* Better workaround for Navi1, removing --pre for Navi3 ([#15224](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15224))\r\n* Ascend NPU wiki page ([#15228](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15228))\r\n\r\n### Other:\r\n* Update comment for Pad prompt/negative prompt v0 to add a warning about truncation, make it override the v1 implementation\r\n* support resizable columns for touch (tablets) ([#15002](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15002))\r\n* Fix #14591 using translated content to do categories mapping ([#14995](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14995))\r\n* Use `absolute` path for normalized filepath ([#15035](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15035))\r\n* resizeHandle handle double tap ([#15065](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15065))\r\n* --dat-models-path cmd flag ([#15039](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15039))\r\n* Add a direct link to the binary release ([#15059](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15059))\r\n* upscaler_utils: Reduce logging ([#15084](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15084))\r\n* Fix various typos with crate-ci/typos ([#15116](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15116))\r\n* fix_jpeg_live_preview ([#15102](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15102))\r\n* [alternative fix] can't load webui if selected wrong extra option in ui ([#15121](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15121))\r\n* Error handling for unsupported transparency ([#14958](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14958))\r\n* Add model description to searched terms ([#15198](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15198))\r\n* bump action version ([#15272](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15272))\r\n* PEP 604 annotations ([#15259](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15259))\r\n* Automatically Set the Scale by value when user selects an Upscale Model ([#15244](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15244))\r\n* move postprocessing-for-training into builtin extensions ([#15222](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15222))\r\n* type hinting in shared.py ([#15211](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15211))\r\n* update ruff to 0.3.3\r\n* Update pytorch lightning utilities ([#15310](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15310))\r\n* Add Size as an XYZ Grid option ([#15354](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15354))\r\n* Use HF_ENDPOINT variable for HuggingFace domain with default ([#15443](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15443))\r\n* re-add update_file_entry ([#15446](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15446))\r\n* create_infotext allow index and callable, re-work Hires prompt infotext ([#15460](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15460))\r\n* update restricted_opts to include more options for --hide-ui-dir-config ([#15492](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15492))\r\n\r\n\r\n## 1.8.0\r\n\r\n### Features:\r\n* Update torch to version 2.1.2\r\n* Soft Inpainting ([#14208](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14208))\r\n* FP8 support ([#14031](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14031), [#14327](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14327))\r\n* Support for SDXL-Inpaint Model ([#14390](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14390))\r\n* Use Spandrel for upscaling and face restoration architectures ([#14425](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14425), [#14467](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14467), [#14473](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14473), [#14474](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14474), [#14477](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14477), [#14476](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14476), [#14484](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14484), [#14500](https://github.com/AUTOMATIC1111/stable-difusion-webui/pull/14500), [#14501](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14501), [#14504](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14504), [#14524](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14524), [#14809](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14809))\r\n* Automatic backwards version compatibility (when loading infotexts from old images with program version specified, will add compatibility settings)\r\n* Implement zero terminal SNR noise schedule option (**[SEED BREAKING CHANGE](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Seed-breaking-changes#180-dev-170-225-2024-01-01---zero-terminal-snr-noise-schedule-option)**, [#14145](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14145), [#14979](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14979))\r\n* Add a [✨] button to run hires fix on selected image in the gallery (with help from [#14598](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14598), [#14626](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14626), [#14728](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14728))\r\n* [Separate assets repository](https://github.com/AUTOMATIC1111/stable-diffusion-webui-assets); serve fonts locally rather than from google's servers\r\n* Official LCM Sampler Support ([#14583](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14583))\r\n* Add support for DAT upscaler models ([#14690](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14690), [#15039](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15039))\r\n* Extra Networks Tree View ([#14588](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14588), [#14900](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14900))\r\n* NPU Support ([#14801](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14801))\r\n* Prompt comments support\r\n\r\n### Minor:\r\n* Allow pasting in WIDTHxHEIGHT strings into the width/height fields ([#14296](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14296))\r\n* add option: Live preview in full page image viewer ([#14230](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14230), [#14307](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14307))\r\n* Add keyboard shortcuts for generate/skip/interrupt ([#14269](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14269))\r\n* Better TCMALLOC support on different platforms ([#14227](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14227), [#14883](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14883), [#14910](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14910))\r\n* Lora not found warning ([#14464](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14464))\r\n* Adding negative prompts to Loras in extra networks ([#14475](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14475))\r\n* xyz_grid: allow varying the seed along an axis separate from axis options ([#12180](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12180))\r\n* option to convert VAE to bfloat16 (implementation of [#9295](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9295))\r\n* Better IPEX support ([#14229](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14229), [#14353](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14353), [#14559](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14559), [#14562](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14562), [#14597](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14597))\r\n* Option to interrupt after current generation rather than immediately ([#13653](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13653), [#14659](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14659))\r\n* Fullscreen Preview control fading/disable ([#14291](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14291))\r\n* Finer settings freezing control ([#13789](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13789))\r\n* Increase Upscaler Limits ([#14589](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14589))\r\n* Adjust brush size with hotkeys ([#14638](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14638))\r\n* Add checkpoint info to csv log file when saving images ([#14663](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14663))\r\n* Make more columns resizable ([#14740](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14740), [#14884](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14884))\r\n* Add an option to not overlay original image for inpainting for #14727\r\n* Add Pad conds v0 option to support same generation with DDIM as before 1.6.0\r\n* Add \"Interrupting...\" placeholder.\r\n* Button for refresh extensions list ([#14857](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14857))\r\n* Add an option to disable normalization after calculating emphasis. ([#14874](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14874))\r\n* When counting tokens, also include enabled styles (can be disabled in settings to revert to previous behavior)\r\n* Configuration for the [📂] button for image gallery ([#14947](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14947))\r\n* Support inference with LyCORIS BOFT networks ([#14871](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14871), [#14973](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14973))\r\n* support resizable columns for touch (tablets) ([#15002](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15002))\r\n\r\n### Extensions and API:\r\n* Removed packages from requirements: basicsr, gfpgan, realesrgan; as well as their dependencies: absl-py, addict, beautifulsoup4, future, gdown, grpcio, importlib-metadata, lmdb, lpips, Markdown, platformdirs, PySocks, soupsieve, tb-nightly, tensorboard-data-server, tomli, Werkzeug, yapf, zipp, soupsieve\r\n* Enable task ids for API ([#14314](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14314))\r\n* add override_settings support for infotext API\r\n* rename generation_parameters_copypaste module to infotext_utils\r\n* prevent crash due to Script __init__ exception ([#14407](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14407))\r\n* Bump numpy to 1.26.2 ([#14471](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14471))\r\n* Add utility to inspect a model's dtype/device ([#14478](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14478))\r\n* Implement general forward method for all method in built-in lora ext ([#14547](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14547))\r\n* Execute model_loaded_callback after moving to target device ([#14563](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14563))\r\n* Add self to CFGDenoiserParams ([#14573](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14573))\r\n* Allow TLS with API only mode (--nowebui) ([#14593](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14593))\r\n* New callback: postprocess_image_after_composite ([#14657](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14657))\r\n* modules/api/api.py: add api endpoint to refresh embeddings list ([#14715](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14715))\r\n* set_named_arg ([#14773](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14773))\r\n* add before_token_counter callback and use it for prompt comments\r\n* ResizeHandleRow - allow overridden column scale parameter ([#15004](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15004))\r\n\r\n### Performance:\r\n* Massive performance improvement for extra networks directories with a huge number of files in them in an attempt to tackle #14507 ([#14528](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14528))\r\n* Reduce unnecessary re-indexing extra networks directory ([#14512](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14512))\r\n* Avoid unnecessary `isfile`/`exists` calls ([#14527](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14527))\r\n\r\n### Bug Fixes:\r\n* fix multiple bugs related to styles multi-file support ([#14203](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14203), [#14276](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14276), [#14707](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14707))\r\n* Lora fixes ([#14300](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14300), [#14237](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14237), [#14546](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14546), [#14726](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14726))\r\n* Re-add setting lost as part of e294e46 ([#14266](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14266))\r\n* fix extras caption BLIP ([#14330](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14330))\r\n* include infotext into saved init image for img2img ([#14452](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14452))\r\n* xyz grid handle axis_type is None ([#14394](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14394))\r\n* Update Added (Fixed) IPV6 Functionality When there is No Webui Argument Passed webui.py ([#14354](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14354))\r\n* fix API thread safe issues of txt2img and img2img ([#14421](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14421))\r\n* handle selectable script_index is None ([#14487](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14487))\r\n* handle config.json failed to load ([#14525](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14525), [#14767](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14767))\r\n* paste infotext cast int as float ([#14523](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14523))\r\n* Ensure GRADIO_ANALYTICS_ENABLED is set early enough ([#14537](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14537))\r\n* Fix logging configuration again ([#14538](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14538))\r\n* Handle CondFunc exception when resolving attributes ([#14560](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14560))\r\n* Fix extras big batch crashes ([#14699](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14699))\r\n* Fix using wrong model caused by alias ([#14655](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14655))\r\n* Add # to the invalid_filename_chars list ([#14640](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14640))\r\n* Fix extension check for requirements ([#14639](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14639))\r\n* Fix tab indexes are reset after restart UI ([#14637](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14637))\r\n* Fix nested manual cast ([#14689](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14689))\r\n* Keep postprocessing upscale selected tab after restart ([#14702](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14702))\r\n* XYZ grid: filter out blank vals when axis is int or float type (like int axis seed) ([#14754](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14754))\r\n* fix CLIP Interrogator topN regex ([#14775](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14775))\r\n* Fix dtype error in MHA layer/change dtype checking mechanism for manual cast ([#14791](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14791))\r\n* catch load style.csv error ([#14814](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14814))\r\n* fix error when editing extra networks card\r\n* fix extra networks metadata failing to work properly when you create the .json file with metadata for the first time.\r\n* util.walk_files extensions case insensitive ([#14879](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14879))\r\n* if extensions page not loaded, prevent apply ([#14873](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14873))\r\n* call the right function for token counter in img2img\r\n* Fix the bugs that search/reload will disappear when using other ExtraNetworks extensions ([#14939](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14939))\r\n* Gracefully handle mtime read exception from cache ([#14933](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14933))\r\n* Only trigger interrupt on `Esc` when interrupt button visible ([#14932](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14932))\r\n* Disable prompt token counters option actually disables token counting rather than just hiding results.\r\n* avoid double upscaling in inpaint ([#14966](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14966))\r\n* Fix #14591 using translated content to do categories mapping ([#14995](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14995))\r\n* fix: the `split_threshold` parameter does not work when running Split oversized images ([#15006](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15006))\r\n* Fix resize-handle for mobile ([#15010](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15010), [#15065](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15065))\r\n\r\n### Other:\r\n* Assign id for \"extra_options\". Replace numeric field with slider. ([#14270](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14270))\r\n* change state dict comparison to ref compare ([#14216](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14216))\r\n* Bump torch-rocm to 5.6/5.7 ([#14293](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14293))\r\n* Base output path off data path ([#14446](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14446))\r\n* reorder training preprocessing modules in extras tab ([#14367](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14367))\r\n* Remove `cleanup_models` code ([#14472](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14472))\r\n* only rewrite ui-config when there is change ([#14352](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14352))\r\n* Fix lint issue from 501993eb ([#14495](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14495))\r\n* Update README.md ([#14548](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14548))\r\n* hires button, fix seeds ()\r\n* Logging: set formatter correctly for fallback logger too ([#14618](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14618))\r\n* Read generation info from infotexts rather than json for internal needs (save, extract seed from generated pic) ([#14645](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14645))\r\n* improve get_crop_region ([#14709](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14709))\r\n* Bump safetensors' version to 0.4.2 ([#14782](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14782))\r\n* add tooltip create_submit_box ([#14803](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14803))\r\n* extensions tab table row hover highlight ([#14885](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14885))\r\n* Always add timestamp to displayed image ([#14890](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14890))\r\n* Added core.filemode=false so doesn't track changes in file permission… ([#14930](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14930))\r\n* Normalize command-line argument paths ([#14934](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14934), [#15035](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15035))\r\n* Use original App Title in progress bar ([#14916](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14916))\r\n* register_tmp_file also for mtime ([#15012](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15012))\r\n\r\n## 1.7.0\r\n\r\n### Features:\r\n* settings tab rework: add search field, add categories, split UI settings page into many\r\n* add altdiffusion-m18 support ([#13364](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13364))\r\n* support inference with LyCORIS GLora networks ([#13610](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13610))\r\n* add lora-embedding bundle system ([#13568](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13568))\r\n* option to move prompt from top row into generation parameters\r\n* add support for SSD-1B ([#13865](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13865))\r\n* support inference with OFT networks ([#13692](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13692))\r\n* script metadata and DAG sorting mechanism ([#13944](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13944))\r\n* support HyperTile optimization ([#13948](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13948))\r\n* add support for SD 2.1 Turbo ([#14170](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14170))\r\n* remove Train->Preprocessing tab and put all its functionality into Extras tab\r\n* initial IPEX support for Intel Arc GPU ([#14171](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14171))\r\n\r\n### Minor:\r\n* allow reading model hash from images in img2img batch mode ([#12767](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12767))\r\n* add option to align with sgm repo's sampling implementation ([#12818](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12818))\r\n* extra field for lora metadata viewer: `ss_output_name` ([#12838](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12838))\r\n* add action in settings page to calculate all SD checkpoint hashes ([#12909](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12909))\r\n* add button to copy prompt to style editor ([#12975](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12975))\r\n* add --skip-load-model-at-start option ([#13253](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13253))\r\n* write infotext to gif images\r\n* read infotext from gif images ([#13068](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13068))\r\n* allow configuring the initial state of InputAccordion in ui-config.json ([#13189](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13189))\r\n* allow editing whitespace delimiters for ctrl+up/ctrl+down prompt editing ([#13444](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13444))\r\n* prevent accidentally closing popup dialogs ([#13480](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13480))\r\n* added option to play notification sound or not ([#13631](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13631))\r\n* show the preview image in the full screen image viewer if available ([#13459](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13459))\r\n* support for webui.settings.bat ([#13638](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13638))\r\n* add an option to not print stack traces on ctrl+c\r\n* start/restart generation by Ctrl (Alt) + Enter ([#13644](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13644))\r\n* update prompts_from_file script to allow concatenating entries with the general prompt ([#13733](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13733))\r\n* added a visible checkbox to input accordion\r\n* added an option to hide all txt2img/img2img parameters in an accordion ([#13826](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13826))\r\n* added 'Path' sorting option for Extra network cards ([#13968](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13968))\r\n* enable prompt hotkeys in style editor ([#13931](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13931))\r\n* option to show batch img2img results in UI ([#14009](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14009))\r\n* infotext updates: add option to disregard certain infotext fields, add option to not include VAE in infotext, add explanation to infotext settings page, move some options to infotext settings page\r\n* add FP32 fallback support on sd_vae_approx ([#14046](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046))\r\n* support XYZ scripts / split hires path from unet ([#14126](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14126))\r\n* allow use of multiple styles csv files ([#14125](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14125))\r\n* make extra network card description plaintext by default, with an option (Treat card description as HTML) to re-enable HTML as it was (originally by [#13241](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13241))\r\n\r\n### Extensions and API:\r\n* update gradio to 3.41.2\r\n* support installed extensions list api ([#12774](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12774))\r\n* update pnginfo API to return dict with parsed values\r\n* add noisy latent to `ExtraNoiseParams` for callback ([#12856](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12856))\r\n* show extension datetime in UTC ([#12864](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12864), [#12865](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12865), [#13281](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13281))\r\n* add an option to choose how to combine hires fix and refiner\r\n* include program version in info response. ([#13135](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13135))\r\n* sd_unet support for SDXL\r\n* patch DDPM.register_betas so that users can put given_betas in model yaml ([#13276](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13276))\r\n* xyz_grid: add prepare ([#13266](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13266))\r\n* allow multiple localization files with same language in extensions ([#13077](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13077))\r\n* add onEdit function for js and rework token-counter.js to use it\r\n* fix the key error exception when processing override_settings keys ([#13567](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13567))\r\n* ability for extensions to return custom data via api in response.images ([#13463](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13463))\r\n* call state.jobnext() before postproces*() ([#13762](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13762))\r\n* add option to set notification sound volume ([#13884](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13884))\r\n* update Ruff to 0.1.6 ([#14059](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14059))\r\n* add Block component creation callback ([#14119](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14119))\r\n* catch uncaught exception with ui creation scripts ([#14120](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14120))\r\n* use extension name for determining an extension is installed in the index ([#14063](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14063))\r\n* update is_installed() from launch_utils.py to fix reinstalling already installed packages ([#14192](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14192))\r\n\r\n### Bug Fixes:\r\n* fix pix2pix producing bad results\r\n* fix defaults settings page breaking when any of main UI tabs are hidden\r\n* fix error that causes some extra networks to be disabled if both <lora:> and <lyco:> are present in the prompt\r\n* fix for Reload UI function: if you reload UI on one tab, other opened tabs will no longer stop working\r\n* prevent duplicate resize handler ([#12795](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12795))\r\n* small typo: vae resolve bug ([#12797](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12797))\r\n* hide broken image crop tool ([#12792](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12792))\r\n* don't show hidden samplers in dropdown for XYZ script ([#12780](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12780))\r\n* fix style editing dialog breaking if it's opened in both img2img and txt2img tabs\r\n* hide --gradio-auth and --api-auth values from /internal/sysinfo report\r\n* add missing infotext for RNG in options ([#12819](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12819))\r\n* fix notification not playing when built-in webui tab is inactive ([#12834](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12834))\r\n* honor `--skip-install` for extension installers ([#12832](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12832))\r\n* don't print blank stdout in extension installers ([#12833](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12833), [#12855](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12855))\r\n* get progressbar to display correctly in extensions tab\r\n* keep order in list of checkpoints when loading model that doesn't have a checksum\r\n* fix inpainting models in txt2img creating black pictures\r\n* fix generation params regex ([#12876](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12876))\r\n* fix batch img2img output dir with script ([#12926](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12926))\r\n* fix #13080 - Hypernetwork/TI preview generation ([#13084](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13084))\r\n* fix bug with sigma min/max overrides. ([#12995](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12995))\r\n* more accurate check for enabling cuDNN benchmark on 16XX cards ([#12924](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12924))\r\n* don't use multicond parser for negative prompt counter ([#13118](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13118))\r\n* fix data-sort-name containing spaces ([#13412](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13412))\r\n* update card on correct tab when editing metadata ([#13411](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13411))\r\n* fix viewing/editing metadata when filename contains an apostrophe ([#13395](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13395))\r\n* fix: --sd_model in \"Prompts from file or textbox\" script is not working ([#13302](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13302))\r\n* better Support for Portable Git ([#13231](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13231))\r\n* fix issues when webui_dir is not work_dir ([#13210](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13210))\r\n* fix: lora-bias-backup don't reset cache ([#13178](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13178))\r\n* account for customizable extra network separators whyen removing extra network text from the prompt ([#12877](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12877))\r\n* re fix batch img2img output dir with script ([#13170](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13170))\r\n* fix `--ckpt-dir` path separator and option use `short name` for checkpoint dropdown ([#13139](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13139))\r\n* consolidated allowed preview formats, Fix extra network `.gif` not woking as preview ([#13121](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13121))\r\n* fix venv_dir=- environment variable not working as expected on linux ([#13469](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13469))\r\n* repair unload sd checkpoint button\r\n* edit-attention fixes ([#13533](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13533))\r\n* fix bug when using --gfpgan-models-path ([#13718](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13718))\r\n* properly apply sort order for extra network cards when selected from dropdown\r\n* fixes generation restart not working for some users when 'Ctrl+Enter' is pressed ([#13962](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13962))\r\n* thread safe extra network list_items ([#13014](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13014))\r\n* fix not able to exit metadata popup when pop up is too big ([#14156](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14156))\r\n* fix auto focal point crop for opencv >= 4.8 ([#14121](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14121))\r\n* make 'use-cpu all' actually apply to 'all' ([#14131](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14131))\r\n* extras tab batch: actually use original filename\r\n* make webui not crash when running with --disable-all-extensions option\r\n\r\n### Other:\r\n* non-local condition ([#12814](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12814))\r\n* fix minor typos ([#12827](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12827))\r\n* remove xformers Python version check ([#12842](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12842))\r\n* style: file-metadata word-break ([#12837](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12837))\r\n* revert SGM noise multiplier change for img2img because it breaks hires fix\r\n* do not change quicksettings dropdown option when value returned is `None` ([#12854](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12854))\r\n* [RC 1.6.0 - zoom is partly hidden] Update style.css ([#12839](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12839))\r\n* chore: change extension time format ([#12851](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12851))\r\n* WEBUI.SH - Use torch 2.1.0 release candidate for Navi 3 ([#12929](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12929))\r\n* add Fallback at images.read_info_from_image if exif data was invalid ([#13028](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13028))\r\n* update cmd arg description ([#12986](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12986))\r\n* fix: update shared.opts.data when add_option ([#12957](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12957), [#13213](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13213))\r\n* restore missing tooltips ([#12976](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12976))\r\n* use default dropdown padding on mobile ([#12880](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12880))\r\n* put enable console prompts option into settings from commandline args ([#13119](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13119))\r\n* fix some deprecated types ([#12846](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12846))\r\n* bump to torchsde==0.2.6 ([#13418](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13418))\r\n* update dragdrop.js ([#13372](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13372))\r\n* use orderdict as lru cache:opt/bug ([#13313](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13313))\r\n* XYZ if not include sub grids do not save sub grid ([#13282](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13282))\r\n* initialize state.time_start befroe state.job_count ([#13229](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13229))\r\n* fix fieldname regex ([#13458](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13458))\r\n* change denoising_strength default to None. ([#13466](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13466))\r\n* fix regression ([#13475](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13475))\r\n* fix IndexError ([#13630](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13630))\r\n* fix: checkpoints_loaded:{checkpoint:state_dict}, model.load_state_dict issue in dict value empty ([#13535](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13535))\r\n* update bug_report.yml ([#12991](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12991))\r\n* requirements_versions httpx==0.24.1 ([#13839](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13839))\r\n* fix parenthesis auto selection ([#13829](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13829))\r\n* fix #13796 ([#13797](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13797))\r\n* corrected a typo in `modules/cmd_args.py` ([#13855](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13855))\r\n* feat: fix randn found element of type float at pos 2 ([#14004](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14004))\r\n* adds tqdm handler to logging_config.py for progress bar integration ([#13996](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13996))\r\n* hotfix: call shared.state.end() after postprocessing done ([#13977](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13977))\r\n* fix dependency address patch 1 ([#13929](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13929))\r\n* save sysinfo as .json ([#14035](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14035))\r\n* move exception_records related methods to errors.py ([#14084](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14084))\r\n* compatibility ([#13936](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13936))\r\n* json.dump(ensure_ascii=False) ([#14108](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14108))\r\n* dir buttons start with / so only the correct dir will be shown and no… ([#13957](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13957))\r\n* alternate implementation for unet forward replacement that does not depend on hijack being applied\r\n* re-add `keyedit_delimiters_whitespace` setting lost as part of commit e294e46 ([#14178](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14178))\r\n* fix `save_samples` being checked early when saving masked composite ([#14177](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14177))\r\n* slight optimization for mask and mask_composite ([#14181](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14181))\r\n* add import_hook hack to work around basicsr/torchvision incompatibility ([#14186](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14186))\r\n\r\n## 1.6.1\r\n\r\n### Bug Fixes:\r\n * fix an error causing the webui to fail to start ([#13839](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/13839))\r\n\r\n## 1.6.0\r\n\r\n### Features:\r\n * refiner support [#12371](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12371)\r\n * add NV option for Random number generator source setting, which allows to generate same pictures on CPU/AMD/Mac as on NVidia videocards\r\n * add style editor dialog\r\n * hires fix: add an option to use a different checkpoint for second pass ([#12181](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12181))\r\n * option to keep multiple loaded models in memory ([#12227](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12227))\r\n * new samplers: Restart, DPM++ 2M SDE Exponential, DPM++ 2M SDE Heun, DPM++ 2M SDE Heun Karras, DPM++ 2M SDE Heun Exponential, DPM++ 3M SDE, DPM++ 3M SDE Karras, DPM++ 3M SDE Exponential ([#12300](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12300), [#12519](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12519), [#12542](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12542))\r\n * rework DDIM, PLMS, UniPC to use CFG denoiser same as in k-diffusion samplers:\r\n   * makes all of them work with img2img\r\n   * makes prompt composition possible (AND)\r\n   * makes them available for SDXL\r\n * always show extra networks tabs in the UI ([#11808](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/11808))\r\n * use less RAM when creating models ([#11958](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/11958), [#12599](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12599))\r\n * textual inversion inference support for SDXL\r\n * extra networks UI: show metadata for SD checkpoints\r\n * checkpoint merger: add metadata support \r\n * prompt editing and attention: add support for whitespace after the number ([ red : green : 0.5 ]) (seed breaking change) ([#12177](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12177))\r\n * VAE: allow selecting own VAE for each checkpoint (in user metadata editor)\r\n * VAE: add selected VAE to infotext\r\n * options in main UI: add own separate setting for txt2img and img2img, correctly read values from pasted infotext, add setting for column count ([#12551](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12551))\r\n * add resize handle to txt2img and img2img tabs, allowing to change the amount of horizontable space given to generation parameters and resulting image gallery ([#12687](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12687), [#12723](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12723))\r\n * change default behavior for batching cond/uncond -- now it's on by default, and is disabled by an UI setting (Optimizatios -> Batch cond/uncond) - if you are on lowvram/medvram and are getting OOM exceptions, you will need to enable it\r\n * show current position in queue and make it so that requests are processed in the order of arrival ([#12707](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12707))\r\n * add `--medvram-sdxl` flag that only enables `--medvram` for SDXL models\r\n * prompt editing timeline has separate range for first pass and hires-fix pass (seed breaking change) ([#12457](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12457))\r\n\r\n### Minor:\r\n * img2img batch: RAM savings, VRAM savings, .tif, .tiff in img2img batch ([#12120](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12120), [#12514](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12514), [#12515](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12515))\r\n * postprocessing/extras: RAM savings ([#12479](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12479))\r\n * XYZ: in the axis labels, remove pathnames from model filenames\r\n * XYZ: support hires sampler ([#12298](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12298))\r\n * XYZ: new option: use text inputs instead of dropdowns ([#12491](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12491))\r\n * add gradio version warning\r\n * sort list of VAE checkpoints ([#12297](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12297))\r\n * use transparent white for mask in inpainting, along with an option to select the color ([#12326](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12326))\r\n * move some settings to their own section: img2img, VAE\r\n * add checkbox to show/hide dirs for extra networks\r\n * Add TAESD(or more) options for all the VAE encode/decode operation ([#12311](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12311))\r\n * gradio theme cache, new gradio themes, along with explanation that the user can input his own values ([#12346](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12346), [#12355](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12355))\r\n * sampler fixes/tweaks: s_tmax, s_churn, s_noise, s_tmax ([#12354](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12354), [#12356](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12356), [#12357](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12357), [#12358](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12358), [#12375](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12375), [#12521](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12521))\r\n * update README.md with correct instructions for Linux installation ([#12352](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12352))\r\n * option to not save incomplete images, on by default ([#12338](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12338))\r\n * enable cond cache by default\r\n * git autofix for repos that are corrupted ([#12230](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12230))\r\n * allow to open images in new browser tab by middle mouse button ([#12379](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12379))\r\n * automatically open webui in browser when running \"locally\" ([#12254](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12254))\r\n * put commonly used samplers on top, make DPM++ 2M Karras the default choice\r\n * zoom and pan: option to auto-expand a wide image, improved integration ([#12413](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12413), [#12727](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12727))\r\n * option to cache Lora networks in memory\r\n * rework hires fix UI to use accordion\r\n * face restoration and tiling moved to settings - use \"Options in main UI\" setting if you want them back\r\n * change quicksettings items to have variable width\r\n * Lora: add Norm module, add support for bias ([#12503](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12503))\r\n * Lora: output warnings in UI rather than fail for unfitting loras; switch to logging for error output in console\r\n * support search and display of hashes for all extra network items ([#12510](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12510))\r\n * add extra noise param for img2img operations ([#12564](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12564))\r\n * support for Lora with bias ([#12584](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12584))\r\n * make interrupt quicker ([#12634](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12634))\r\n * configurable gallery height ([#12648](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12648))\r\n * make results column sticky ([#12645](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12645))\r\n * more hash filename patterns ([#12639](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12639))\r\n * make image viewer actually fit the whole page ([#12635](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12635))\r\n * make progress bar work independently from live preview display which results in it being updated a lot more often\r\n * forbid Full live preview method for medvram and add a setting to undo the forbidding\r\n * make it possible to localize tooltips and placeholders\r\n * add option to align with sgm repo's sampling implementation ([#12818](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12818))\r\n * Restore faces and Tiling generation parameters have been moved to settings out of main UI\r\n   * if you want to put them back into main UI, use `Options in main UI` setting on the UI page.\r\n\r\n### Extensions and API:\r\n * gradio 3.41.2\r\n * also bump versions for packages: transformers, GitPython, accelerate, scikit-image, timm, tomesd\r\n * support tooltip kwarg for gradio elements: gr.Textbox(label='hello', tooltip='world')\r\n * properly clear the total console progressbar when using txt2img and img2img from API\r\n * add cmd_arg --disable-extra-extensions and --disable-all-extensions ([#12294](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12294))\r\n * shared.py and webui.py split into many files\r\n * add --loglevel commandline argument for logging\r\n * add a custom UI element that combines accordion and checkbox\r\n * avoid importing gradio in tests because it spams warnings\r\n * put infotext label for setting into OptionInfo definition rather than in a separate list\r\n * make `StableDiffusionProcessingImg2Img.mask_blur` a property, make more inline with PIL `GaussianBlur` ([#12470](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12470))\r\n * option to make scripts UI without gr.Group\r\n * add a way for scripts to register a callback for before/after just a single component's creation\r\n * use dataclass for StableDiffusionProcessing\r\n * store patches for Lora in a specialized module instead of inside torch\r\n * support http/https URLs in API ([#12663](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12663), [#12698](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12698))\r\n * add extra noise callback ([#12616](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12616))\r\n * dump current stack traces when exiting with SIGINT\r\n * add type annotations for extra fields of shared.sd_model\r\n\r\n### Bug Fixes:\r\n * Don't crash if out of local storage quota for javascriot localStorage\r\n * XYZ plot do not fail if an exception occurs\r\n * fix missing TI hash in infotext if generation uses both negative and positive TI ([#12269](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12269))\r\n * localization fixes ([#12307](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12307))\r\n * fix sdxl model invalid configuration after the hijack\r\n * correctly toggle extras checkbox for infotext paste ([#12304](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12304))\r\n * open raw sysinfo link in new page ([#12318](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12318))\r\n * prompt parser: Account for empty field in alternating words syntax ([#12319](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12319))\r\n * add tab and carriage return to invalid filename chars ([#12327](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12327))\r\n * fix api only Lora not working ([#12387](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12387))\r\n * fix options in main UI misbehaving when there's just one element\r\n * make it possible to use a sampler from infotext even if it's hidden in the dropdown\r\n * fix styles missing from the prompt in infotext when making a grid of batch of multiplie images\r\n * prevent bogus progress output in console when calculating hires fix dimensions\r\n * fix --use-textbox-seed\r\n * fix broken `Lora/Networks: use old method` option ([#12466](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12466))\r\n * properly return `None` for VAE hash when using `--no-hashing` ([#12463](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12463))\r\n * MPS/macOS fixes and optimizations ([#12526](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12526))\r\n * add second_order to samplers that mistakenly didn't have it\r\n * when refreshing cards in extra networks UI, do not discard user's custom resolution\r\n * fix processing error that happens if batch_size is not a multiple of how many prompts/negative prompts there are ([#12509](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12509))\r\n * fix inpaint upload for alpha masks ([#12588](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12588))\r\n * fix exception when image sizes are not integers ([#12586](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12586))\r\n * fix incorrect TAESD Latent scale ([#12596](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12596))\r\n * auto add data-dir to gradio-allowed-path ([#12603](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12603))\r\n * fix exception if extensuions dir is missing ([#12607](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12607))\r\n * fix issues with api model-refresh and vae-refresh ([#12638](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12638))\r\n * fix img2img background color for transparent images option not being used ([#12633](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12633))\r\n * attempt to resolve NaN issue with unstable VAEs in fp32 mk2 ([#12630](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12630))\r\n * implement missing undo hijack for SDXL\r\n * fix xyz swap axes ([#12684](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12684))\r\n * fix errors in backup/restore tab if any of config files are broken ([#12689](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12689))\r\n * fix SD VAE switch error after model reuse ([#12685](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12685))\r\n * fix trying to create images too large for the chosen format ([#12667](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12667))\r\n * create Gradio temp directory if necessary ([#12717](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12717))\r\n * prevent possible cache loss if exiting as it's being written by using an atomic operation to replace the cache with the new version\r\n * set devices.dtype_unet correctly\r\n * run RealESRGAN on GPU for non-CUDA devices ([#12737](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12737))\r\n * prevent extra network buttons being obscured by description for very small card sizes ([#12745](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12745))\r\n * fix error that causes some extra networks to be disabled if both <lora:> and <lyco:> are present in the prompt\r\n * fix defaults settings page breaking when any of main UI tabs are hidden\r\n * fix incorrect save/display of new values in Defaults page in settings\r\n * fix for Reload UI function: if you reload UI on one tab, other opened tabs will no longer stop working\r\n * fix an error that prevents VAE being reloaded after an option change if a VAE near the checkpoint exists ([#12797](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12737))\r\n * hide broken image crop tool ([#12792](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12737))\r\n * don't show hidden samplers in dropdown for XYZ script ([#12780](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12737))\r\n * fix style editing dialog breaking if it's opened in both img2img and txt2img tabs\r\n * fix a bug allowing users to bypass gradio and API authentication (reported by vysecurity) \r\n * fix notification not playing when built-in webui tab is inactive ([#12834](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12834))\r\n * honor `--skip-install` for extension installers ([#12832](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12832))\r\n * don't print blank stdout in extension installers ([#12833](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12832), [#12855](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12855))\r\n * do not change quicksettings dropdown option when value returned is `None` ([#12854](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12854))\r\n * get progressbar to display correctly in extensions tab\r\n\r\n\r\n## 1.5.2\r\n\r\n### Bug Fixes:\r\n * fix memory leak when generation fails\r\n * update doggettx cross attention optimization to not use an unreasonable amount of memory in some edge cases -- suggestion by MorkTheOrk\r\n\r\n\r\n## 1.5.1\r\n\r\n### Minor:\r\n * support parsing text encoder blocks in some new LoRAs\r\n * delete scale checker script due to user demand\r\n\r\n### Extensions and API:\r\n * add postprocess_batch_list script callback\r\n\r\n### Bug Fixes:\r\n * fix TI training for SD1\r\n * fix reload altclip model error\r\n * prepend the pythonpath instead of overriding it\r\n * fix typo in SD_WEBUI_RESTARTING\r\n * if txt2img/img2img raises an exception, finally call state.end()\r\n * fix composable diffusion weight parsing\r\n * restyle Startup profile for black users\r\n * fix webui not launching with --nowebui\r\n * catch exception for non git extensions\r\n * fix some options missing from /sdapi/v1/options\r\n * fix for extension update status always saying \"unknown\"\r\n * fix display of extra network cards that have `<>` in the name\r\n * update lora extension to work with python 3.8\r\n\r\n\r\n## 1.5.0\r\n\r\n### Features:\r\n * SD XL support\r\n * user metadata system for custom networks\r\n * extended Lora metadata editor: set activation text, default weight, view tags, training info\r\n * Lora extension rework to include other types of networks (all that were previously handled by LyCORIS extension)\r\n * show github stars for extensions\r\n * img2img batch mode can read extra stuff from png info\r\n * img2img batch works with subdirectories\r\n * hotkeys to move prompt elements: alt+left/right\r\n * restyle time taken/VRAM display\r\n * add textual inversion hashes to infotext\r\n * optimization: cache git extension repo information\r\n * move generate button next to the generated picture for mobile clients\r\n * hide cards for networks of incompatible Stable Diffusion version in Lora extra networks interface\r\n * skip installing packages with pip if they all are already installed - startup speedup of about 2 seconds\r\n\r\n### Minor:\r\n * checkbox to check/uncheck all extensions in the Installed tab\r\n * add gradio user to infotext and to filename patterns\r\n * allow gif for extra network previews\r\n * add options to change colors in grid\r\n * use natural sort for items in extra networks\r\n * Mac: use empty_cache() from torch 2 to clear VRAM\r\n * added automatic support for installing the right libraries for Navi3 (AMD)\r\n * add option SWIN_torch_compile to accelerate SwinIR upscale\r\n * suppress printing TI embedding info at start to console by default\r\n * speedup extra networks listing\r\n * added `[none]` filename token.\r\n * removed thumbs extra networks view mode (use settings tab to change width/height/scale to get thumbs)\r\n * add always_discard_next_to_last_sigma option to XYZ plot\r\n * automatically switch to 32-bit float VAE if the generated picture has NaNs without the need for `--no-half-vae` commandline flag.\r\n \r\n### Extensions and API:\r\n * api endpoints: /sdapi/v1/server-kill, /sdapi/v1/server-restart, /sdapi/v1/server-stop\r\n * allow Script to have custom metaclass\r\n * add model exists status check /sdapi/v1/options\r\n * rename --add-stop-route to --api-server-stop\r\n * add `before_hr` script callback\r\n * add callback `after_extra_networks_activate`\r\n * disable rich exception output in console for API by default, use WEBUI_RICH_EXCEPTIONS env var to enable\r\n * return http 404 when thumb file not found\r\n * allow replacing extensions index with environment variable\r\n \r\n### Bug Fixes:\r\n * fix for catch errors when retrieving extension index #11290\r\n * fix very slow loading speed of .safetensors files when reading from network drives\r\n * API cache cleanup\r\n * fix UnicodeEncodeError when writing to file CLIP Interrogator batch mode\r\n * fix warning of 'has_mps' deprecated from PyTorch\r\n * fix problem with extra network saving images as previews losing generation info\r\n * fix throwing exception when trying to resize image with I;16 mode\r\n * fix for #11534: canvas zoom and pan extension hijacking shortcut keys\r\n * fixed launch script to be runnable from any directory\r\n * don't add \"Seed Resize: -1x-1\" to API image metadata\r\n * correctly remove end parenthesis with ctrl+up/down\r\n * fixing --subpath on newer gradio version\r\n * fix: check fill size none zero when resize  (fixes #11425)\r\n * use submit and blur for quick settings textbox\r\n * save img2img batch with images.save_image()\r\n * prevent running preload.py for disabled extensions\r\n * fix: previously, model name was added together with directory name to infotext and to [model_name] filename pattern; directory name is now not included\r\n\r\n\r\n## 1.4.1\r\n\r\n### Bug Fixes:\r\n * add queue lock for refresh-checkpoints\r\n\r\n## 1.4.0\r\n\r\n### Features:\r\n * zoom controls for inpainting\r\n * run basic torch calculation at startup in parallel to reduce the performance impact of first generation\r\n * option to pad prompt/neg prompt to be same length\r\n * remove taming_transformers dependency\r\n * custom k-diffusion scheduler settings\r\n * add an option to show selected settings in main txt2img/img2img UI\r\n * sysinfo tab in settings\r\n * infer styles from prompts when pasting params into the UI\r\n * an option to control the behavior of the above\r\n\r\n### Minor:\r\n * bump Gradio to 3.32.0\r\n * bump xformers to 0.0.20\r\n * Add option to disable token counters\r\n * tooltip fixes & optimizations\r\n * make it possible to configure filename for the zip download\r\n * `[vae_filename]` pattern for filenames\r\n * Revert discarding penultimate sigma for DPM-Solver++(2M) SDE\r\n * change UI reorder setting to multiselect\r\n * read version info form CHANGELOG.md if git version info is not available\r\n * link footer API to Wiki when API is not active\r\n * persistent conds cache (opt-in optimization)\r\n \r\n### Extensions:\r\n * After installing extensions, webui properly restarts the process rather than reloads the UI \r\n * Added VAE listing to web API. Via: /sdapi/v1/sd-vae\r\n * custom unet support\r\n * Add onAfterUiUpdate callback\r\n * refactor EmbeddingDatabase.register_embedding() to allow unregistering\r\n * add before_process callback for scripts\r\n * add ability for alwayson scripts to specify section and let user reorder those sections\r\n \r\n### Bug Fixes:\r\n * Fix dragging text to prompt\r\n * fix incorrect quoting for infotext values with colon in them\r\n * fix \"hires. fix\" prompt sharing same labels with txt2img_prompt\r\n * Fix s_min_uncond default type int\r\n * Fix for #10643 (Inpainting mask sometimes not working)\r\n * fix bad styling for thumbs view in extra networks #10639\r\n * fix for empty list of optimizations #10605\r\n * small fixes to prepare_tcmalloc for Debian/Ubuntu compatibility\r\n * fix --ui-debug-mode exit\r\n * patch GitPython to not use leaky persistent processes\r\n * fix duplicate Cross attention optimization after UI reload\r\n * torch.cuda.is_available() check for SdOptimizationXformers\r\n * fix hires fix using wrong conds in second pass if using Loras.\r\n * handle exception when parsing generation parameters from png info\r\n * fix upcast attention dtype error\r\n * forcing Torch Version to 1.13.1 for RX 5000 series GPUs\r\n * split mask blur into X and Y components, patch Outpainting MK2 accordingly\r\n * don't die when a LoRA is a broken symlink\r\n * allow activation of Generate Forever during generation\r\n\r\n\r\n## 1.3.2\r\n\r\n### Bug Fixes:\r\n * fix files served out of tmp directory even if they are saved to disk\r\n * fix postprocessing overwriting parameters\r\n\r\n## 1.3.1\r\n\r\n### Features:\r\n * revert default cross attention optimization to Doggettx\r\n\r\n### Bug Fixes:\r\n * fix bug: LoRA don't apply on dropdown list sd_lora\r\n * fix png info always added even if setting is not enabled\r\n * fix some fields not applying in xyz plot\r\n * fix \"hires. fix\" prompt sharing same labels with txt2img_prompt\r\n * fix lora hashes not being added properly to infotex if there is only one lora\r\n * fix --use-cpu failing to work properly at startup\r\n * make --disable-opt-split-attention command line option work again\r\n\r\n## 1.3.0\r\n\r\n### Features:\r\n * add UI to edit defaults\r\n * token merging (via dbolya/tomesd)\r\n * settings tab rework: add a lot of additional explanations and links\r\n * load extensions' Git metadata in parallel to loading the main program to save a ton of time during startup\r\n * update extensions table: show branch, show date in separate column, and show version from tags if available\r\n * TAESD - another option for cheap live previews\r\n * allow choosing sampler and prompts for second pass of hires fix - hidden by default, enabled in settings\r\n * calculate hashes for Lora\r\n * add lora hashes to infotext\r\n * when pasting infotext, use infotext's lora hashes to find local loras for `<lora:xxx:1>` entries whose hashes match loras the user has\r\n * select cross attention optimization from UI\r\n\r\n### Minor:\r\n * bump Gradio to 3.31.0\r\n * bump PyTorch to 2.0.1 for macOS and Linux AMD\r\n * allow setting defaults for elements in extensions' tabs\r\n * allow selecting file type for live previews\r\n * show \"Loading...\" for extra networks when displaying for the first time\r\n * suppress ENSD infotext for samplers that don't use it\r\n * clientside optimizations\r\n * add options to show/hide hidden files and dirs in extra networks, and to not list models/files in hidden directories\r\n * allow whitespace in styles.csv\r\n * add option to reorder tabs\r\n * move some functionality (swap resolution and set seed to -1) to client\r\n * option to specify editor height for img2img\r\n * button to copy image resolution into img2img width/height sliders\r\n * switch from pyngrok to ngrok-py\r\n * lazy-load images in extra networks UI\r\n * set \"Navigate image viewer with gamepad\" option to false by default, by request\r\n * change upscalers to download models into user-specified directory (from commandline args) rather than the default models/<...>\r\n * allow hiding buttons in ui-config.json\r\n\r\n### Extensions:\r\n * add /sdapi/v1/script-info api\r\n * use Ruff to lint Python code\r\n * use ESlint to lint Javascript code\r\n * add/modify CFG callbacks for Self-Attention Guidance extension\r\n * add command and endpoint for graceful server stopping\r\n * add some locals (prompts/seeds/etc) from processing function into the Processing class as fields\r\n * rework quoting for infotext items that have commas in them to use JSON (should be backwards compatible except for cases where it didn't work previously)\r\n * add /sdapi/v1/refresh-loras api checkpoint post request\r\n * tests overhaul\r\n\r\n### Bug Fixes:\r\n * fix an issue preventing the program from starting if the user specifies a bad Gradio theme\r\n * fix broken prompts from file script\r\n * fix symlink scanning for extra networks\r\n * fix --data-dir ignored when launching via webui-user.bat COMMANDLINE_ARGS\r\n * allow web UI to be ran fully offline\r\n * fix inability to run with --freeze-settings\r\n * fix inability to merge checkpoint without adding metadata\r\n * fix extra networks' save preview image not adding infotext for jpeg/webm\r\n * remove blinking effect from text in hires fix and scale resolution preview\r\n * make links to `http://<...>.git` extensions work in the extension tab\r\n * fix bug with webui hanging at startup due to hanging git process\r\n\r\n\r\n## 1.2.1\r\n\r\n### Features:\r\n * add an option to always refer to LoRA by filenames\r\n\r\n### Bug Fixes:\r\n * never refer to LoRA by an alias if multiple LoRAs have same alias or the alias is called none\r\n * fix upscalers disappearing after the user reloads UI\r\n * allow bf16 in safe unpickler (resolves problems with loading some LoRAs)\r\n * allow web UI to be ran fully offline\r\n * fix localizations not working\r\n * fix error for LoRAs: `'LatentDiffusion' object has no attribute 'lora_layer_mapping'`\r\n\r\n## 1.2.0\r\n\r\n### Features:\r\n * do not wait for Stable Diffusion model to load at startup\r\n * add filename patterns: `[denoising]`\r\n * directory hiding for extra networks: dirs starting with `.` will hide their cards on extra network tabs unless specifically searched for\r\n * LoRA: for the `<...>` text in prompt, use name of LoRA that is in the metadata of the file, if present, instead of filename (both can be used to activate LoRA)\r\n * LoRA: read infotext params from kohya-ss's extension parameters if they are present and if his extension is not active\r\n * LoRA: fix some LoRAs not working (ones that have 3x3 convolution layer)\r\n * LoRA: add an option to use old method of applying LoRAs (producing same results as with kohya-ss)\r\n * add version to infotext, footer and console output when starting\r\n * add links to wiki for filename pattern settings\r\n * add extended info for quicksettings setting and use multiselect input instead of a text field\r\n\r\n### Minor:\r\n * bump Gradio to 3.29.0\r\n * bump PyTorch to 2.0.1\r\n * `--subpath` option for gradio for use with reverse proxy\r\n * Linux/macOS: use existing virtualenv if already active (the VIRTUAL_ENV environment variable)\r\n * do not apply localizations if there are none (possible frontend optimization)\r\n * add extra `None` option for VAE in XYZ plot\r\n * print error to console when batch processing in img2img fails\r\n * create HTML for extra network pages only on demand\r\n * allow directories starting with `.` to still list their models for LoRA, checkpoints, etc\r\n * put infotext options into their own category in settings tab\r\n * do not show licenses page when user selects Show all pages in settings\r\n\r\n### Extensions:\r\n * tooltip localization support\r\n * add API method to get LoRA models with prompt\r\n\r\n### Bug Fixes:\r\n * re-add `/docs` endpoint\r\n * fix gamepad navigation\r\n * make the lightbox fullscreen image function properly\r\n * fix squished thumbnails in extras tab\r\n * keep \"search\" filter for extra networks when user refreshes the tab (previously it showed everything after you refreshed)\r\n * fix webui showing the same image if you configure the generation to always save results into same file\r\n * fix bug with upscalers not working properly\r\n * fix MPS on PyTorch 2.0.1, Intel Macs\r\n * make it so that custom context menu from contextMenu.js only disappears after user's click, ignoring non-user click events\r\n * prevent Reload UI button/link from reloading the page when it's not yet ready\r\n * fix prompts from file script failing to read contents from a drag/drop file\r\n\r\n\r\n## 1.1.1\r\n### Bug Fixes:\r\n * fix an error that prevents running webui on PyTorch<2.0 without --disable-safe-unpickle\r\n\r\n## 1.1.0\r\n### Features:\r\n * switch to PyTorch 2.0.0 (except for AMD GPUs)\r\n * visual improvements to custom code scripts\r\n * add filename patterns: `[clip_skip]`, `[hasprompt<>]`, `[batch_number]`, `[generation_number]`\r\n * add support for saving init images in img2img, and record their hashes in infotext for reproducibility\r\n * automatically select current word when adjusting weight with ctrl+up/down\r\n * add dropdowns for X/Y/Z plot\r\n * add setting: Stable Diffusion/Random number generator source: makes it possible to make images generated from a given manual seed consistent across different GPUs\r\n * support Gradio's theme API\r\n * use TCMalloc on Linux by default; possible fix for memory leaks\r\n * add optimization option to remove negative conditioning at low sigma values #9177\r\n * embed model merge metadata in .safetensors file\r\n * extension settings backup/restore feature #9169\r\n * add \"resize by\" and \"resize to\" tabs to img2img\r\n * add option \"keep original size\" to textual inversion images preprocess\r\n * image viewer scrolling via analog stick\r\n * button to restore the progress from session lost / tab reload\r\n\r\n### Minor:\r\n * bump Gradio to 3.28.1\r\n * change \"scale to\" to sliders in Extras tab\r\n * add labels to tool buttons to make it possible to hide them\r\n * add tiled inference support for ScuNET\r\n * add branch support for extension installation\r\n * change Linux installation script to install into current directory rather than `/home/username`\r\n * sort textual inversion embeddings by name (case-insensitive)\r\n * allow styles.csv to be symlinked or mounted in docker\r\n * remove the \"do not add watermark to images\" option\r\n * make selected tab configurable with UI config\r\n * make the extra networks UI fixed height and scrollable\r\n * add `disable_tls_verify` arg for use with self-signed certs\r\n\r\n### Extensions:\r\n * add reload callback\r\n * add `is_hr_pass` field for processing\r\n\r\n### Bug Fixes:\r\n * fix broken batch image processing on 'Extras/Batch Process' tab\r\n * add \"None\" option to extra networks dropdowns\r\n * fix FileExistsError for CLIP Interrogator\r\n * fix /sdapi/v1/txt2img endpoint not working on Linux #9319\r\n * fix disappearing live previews and progressbar during slow tasks\r\n * fix fullscreen image view not working properly in some cases\r\n * prevent alwayson_scripts args param resizing script_arg list when they are inserted in it\r\n * fix prompt schedule for second order samplers\r\n * fix image mask/composite for weird resolutions #9628\r\n * use correct images for previews when using AND (see #9491)\r\n * one broken image in img2img batch won't stop all processing\r\n * fix image orientation bug in train/preprocess\r\n * fix Ngrok recreating tunnels every reload\r\n * fix `--realesrgan-models-path` and `--ldsr-models-path` not working\r\n * fix `--skip-install` not working\r\n * use SAMPLE file format in Outpainting Mk2 & Poorman\r\n * do not fail all LoRAs if some have failed to load when making a picture\r\n\r\n## 1.0.0\r\n  * everything\r\n"
  },
  {
    "path": "CITATION.cff",
    "content": "cff-version: 1.2.0\nmessage: \"If you use this software, please cite it as below.\"\nauthors:\n  - given-names: AUTOMATIC1111\ntitle: \"Stable Diffusion Web UI\"\ndate-released: 2022-08-22\nurl: \"https://github.com/AUTOMATIC1111/stable-diffusion-webui\"\n"
  },
  {
    "path": "CODEOWNERS",
    "content": "*       @AUTOMATIC1111\r\n\r\n# if you were managing a localization and were removed from this file, this is because\r\n# the intended way to do localizations now is via extensions. See:\r\n# https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Developing-extensions\r\n# Make a repo with your localization and since you are still listed as a collaborator\r\n# you can add it to the wiki page yourself. This change is because some people complained\r\n# the git commit log is cluttered with things unrelated to almost everyone and\r\n# because I believe this is the best overall for the project to handle localizations almost\r\n# entirely without my oversight.\r\n\r\n\r\n"
  },
  {
    "path": "LICENSE.txt",
    "content": "                    GNU AFFERO GENERAL PUBLIC LICENSE\r\n                       Version 3, 19 November 2007\r\n\r\n                    Copyright (c) 2023 AUTOMATIC1111\r\n\r\n Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>\r\n Everyone is permitted to copy and distribute verbatim copies\r\n of this license document, but changing it is not allowed.\r\n\r\n                            Preamble\r\n\r\n  The GNU Affero General Public License is a free, copyleft license for\r\nsoftware and other kinds of works, specifically designed to ensure\r\ncooperation with the community in the case of network server software.\r\n\r\n  The licenses for most software and other practical works are designed\r\nto take away your freedom to share and change the works.  By contrast,\r\nour General Public Licenses are intended to guarantee your freedom to\r\nshare and change all versions of a program--to make sure it remains free\r\nsoftware for all its users.\r\n\r\n  When we speak of free software, we are referring to freedom, not\r\nprice.  Our General Public Licenses are designed to make sure that you\r\nhave the freedom to distribute copies of free software (and charge for\r\nthem if you wish), that you receive source code or can get it if you\r\nwant it, that you can change the software or use pieces of it in new\r\nfree programs, and that you know you can do these things.\r\n\r\n  Developers that use our General Public Licenses protect your rights\r\nwith two steps: (1) assert copyright on the software, and (2) offer\r\nyou this License which gives you legal permission to copy, distribute\r\nand/or modify the software.\r\n\r\n  A secondary benefit of defending all users' freedom is that\r\nimprovements made in alternate versions of the program, if they\r\nreceive widespread use, become available for other developers to\r\nincorporate.  Many developers of free software are heartened and\r\nencouraged by the resulting cooperation.  However, in the case of\r\nsoftware used on network servers, this result may fail to come about.\r\nThe GNU General Public License permits making a modified version and\r\nletting the public access it on a server without ever releasing its\r\nsource code to the public.\r\n\r\n  The GNU Affero General Public License is designed specifically to\r\nensure that, in such cases, the modified source code becomes available\r\nto the community.  It requires the operator of a network server to\r\nprovide the source code of the modified version running there to the\r\nusers of that server.  Therefore, public use of a modified version, on\r\na publicly accessible server, gives the public access to the source\r\ncode of the modified version.\r\n\r\n  An older license, called the Affero General Public License and\r\npublished by Affero, was designed to accomplish similar goals.  This is\r\na different license, not a version of the Affero GPL, but Affero has\r\nreleased a new version of the Affero GPL which permits relicensing under\r\nthis license.\r\n\r\n  The precise terms and conditions for copying, distribution and\r\nmodification follow.\r\n\r\n                       TERMS AND CONDITIONS\r\n\r\n  0. Definitions.\r\n\r\n  \"This License\" refers to version 3 of the GNU Affero General Public License.\r\n\r\n  \"Copyright\" also means copyright-like laws that apply to other kinds of\r\nworks, such as semiconductor masks.\r\n\r\n  \"The Program\" refers to any copyrightable work licensed under this\r\nLicense.  Each licensee is addressed as \"you\".  \"Licensees\" and\r\n\"recipients\" may be individuals or organizations.\r\n\r\n  To \"modify\" a work means to copy from or adapt all or part of the work\r\nin a fashion requiring copyright permission, other than the making of an\r\nexact copy.  The resulting work is called a \"modified version\" of the\r\nearlier work or a work \"based on\" the earlier work.\r\n\r\n  A \"covered work\" means either the unmodified Program or a work based\r\non the Program.\r\n\r\n  To \"propagate\" a work means to do anything with it that, without\r\npermission, would make you directly or secondarily liable for\r\ninfringement under applicable copyright law, except executing it on a\r\ncomputer or modifying a private copy.  Propagation includes copying,\r\ndistribution (with or without modification), making available to the\r\npublic, and in some countries other activities as well.\r\n\r\n  To \"convey\" a work means any kind of propagation that enables other\r\nparties to make or receive copies.  Mere interaction with a user through\r\na computer network, with no transfer of a copy, is not conveying.\r\n\r\n  An interactive user interface displays \"Appropriate Legal Notices\"\r\nto the extent that it includes a convenient and prominently visible\r\nfeature that (1) displays an appropriate copyright notice, and (2)\r\ntells the user that there is no warranty for the work (except to the\r\nextent that warranties are provided), that licensees may convey the\r\nwork under this License, and how to view a copy of this License.  If\r\nthe interface presents a list of user commands or options, such as a\r\nmenu, a prominent item in the list meets this criterion.\r\n\r\n  1. Source Code.\r\n\r\n  The \"source code\" for a work means the preferred form of the work\r\nfor making modifications to it.  \"Object code\" means any non-source\r\nform of a work.\r\n\r\n  A \"Standard Interface\" means an interface that either is an official\r\nstandard defined by a recognized standards body, or, in the case of\r\ninterfaces specified for a particular programming language, one that\r\nis widely used among developers working in that language.\r\n\r\n  The \"System Libraries\" of an executable work include anything, other\r\nthan the work as a whole, that (a) is included in the normal form of\r\npackaging a Major Component, but which is not part of that Major\r\nComponent, and (b) serves only to enable use of the work with that\r\nMajor Component, or to implement a Standard Interface for which an\r\nimplementation is available to the public in source code form.  A\r\n\"Major Component\", in this context, means a major essential component\r\n(kernel, window system, and so on) of the specific operating system\r\n(if any) on which the executable work runs, or a compiler used to\r\nproduce the work, or an object code interpreter used to run it.\r\n\r\n  The \"Corresponding Source\" for a work in object code form means all\r\nthe source code needed to generate, install, and (for an executable\r\nwork) run the object code and to modify the work, including scripts to\r\ncontrol those activities.  However, it does not include the work's\r\nSystem Libraries, or general-purpose tools or generally available free\r\nprograms which are used unmodified in performing those activities but\r\nwhich are not part of the work.  For example, Corresponding Source\r\nincludes interface definition files associated with source files for\r\nthe work, and the source code for shared libraries and dynamically\r\nlinked subprograms that the work is specifically designed to require,\r\nsuch as by intimate data communication or control flow between those\r\nsubprograms and other parts of the work.\r\n\r\n  The Corresponding Source need not include anything that users\r\ncan regenerate automatically from other parts of the Corresponding\r\nSource.\r\n\r\n  The Corresponding Source for a work in source code form is that\r\nsame work.\r\n\r\n  2. Basic Permissions.\r\n\r\n  All rights granted under this License are granted for the term of\r\ncopyright on the Program, and are irrevocable provided the stated\r\nconditions are met.  This License explicitly affirms your unlimited\r\npermission to run the unmodified Program.  The output from running a\r\ncovered work is covered by this License only if the output, given its\r\ncontent, constitutes a covered work.  This License acknowledges your\r\nrights of fair use or other equivalent, as provided by copyright law.\r\n\r\n  You may make, run and propagate covered works that you do not\r\nconvey, without conditions so long as your license otherwise remains\r\nin force.  You may convey covered works to others for the sole purpose\r\nof having them make modifications exclusively for you, or provide you\r\nwith facilities for running those works, provided that you comply with\r\nthe terms of this License in conveying all material for which you do\r\nnot control copyright.  Those thus making or running the covered works\r\nfor you must do so exclusively on your behalf, under your direction\r\nand control, on terms that prohibit them from making any copies of\r\nyour copyrighted material outside their relationship with you.\r\n\r\n  Conveying under any other circumstances is permitted solely under\r\nthe conditions stated below.  Sublicensing is not allowed; section 10\r\nmakes it unnecessary.\r\n\r\n  3. Protecting Users' Legal Rights From Anti-Circumvention Law.\r\n\r\n  No covered work shall be deemed part of an effective technological\r\nmeasure under any applicable law fulfilling obligations under article\r\n11 of the WIPO copyright treaty adopted on 20 December 1996, or\r\nsimilar laws prohibiting or restricting circumvention of such\r\nmeasures.\r\n\r\n  When you convey a covered work, you waive any legal power to forbid\r\ncircumvention of technological measures to the extent such circumvention\r\nis effected by exercising rights under this License with respect to\r\nthe covered work, and you disclaim any intention to limit operation or\r\nmodification of the work as a means of enforcing, against the work's\r\nusers, your or third parties' legal rights to forbid circumvention of\r\ntechnological measures.\r\n\r\n  4. Conveying Verbatim Copies.\r\n\r\n  You may convey verbatim copies of the Program's source code as you\r\nreceive it, in any medium, provided that you conspicuously and\r\nappropriately publish on each copy an appropriate copyright notice;\r\nkeep intact all notices stating that this License and any\r\nnon-permissive terms added in accord with section 7 apply to the code;\r\nkeep intact all notices of the absence of any warranty; and give all\r\nrecipients a copy of this License along with the Program.\r\n\r\n  You may charge any price or no price for each copy that you convey,\r\nand you may offer support or warranty protection for a fee.\r\n\r\n  5. Conveying Modified Source Versions.\r\n\r\n  You may convey a work based on the Program, or the modifications to\r\nproduce it from the Program, in the form of source code under the\r\nterms of section 4, provided that you also meet all of these conditions:\r\n\r\n    a) The work must carry prominent notices stating that you modified\r\n    it, and giving a relevant date.\r\n\r\n    b) The work must carry prominent notices stating that it is\r\n    released under this License and any conditions added under section\r\n    7.  This requirement modifies the requirement in section 4 to\r\n    \"keep intact all notices\".\r\n\r\n    c) You must license the entire work, as a whole, under this\r\n    License to anyone who comes into possession of a copy.  This\r\n    License will therefore apply, along with any applicable section 7\r\n    additional terms, to the whole of the work, and all its parts,\r\n    regardless of how they are packaged.  This License gives no\r\n    permission to license the work in any other way, but it does not\r\n    invalidate such permission if you have separately received it.\r\n\r\n    d) If the work has interactive user interfaces, each must display\r\n    Appropriate Legal Notices; however, if the Program has interactive\r\n    interfaces that do not display Appropriate Legal Notices, your\r\n    work need not make them do so.\r\n\r\n  A compilation of a covered work with other separate and independent\r\nworks, which are not by their nature extensions of the covered work,\r\nand which are not combined with it such as to form a larger program,\r\nin or on a volume of a storage or distribution medium, is called an\r\n\"aggregate\" if the compilation and its resulting copyright are not\r\nused to limit the access or legal rights of the compilation's users\r\nbeyond what the individual works permit.  Inclusion of a covered work\r\nin an aggregate does not cause this License to apply to the other\r\nparts of the aggregate.\r\n\r\n  6. Conveying Non-Source Forms.\r\n\r\n  You may convey a covered work in object code form under the terms\r\nof sections 4 and 5, provided that you also convey the\r\nmachine-readable Corresponding Source under the terms of this License,\r\nin one of these ways:\r\n\r\n    a) Convey the object code in, or embodied in, a physical product\r\n    (including a physical distribution medium), accompanied by the\r\n    Corresponding Source fixed on a durable physical medium\r\n    customarily used for software interchange.\r\n\r\n    b) Convey the object code in, or embodied in, a physical product\r\n    (including a physical distribution medium), accompanied by a\r\n    written offer, valid for at least three years and valid for as\r\n    long as you offer spare parts or customer support for that product\r\n    model, to give anyone who possesses the object code either (1) a\r\n    copy of the Corresponding Source for all the software in the\r\n    product that is covered by this License, on a durable physical\r\n    medium customarily used for software interchange, for a price no\r\n    more than your reasonable cost of physically performing this\r\n    conveying of source, or (2) access to copy the\r\n    Corresponding Source from a network server at no charge.\r\n\r\n    c) Convey individual copies of the object code with a copy of the\r\n    written offer to provide the Corresponding Source.  This\r\n    alternative is allowed only occasionally and noncommercially, and\r\n    only if you received the object code with such an offer, in accord\r\n    with subsection 6b.\r\n\r\n    d) Convey the object code by offering access from a designated\r\n    place (gratis or for a charge), and offer equivalent access to the\r\n    Corresponding Source in the same way through the same place at no\r\n    further charge.  You need not require recipients to copy the\r\n    Corresponding Source along with the object code.  If the place to\r\n    copy the object code is a network server, the Corresponding Source\r\n    may be on a different server (operated by you or a third party)\r\n    that supports equivalent copying facilities, provided you maintain\r\n    clear directions next to the object code saying where to find the\r\n    Corresponding Source.  Regardless of what server hosts the\r\n    Corresponding Source, you remain obligated to ensure that it is\r\n    available for as long as needed to satisfy these requirements.\r\n\r\n    e) Convey the object code using peer-to-peer transmission, provided\r\n    you inform other peers where the object code and Corresponding\r\n    Source of the work are being offered to the general public at no\r\n    charge under subsection 6d.\r\n\r\n  A separable portion of the object code, whose source code is excluded\r\nfrom the Corresponding Source as a System Library, need not be\r\nincluded in conveying the object code work.\r\n\r\n  A \"User Product\" is either (1) a \"consumer product\", which means any\r\ntangible personal property which is normally used for personal, family,\r\nor household purposes, or (2) anything designed or sold for incorporation\r\ninto a dwelling.  In determining whether a product is a consumer product,\r\ndoubtful cases shall be resolved in favor of coverage.  For a particular\r\nproduct received by a particular user, \"normally used\" refers to a\r\ntypical or common use of that class of product, regardless of the status\r\nof the particular user or of the way in which the particular user\r\nactually uses, or expects or is expected to use, the product.  A product\r\nis a consumer product regardless of whether the product has substantial\r\ncommercial, industrial or non-consumer uses, unless such uses represent\r\nthe only significant mode of use of the product.\r\n\r\n  \"Installation Information\" for a User Product means any methods,\r\nprocedures, authorization keys, or other information required to install\r\nand execute modified versions of a covered work in that User Product from\r\na modified version of its Corresponding Source.  The information must\r\nsuffice to ensure that the continued functioning of the modified object\r\ncode is in no case prevented or interfered with solely because\r\nmodification has been made.\r\n\r\n  If you convey an object code work under this section in, or with, or\r\nspecifically for use in, a User Product, and the conveying occurs as\r\npart of a transaction in which the right of possession and use of the\r\nUser Product is transferred to the recipient in perpetuity or for a\r\nfixed term (regardless of how the transaction is characterized), the\r\nCorresponding Source conveyed under this section must be accompanied\r\nby the Installation Information.  But this requirement does not apply\r\nif neither you nor any third party retains the ability to install\r\nmodified object code on the User Product (for example, the work has\r\nbeen installed in ROM).\r\n\r\n  The requirement to provide Installation Information does not include a\r\nrequirement to continue to provide support service, warranty, or updates\r\nfor a work that has been modified or installed by the recipient, or for\r\nthe User Product in which it has been modified or installed.  Access to a\r\nnetwork may be denied when the modification itself materially and\r\nadversely affects the operation of the network or violates the rules and\r\nprotocols for communication across the network.\r\n\r\n  Corresponding Source conveyed, and Installation Information provided,\r\nin accord with this section must be in a format that is publicly\r\ndocumented (and with an implementation available to the public in\r\nsource code form), and must require no special password or key for\r\nunpacking, reading or copying.\r\n\r\n  7. Additional Terms.\r\n\r\n  \"Additional permissions\" are terms that supplement the terms of this\r\nLicense by making exceptions from one or more of its conditions.\r\nAdditional permissions that are applicable to the entire Program shall\r\nbe treated as though they were included in this License, to the extent\r\nthat they are valid under applicable law.  If additional permissions\r\napply only to part of the Program, that part may be used separately\r\nunder those permissions, but the entire Program remains governed by\r\nthis License without regard to the additional permissions.\r\n\r\n  When you convey a copy of a covered work, you may at your option\r\nremove any additional permissions from that copy, or from any part of\r\nit.  (Additional permissions may be written to require their own\r\nremoval in certain cases when you modify the work.)  You may place\r\nadditional permissions on material, added by you to a covered work,\r\nfor which you have or can give appropriate copyright permission.\r\n\r\n  Notwithstanding any other provision of this License, for material you\r\nadd to a covered work, you may (if authorized by the copyright holders of\r\nthat material) supplement the terms of this License with terms:\r\n\r\n    a) Disclaiming warranty or limiting liability differently from the\r\n    terms of sections 15 and 16 of this License; or\r\n\r\n    b) Requiring preservation of specified reasonable legal notices or\r\n    author attributions in that material or in the Appropriate Legal\r\n    Notices displayed by works containing it; or\r\n\r\n    c) Prohibiting misrepresentation of the origin of that material, or\r\n    requiring that modified versions of such material be marked in\r\n    reasonable ways as different from the original version; or\r\n\r\n    d) Limiting the use for publicity purposes of names of licensors or\r\n    authors of the material; or\r\n\r\n    e) Declining to grant rights under trademark law for use of some\r\n    trade names, trademarks, or service marks; or\r\n\r\n    f) Requiring indemnification of licensors and authors of that\r\n    material by anyone who conveys the material (or modified versions of\r\n    it) with contractual assumptions of liability to the recipient, for\r\n    any liability that these contractual assumptions directly impose on\r\n    those licensors and authors.\r\n\r\n  All other non-permissive additional terms are considered \"further\r\nrestrictions\" within the meaning of section 10.  If the Program as you\r\nreceived it, or any part of it, contains a notice stating that it is\r\ngoverned by this License along with a term that is a further\r\nrestriction, you may remove that term.  If a license document contains\r\na further restriction but permits relicensing or conveying under this\r\nLicense, you may add to a covered work material governed by the terms\r\nof that license document, provided that the further restriction does\r\nnot survive such relicensing or conveying.\r\n\r\n  If you add terms to a covered work in accord with this section, you\r\nmust place, in the relevant source files, a statement of the\r\nadditional terms that apply to those files, or a notice indicating\r\nwhere to find the applicable terms.\r\n\r\n  Additional terms, permissive or non-permissive, may be stated in the\r\nform of a separately written license, or stated as exceptions;\r\nthe above requirements apply either way.\r\n\r\n  8. Termination.\r\n\r\n  You may not propagate or modify a covered work except as expressly\r\nprovided under this License.  Any attempt otherwise to propagate or\r\nmodify it is void, and will automatically terminate your rights under\r\nthis License (including any patent licenses granted under the third\r\nparagraph of section 11).\r\n\r\n  However, if you cease all violation of this License, then your\r\nlicense from a particular copyright holder is reinstated (a)\r\nprovisionally, unless and until the copyright holder explicitly and\r\nfinally terminates your license, and (b) permanently, if the copyright\r\nholder fails to notify you of the violation by some reasonable means\r\nprior to 60 days after the cessation.\r\n\r\n  Moreover, your license from a particular copyright holder is\r\nreinstated permanently if the copyright holder notifies you of the\r\nviolation by some reasonable means, this is the first time you have\r\nreceived notice of violation of this License (for any work) from that\r\ncopyright holder, and you cure the violation prior to 30 days after\r\nyour receipt of the notice.\r\n\r\n  Termination of your rights under this section does not terminate the\r\nlicenses of parties who have received copies or rights from you under\r\nthis License.  If your rights have been terminated and not permanently\r\nreinstated, you do not qualify to receive new licenses for the same\r\nmaterial under section 10.\r\n\r\n  9. Acceptance Not Required for Having Copies.\r\n\r\n  You are not required to accept this License in order to receive or\r\nrun a copy of the Program.  Ancillary propagation of a covered work\r\noccurring solely as a consequence of using peer-to-peer transmission\r\nto receive a copy likewise does not require acceptance.  However,\r\nnothing other than this License grants you permission to propagate or\r\nmodify any covered work.  These actions infringe copyright if you do\r\nnot accept this License.  Therefore, by modifying or propagating a\r\ncovered work, you indicate your acceptance of this License to do so.\r\n\r\n  10. Automatic Licensing of Downstream Recipients.\r\n\r\n  Each time you convey a covered work, the recipient automatically\r\nreceives a license from the original licensors, to run, modify and\r\npropagate that work, subject to this License.  You are not responsible\r\nfor enforcing compliance by third parties with this License.\r\n\r\n  An \"entity transaction\" is a transaction transferring control of an\r\norganization, or substantially all assets of one, or subdividing an\r\norganization, or merging organizations.  If propagation of a covered\r\nwork results from an entity transaction, each party to that\r\ntransaction who receives a copy of the work also receives whatever\r\nlicenses to the work the party's predecessor in interest had or could\r\ngive under the previous paragraph, plus a right to possession of the\r\nCorresponding Source of the work from the predecessor in interest, if\r\nthe predecessor has it or can get it with reasonable efforts.\r\n\r\n  You may not impose any further restrictions on the exercise of the\r\nrights granted or affirmed under this License.  For example, you may\r\nnot impose a license fee, royalty, or other charge for exercise of\r\nrights granted under this License, and you may not initiate litigation\r\n(including a cross-claim or counterclaim in a lawsuit) alleging that\r\nany patent claim is infringed by making, using, selling, offering for\r\nsale, or importing the Program or any portion of it.\r\n\r\n  11. Patents.\r\n\r\n  A \"contributor\" is a copyright holder who authorizes use under this\r\nLicense of the Program or a work on which the Program is based.  The\r\nwork thus licensed is called the contributor's \"contributor version\".\r\n\r\n  A contributor's \"essential patent claims\" are all patent claims\r\nowned or controlled by the contributor, whether already acquired or\r\nhereafter acquired, that would be infringed by some manner, permitted\r\nby this License, of making, using, or selling its contributor version,\r\nbut do not include claims that would be infringed only as a\r\nconsequence of further modification of the contributor version.  For\r\npurposes of this definition, \"control\" includes the right to grant\r\npatent sublicenses in a manner consistent with the requirements of\r\nthis License.\r\n\r\n  Each contributor grants you a non-exclusive, worldwide, royalty-free\r\npatent license under the contributor's essential patent claims, to\r\nmake, use, sell, offer for sale, import and otherwise run, modify and\r\npropagate the contents of its contributor version.\r\n\r\n  In the following three paragraphs, a \"patent license\" is any express\r\nagreement or commitment, however denominated, not to enforce a patent\r\n(such as an express permission to practice a patent or covenant not to\r\nsue for patent infringement).  To \"grant\" such a patent license to a\r\nparty means to make such an agreement or commitment not to enforce a\r\npatent against the party.\r\n\r\n  If you convey a covered work, knowingly relying on a patent license,\r\nand the Corresponding Source of the work is not available for anyone\r\nto copy, free of charge and under the terms of this License, through a\r\npublicly available network server or other readily accessible means,\r\nthen you must either (1) cause the Corresponding Source to be so\r\navailable, or (2) arrange to deprive yourself of the benefit of the\r\npatent license for this particular work, or (3) arrange, in a manner\r\nconsistent with the requirements of this License, to extend the patent\r\nlicense to downstream recipients.  \"Knowingly relying\" means you have\r\nactual knowledge that, but for the patent license, your conveying the\r\ncovered work in a country, or your recipient's use of the covered work\r\nin a country, would infringe one or more identifiable patents in that\r\ncountry that you have reason to believe are valid.\r\n\r\n  If, pursuant to or in connection with a single transaction or\r\narrangement, you convey, or propagate by procuring conveyance of, a\r\ncovered work, and grant a patent license to some of the parties\r\nreceiving the covered work authorizing them to use, propagate, modify\r\nor convey a specific copy of the covered work, then the patent license\r\nyou grant is automatically extended to all recipients of the covered\r\nwork and works based on it.\r\n\r\n  A patent license is \"discriminatory\" if it does not include within\r\nthe scope of its coverage, prohibits the exercise of, or is\r\nconditioned on the non-exercise of one or more of the rights that are\r\nspecifically granted under this License.  You may not convey a covered\r\nwork if you are a party to an arrangement with a third party that is\r\nin the business of distributing software, under which you make payment\r\nto the third party based on the extent of your activity of conveying\r\nthe work, and under which the third party grants, to any of the\r\nparties who would receive the covered work from you, a discriminatory\r\npatent license (a) in connection with copies of the covered work\r\nconveyed by you (or copies made from those copies), or (b) primarily\r\nfor and in connection with specific products or compilations that\r\ncontain the covered work, unless you entered into that arrangement,\r\nor that patent license was granted, prior to 28 March 2007.\r\n\r\n  Nothing in this License shall be construed as excluding or limiting\r\nany implied license or other defenses to infringement that may\r\notherwise be available to you under applicable patent law.\r\n\r\n  12. No Surrender of Others' Freedom.\r\n\r\n  If conditions are imposed on you (whether by court order, agreement or\r\notherwise) that contradict the conditions of this License, they do not\r\nexcuse you from the conditions of this License.  If you cannot convey a\r\ncovered work so as to satisfy simultaneously your obligations under this\r\nLicense and any other pertinent obligations, then as a consequence you may\r\nnot convey it at all.  For example, if you agree to terms that obligate you\r\nto collect a royalty for further conveying from those to whom you convey\r\nthe Program, the only way you could satisfy both those terms and this\r\nLicense would be to refrain entirely from conveying the Program.\r\n\r\n  13. Remote Network Interaction; Use with the GNU General Public License.\r\n\r\n  Notwithstanding any other provision of this License, if you modify the\r\nProgram, your modified version must prominently offer all users\r\ninteracting with it remotely through a computer network (if your version\r\nsupports such interaction) an opportunity to receive the Corresponding\r\nSource of your version by providing access to the Corresponding Source\r\nfrom a network server at no charge, through some standard or customary\r\nmeans of facilitating copying of software.  This Corresponding Source\r\nshall include the Corresponding Source for any work covered by version 3\r\nof the GNU General Public License that is incorporated pursuant to the\r\nfollowing paragraph.\r\n\r\n  Notwithstanding any other provision of this License, you have\r\npermission to link or combine any covered work with a work licensed\r\nunder version 3 of the GNU General Public License into a single\r\ncombined work, and to convey the resulting work.  The terms of this\r\nLicense will continue to apply to the part which is the covered work,\r\nbut the work with which it is combined will remain governed by version\r\n3 of the GNU General Public License.\r\n\r\n  14. Revised Versions of this License.\r\n\r\n  The Free Software Foundation may publish revised and/or new versions of\r\nthe GNU Affero General Public License from time to time.  Such new versions\r\nwill be similar in spirit to the present version, but may differ in detail to\r\naddress new problems or concerns.\r\n\r\n  Each version is given a distinguishing version number.  If the\r\nProgram specifies that a certain numbered version of the GNU Affero General\r\nPublic License \"or any later version\" applies to it, you have the\r\noption of following the terms and conditions either of that numbered\r\nversion or of any later version published by the Free Software\r\nFoundation.  If the Program does not specify a version number of the\r\nGNU Affero General Public License, you may choose any version ever published\r\nby the Free Software Foundation.\r\n\r\n  If the Program specifies that a proxy can decide which future\r\nversions of the GNU Affero General Public License can be used, that proxy's\r\npublic statement of acceptance of a version permanently authorizes you\r\nto choose that version for the Program.\r\n\r\n  Later license versions may give you additional or different\r\npermissions.  However, no additional obligations are imposed on any\r\nauthor or copyright holder as a result of your choosing to follow a\r\nlater version.\r\n\r\n  15. Disclaimer of Warranty.\r\n\r\n  THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY\r\nAPPLICABLE LAW.  EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT\r\nHOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM \"AS IS\" WITHOUT WARRANTY\r\nOF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,\r\nTHE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR\r\nPURPOSE.  THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM\r\nIS WITH YOU.  SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF\r\nALL NECESSARY SERVICING, REPAIR OR CORRECTION.\r\n\r\n  16. Limitation of Liability.\r\n\r\n  IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING\r\nWILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS\r\nTHE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY\r\nGENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE\r\nUSE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF\r\nDATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD\r\nPARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),\r\nEVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF\r\nSUCH DAMAGES.\r\n\r\n  17. Interpretation of Sections 15 and 16.\r\n\r\n  If the disclaimer of warranty and limitation of liability provided\r\nabove cannot be given local legal effect according to their terms,\r\nreviewing courts shall apply local law that most closely approximates\r\nan absolute waiver of all civil liability in connection with the\r\nProgram, unless a warranty or assumption of liability accompanies a\r\ncopy of the Program in return for a fee.\r\n\r\n                     END OF TERMS AND CONDITIONS\r\n\r\n            How to Apply These Terms to Your New Programs\r\n\r\n  If you develop a new program, and you want it to be of the greatest\r\npossible use to the public, the best way to achieve this is to make it\r\nfree software which everyone can redistribute and change under these terms.\r\n\r\n  To do so, attach the following notices to the program.  It is safest\r\nto attach them to the start of each source file to most effectively\r\nstate the exclusion of warranty; and each file should have at least\r\nthe \"copyright\" line and a pointer to where the full notice is found.\r\n\r\n    <one line to give the program's name and a brief idea of what it does.>\r\n    Copyright (C) <year>  <name of author>\r\n\r\n    This program is free software: you can redistribute it and/or modify\r\n    it under the terms of the GNU Affero General Public License as published by\r\n    the Free Software Foundation, either version 3 of the License, or\r\n    (at your option) any later version.\r\n\r\n    This program is distributed in the hope that it will be useful,\r\n    but WITHOUT ANY WARRANTY; without even the implied warranty of\r\n    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the\r\n    GNU Affero General Public License for more details.\r\n\r\n    You should have received a copy of the GNU Affero General Public License\r\n    along with this program.  If not, see <https://www.gnu.org/licenses/>.\r\n\r\nAlso add information on how to contact you by electronic and paper mail.\r\n\r\n  If your software can interact with users remotely through a computer\r\nnetwork, you should also make sure that it provides a way for users to\r\nget its source.  For example, if your program is a web application, its\r\ninterface could display a \"Source\" link that leads users to an archive\r\nof the code.  There are many ways you could offer source, and different\r\nsolutions will be better for different programs; see section 13 for the\r\nspecific requirements.\r\n\r\n  You should also get your employer (if you work as a programmer) or school,\r\nif any, to sign a \"copyright disclaimer\" for the program, if necessary.\r\nFor more information on this, and how to apply and follow the GNU AGPL, see\r\n<https://www.gnu.org/licenses/>.\r\n"
  },
  {
    "path": "README.md",
    "content": "# Stable Diffusion web UI\r\nA web interface for Stable Diffusion, implemented using Gradio library.\r\n\r\n![](screenshot.png)\r\n\r\n## Features\r\n[Detailed feature showcase with images](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features):\r\n- Original txt2img and img2img modes\r\n- One click install and run script (but you still must install python and git)\r\n- Outpainting\r\n- Inpainting\r\n- Color Sketch\r\n- Prompt Matrix\r\n- Stable Diffusion Upscale\r\n- Attention, specify parts of text that the model should pay more attention to\r\n    - a man in a `((tuxedo))` - will pay more attention to tuxedo\r\n    - a man in a `(tuxedo:1.21)` - alternative syntax\r\n    - select text and press `Ctrl+Up` or `Ctrl+Down` (or `Command+Up` or `Command+Down` if you're on a MacOS) to automatically adjust attention to selected text (code contributed by anonymous user)\r\n- Loopback, run img2img processing multiple times\r\n- X/Y/Z plot, a way to draw a 3 dimensional plot of images with different parameters\r\n- Textual Inversion\r\n    - have as many embeddings as you want and use any names you like for them\r\n    - use multiple embeddings with different numbers of vectors per token\r\n    - works with half precision floating point numbers\r\n    - train embeddings on 8GB (also reports of 6GB working)\r\n- Extras tab with:\r\n    - GFPGAN, neural network that fixes faces\r\n    - CodeFormer, face restoration tool as an alternative to GFPGAN\r\n    - RealESRGAN, neural network upscaler\r\n    - ESRGAN, neural network upscaler with a lot of third party models\r\n    - SwinIR and Swin2SR ([see here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/2092)), neural network upscalers\r\n    - LDSR, Latent diffusion super resolution upscaling\r\n- Resizing aspect ratio options\r\n- Sampling method selection\r\n    - Adjust sampler eta values (noise multiplier)\r\n    - More advanced noise setting options\r\n- Interrupt processing at any time\r\n- 4GB video card support (also reports of 2GB working)\r\n- Correct seeds for batches\r\n- Live prompt token length validation\r\n- Generation parameters\r\n     - parameters you used to generate images are saved with that image\r\n     - in PNG chunks for PNG, in EXIF for JPEG\r\n     - can drag the image to PNG info tab to restore generation parameters and automatically copy them into UI\r\n     - can be disabled in settings\r\n     - drag and drop an image/text-parameters to promptbox\r\n- Read Generation Parameters Button, loads parameters in promptbox to UI\r\n- Settings page\r\n- Running arbitrary python code from UI (must run with `--allow-code` to enable)\r\n- Mouseover hints for most UI elements\r\n- Possible to change defaults/mix/max/step values for UI elements via text config\r\n- Tiling support, a checkbox to create images that can be tiled like textures\r\n- Progress bar and live image generation preview\r\n    - Can use a separate neural network to produce previews with almost none VRAM or compute requirement\r\n- Negative prompt, an extra text field that allows you to list what you don't want to see in generated image\r\n- Styles, a way to save part of prompt and easily apply them via dropdown later\r\n- Variations, a way to generate same image but with tiny differences\r\n- Seed resizing, a way to generate same image but at slightly different resolution\r\n- CLIP interrogator, a button that tries to guess prompt from an image\r\n- Prompt Editing, a way to change prompt mid-generation, say to start making a watermelon and switch to anime girl midway\r\n- Batch Processing, process a group of files using img2img\r\n- Img2img Alternative, reverse Euler method of cross attention control\r\n- Highres Fix, a convenience option to produce high resolution pictures in one click without usual distortions\r\n- Reloading checkpoints on the fly\r\n- Checkpoint Merger, a tab that allows you to merge up to 3 checkpoints into one\r\n- [Custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts) with many extensions from community\r\n- [Composable-Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/), a way to use multiple prompts at once\r\n     - separate prompts using uppercase `AND`\r\n     - also supports weights for prompts: `a cat :1.2 AND a dog AND a penguin :2.2`\r\n- No token limit for prompts (original stable diffusion lets you use up to 75 tokens)\r\n- DeepDanbooru integration, creates danbooru style tags for anime prompts\r\n- [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add `--xformers` to commandline args)\r\n- via extension: [History tab](https://github.com/yfszzx/stable-diffusion-webui-images-browser): view, direct and delete images conveniently within the UI\r\n- Generate forever option\r\n- Training tab\r\n     - hypernetworks and embeddings options\r\n     - Preprocessing images: cropping, mirroring, autotagging using BLIP or deepdanbooru (for anime)\r\n- Clip skip\r\n- Hypernetworks\r\n- Loras (same as Hypernetworks but more pretty)\r\n- A separate UI where you can choose, with preview, which embeddings, hypernetworks or Loras to add to your prompt\r\n- Can select to load a different VAE from settings screen\r\n- Estimated completion time in progress bar\r\n- API\r\n- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML\r\n- via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embeds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients))\r\n- [Stable Diffusion 2.0](https://github.com/Stability-AI/stablediffusion) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20) for instructions\r\n- [Alt-Diffusion](https://arxiv.org/abs/2211.06679) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alt-diffusion) for instructions\r\n- Now without any bad letters!\r\n- Load checkpoints in safetensors format\r\n- Eased resolution restriction: generated image's dimensions must be a multiple of 8 rather than 64\r\n- Now with a license!\r\n- Reorder elements in the UI from settings screen\r\n- [Segmind Stable Diffusion](https://huggingface.co/segmind/SSD-1B) support\r\n\r\n## Installation and Running\r\nMake sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for:\r\n- [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended)\r\n- [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.\r\n- [Intel CPUs, Intel GPUs (both integrated and discrete)](https://github.com/openvinotoolkit/stable-diffusion-webui/wiki/Installation-on-Intel-Silicon) (external wiki page)\r\n- [Ascend NPUs](https://github.com/wangshuai09/stable-diffusion-webui/wiki/Install-and-run-on-Ascend-NPUs) (external wiki page)\r\n\r\nAlternatively, use online services (like Google Colab):\r\n\r\n- [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services)\r\n\r\n### Installation on Windows 10/11 with NVidia-GPUs using release package\r\n1. Download `sd.webui.zip` from [v1.0.0-pre](https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/tag/v1.0.0-pre) and extract its contents.\r\n2. Run `update.bat`.\r\n3. Run `run.bat`.\r\n> For more details see [Install-and-Run-on-NVidia-GPUs](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs)\r\n\r\n### Automatic Installation on Windows\r\n1. Install [Python 3.10.6](https://www.python.org/downloads/release/python-3106/) (Newer version of Python does not support torch), checking \"Add Python to PATH\".\r\n2. Install [git](https://git-scm.com/download/win).\r\n3. Download the stable-diffusion-webui repository, for example by running `git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git`.\r\n4. Run `webui-user.bat` from Windows Explorer as normal, non-administrator, user.\r\n\r\n### Automatic Installation on Linux\r\n1. Install the dependencies:\r\n```bash\r\n# Debian-based:\r\nsudo apt install wget git python3 python3-venv libgl1 libglib2.0-0\r\n# Red Hat-based:\r\nsudo dnf install wget git python3 gperftools-libs libglvnd-glx\r\n# openSUSE-based:\r\nsudo zypper install wget git python3 libtcmalloc4 libglvnd\r\n# Arch-based:\r\nsudo pacman -S wget git python3\r\n```\r\nIf your system is very new, you need to install python3.11 or python3.10:\r\n```bash\r\n# Ubuntu 24.04\r\nsudo add-apt-repository ppa:deadsnakes/ppa\r\nsudo apt update\r\nsudo apt install python3.11\r\n\r\n# Manjaro/Arch\r\nsudo pacman -S yay\r\nyay -S python311 # do not confuse with python3.11 package\r\n\r\n# Only for 3.11\r\n# Then set up env variable in launch script\r\nexport python_cmd=\"python3.11\"\r\n# or in webui-user.sh\r\npython_cmd=\"python3.11\"\r\n```\r\n2. Navigate to the directory you would like the webui to be installed and execute the following command:\r\n```bash\r\nwget -q https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/master/webui.sh\r\n```\r\nOr just clone the repo wherever you want:\r\n```bash\r\ngit clone https://github.com/AUTOMATIC1111/stable-diffusion-webui\r\n```\r\n\r\n3. Run `webui.sh`.\r\n4. Check `webui-user.sh` for options.\r\n### Installation on Apple Silicon\r\n\r\nFind the instructions [here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Installation-on-Apple-Silicon).\r\n\r\n## Contributing\r\nHere's how to add code to this repo: [Contributing](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing)\r\n\r\n## Documentation\r\n\r\nThe documentation was moved from this README over to the project's [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki).\r\n\r\nFor the purposes of getting Google and other search engines to crawl the wiki, here's a link to the (not for humans) [crawlable wiki](https://github-wiki-see.page/m/AUTOMATIC1111/stable-diffusion-webui/wiki).\r\n\r\n## Credits\r\nLicenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file.\r\n\r\n- Stable Diffusion - https://github.com/Stability-AI/stablediffusion, https://github.com/CompVis/taming-transformers, https://github.com/mcmonkey4eva/sd3-ref\r\n- k-diffusion - https://github.com/crowsonkb/k-diffusion.git\r\n- Spandrel - https://github.com/chaiNNer-org/spandrel implementing\r\n  - GFPGAN - https://github.com/TencentARC/GFPGAN.git\r\n  - CodeFormer - https://github.com/sczhou/CodeFormer\r\n  - ESRGAN - https://github.com/xinntao/ESRGAN\r\n  - SwinIR - https://github.com/JingyunLiang/SwinIR\r\n  - Swin2SR - https://github.com/mv-lab/swin2sr\r\n- LDSR - https://github.com/Hafiidz/latent-diffusion\r\n- MiDaS - https://github.com/isl-org/MiDaS\r\n- Ideas for optimizations - https://github.com/basujindal/stable-diffusion\r\n- Cross Attention layer optimization - Doggettx - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing.\r\n- Cross Attention layer optimization - InvokeAI, lstein - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion)\r\n- Sub-quadratic Cross Attention layer optimization - Alex Birch (https://github.com/Birch-san/diffusers/pull/1), Amin Rezaei (https://github.com/AminRezaei0x443/memory-efficient-attention)\r\n- Textual Inversion - Rinon Gal - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas).\r\n- Idea for SD upscale - https://github.com/jquesnelle/txt2imghd\r\n- Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot\r\n- CLIP interrogator idea and borrowing some code - https://github.com/pharmapsychotic/clip-interrogator\r\n- Idea for Composable Diffusion - https://github.com/energy-based-model/Compositional-Visual-Generation-with-Composable-Diffusion-Models-PyTorch\r\n- xformers - https://github.com/facebookresearch/xformers\r\n- DeepDanbooru - interrogator for anime diffusers https://github.com/KichangKim/DeepDanbooru\r\n- Sampling in float32 precision from a float16 UNet - marunine for the idea, Birch-san for the example Diffusers implementation (https://github.com/Birch-san/diffusers-play/tree/92feee6)\r\n- Instruct pix2pix - Tim Brooks (star), Aleksander Holynski (star), Alexei A. Efros (no star) - https://github.com/timothybrooks/instruct-pix2pix\r\n- Security advice - RyotaK\r\n- UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC\r\n- TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd\r\n- LyCORIS - KohakuBlueleaf\r\n- Restart sampling - lambertae - https://github.com/Newbeeer/diffusion_restart_sampling\r\n- Hypertile - tfernd - https://github.com/tfernd/HyperTile\r\n- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.\r\n- (You)\r\n"
  },
  {
    "path": "_typos.toml",
    "content": "[default.extend-words]\n# Part of \"RGBa\" (Pillow's pre-multiplied alpha RGB mode)\nBa = \"Ba\"\n# HSA is something AMD uses for their GPUs\nHSA = \"HSA\"\n"
  },
  {
    "path": "configs/alt-diffusion-inference.yaml",
    "content": "model:\n  base_learning_rate: 1.0e-04\n  target: ldm.models.diffusion.ddpm.LatentDiffusion\n  params:\n    linear_start: 0.00085\n    linear_end: 0.0120\n    num_timesteps_cond: 1\n    log_every_t: 200\n    timesteps: 1000\n    first_stage_key: \"jpg\"\n    cond_stage_key: \"txt\"\n    image_size: 64\n    channels: 4\n    cond_stage_trainable: false   # Note: different from the one we trained before\n    conditioning_key: crossattn\n    monitor: val/loss_simple_ema\n    scale_factor: 0.18215\n    use_ema: False\n\n    scheduler_config: # 10000 warmup steps\n      target: ldm.lr_scheduler.LambdaLinearScheduler\n      params:\n        warm_up_steps: [ 10000 ]\n        cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases\n        f_start: [ 1.e-6 ]\n        f_max: [ 1. ]\n        f_min: [ 1. ]\n\n    unet_config:\n      target: ldm.modules.diffusionmodules.openaimodel.UNetModel\n      params:\n        image_size: 32 # unused\n        in_channels: 4\n        out_channels: 4\n        model_channels: 320\n        attention_resolutions: [ 4, 2, 1 ]\n        num_res_blocks: 2\n        channel_mult: [ 1, 2, 4, 4 ]\n        num_heads: 8\n        use_spatial_transformer: True\n        transformer_depth: 1\n        context_dim: 768\n        use_checkpoint: False\n        legacy: False\n\n    first_stage_config:\n      target: ldm.models.autoencoder.AutoencoderKL\n      params:\n        embed_dim: 4\n        monitor: val/rec_loss\n        ddconfig:\n          double_z: true\n          z_channels: 4\n          resolution: 256\n          in_channels: 3\n          out_ch: 3\n          ch: 128\n          ch_mult:\n          - 1\n          - 2\n          - 4\n          - 4\n          num_res_blocks: 2\n          attn_resolutions: []\n          dropout: 0.0\n        lossconfig:\n          target: torch.nn.Identity\n\n    cond_stage_config:\n      target: modules.xlmr.BertSeriesModelWithTransformation\n      params:\n        name: \"XLMR-Large\""
  },
  {
    "path": "configs/alt-diffusion-m18-inference.yaml",
    "content": "model:\n  base_learning_rate: 1.0e-04\n  target: ldm.models.diffusion.ddpm.LatentDiffusion\n  params:\n    linear_start: 0.00085\n    linear_end: 0.0120\n    num_timesteps_cond: 1\n    log_every_t: 200\n    timesteps: 1000\n    first_stage_key: \"jpg\"\n    cond_stage_key: \"txt\"\n    image_size: 64\n    channels: 4\n    cond_stage_trainable: false   # Note: different from the one we trained before\n    conditioning_key: crossattn\n    monitor: val/loss_simple_ema\n    scale_factor: 0.18215\n    use_ema: False\n\n    scheduler_config: # 10000 warmup steps\n      target: ldm.lr_scheduler.LambdaLinearScheduler\n      params:\n        warm_up_steps: [ 10000 ]\n        cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases\n        f_start: [ 1.e-6 ]\n        f_max: [ 1. ]\n        f_min: [ 1. ]\n\n    unet_config:\n      target: ldm.modules.diffusionmodules.openaimodel.UNetModel\n      params:\n        image_size: 32 # unused\n        in_channels: 4\n        out_channels: 4\n        model_channels: 320\n        attention_resolutions: [ 4, 2, 1 ]\n        num_res_blocks: 2\n        channel_mult: [ 1, 2, 4, 4 ]\n        num_head_channels: 64\n        use_spatial_transformer: True\n        use_linear_in_transformer: True\n        transformer_depth: 1\n        context_dim: 1024\n        use_checkpoint: False\n        legacy: False\n\n    first_stage_config:\n      target: ldm.models.autoencoder.AutoencoderKL\n      params:\n        embed_dim: 4\n        monitor: val/rec_loss\n        ddconfig:\n          double_z: true\n          z_channels: 4\n          resolution: 256\n          in_channels: 3\n          out_ch: 3\n          ch: 128\n          ch_mult:\n          - 1\n          - 2\n          - 4\n          - 4\n          num_res_blocks: 2\n          attn_resolutions: []\n          dropout: 0.0\n        lossconfig:\n          target: torch.nn.Identity\n\n    cond_stage_config:\n      target: modules.xlmr_m18.BertSeriesModelWithTransformation\n      params:\n        name: \"XLMR-Large\"\n"
  },
  {
    "path": "configs/instruct-pix2pix.yaml",
    "content": "# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).\n# See more details in LICENSE.\n\nmodel:\n  base_learning_rate: 1.0e-04\n  target: modules.models.diffusion.ddpm_edit.LatentDiffusion\n  params:\n    linear_start: 0.00085\n    linear_end: 0.0120\n    num_timesteps_cond: 1\n    log_every_t: 200\n    timesteps: 1000\n    first_stage_key: edited\n    cond_stage_key: edit\n    # image_size: 64\n    # image_size: 32\n    image_size: 16\n    channels: 4\n    cond_stage_trainable: false   # Note: different from the one we trained before\n    conditioning_key: hybrid\n    monitor: val/loss_simple_ema\n    scale_factor: 0.18215\n    use_ema: false\n\n    scheduler_config: # 10000 warmup steps\n      target: ldm.lr_scheduler.LambdaLinearScheduler\n      params:\n        warm_up_steps: [ 0 ]\n        cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases\n        f_start: [ 1.e-6 ]\n        f_max: [ 1. ]\n        f_min: [ 1. ]\n\n    unet_config:\n      target: ldm.modules.diffusionmodules.openaimodel.UNetModel\n      params:\n        image_size: 32 # unused\n        in_channels: 8\n        out_channels: 4\n        model_channels: 320\n        attention_resolutions: [ 4, 2, 1 ]\n        num_res_blocks: 2\n        channel_mult: [ 1, 2, 4, 4 ]\n        num_heads: 8\n        use_spatial_transformer: True\n        transformer_depth: 1\n        context_dim: 768\n        use_checkpoint: False\n        legacy: False\n\n    first_stage_config:\n      target: ldm.models.autoencoder.AutoencoderKL\n      params:\n        embed_dim: 4\n        monitor: val/rec_loss\n        ddconfig:\n          double_z: true\n          z_channels: 4\n          resolution: 256\n          in_channels: 3\n          out_ch: 3\n          ch: 128\n          ch_mult:\n          - 1\n          - 2\n          - 4\n          - 4\n          num_res_blocks: 2\n          attn_resolutions: []\n          dropout: 0.0\n        lossconfig:\n          target: torch.nn.Identity\n\n    cond_stage_config:\n      target: ldm.modules.encoders.modules.FrozenCLIPEmbedder\n\ndata:\n  target: main.DataModuleFromConfig\n  params:\n    batch_size: 128\n    num_workers: 1\n    wrap: false\n    validation:\n      target: edit_dataset.EditDataset\n      params:\n        path: data/clip-filtered-dataset\n        cache_dir:  data/\n        cache_name: data_10k\n        split: val\n        min_text_sim: 0.2\n        min_image_sim: 0.75\n        min_direction_sim: 0.2\n        max_samples_per_prompt: 1\n        min_resize_res: 512\n        max_resize_res: 512\n        crop_res: 512\n        output_as_edit: False\n        real_input: True\n"
  },
  {
    "path": "configs/sd3-inference.yaml",
    "content": "model:\n  target: modules.models.sd3.sd3_model.SD3Inferencer\n  params:\n    shift: 3\n    state_dict: null\n"
  },
  {
    "path": "configs/sd_xl_inpaint.yaml",
    "content": "model:\n  target: sgm.models.diffusion.DiffusionEngine\n  params:\n    scale_factor: 0.13025\n    disable_first_stage_autocast: True\n\n    denoiser_config:\n      target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser\n      params:\n        num_idx: 1000\n\n        weighting_config:\n          target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting\n        scaling_config:\n          target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling\n        discretization_config:\n          target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization\n\n    network_config:\n      target: sgm.modules.diffusionmodules.openaimodel.UNetModel\n      params:\n        adm_in_channels: 2816\n        num_classes: sequential\n        use_checkpoint: False\n        in_channels: 9\n        out_channels: 4\n        model_channels: 320\n        attention_resolutions: [4, 2]\n        num_res_blocks: 2\n        channel_mult: [1, 2, 4]\n        num_head_channels: 64\n        use_spatial_transformer: True\n        use_linear_in_transformer: True\n        transformer_depth: [1, 2, 10]  # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16\n        context_dim: 2048\n        spatial_transformer_attn_type: softmax-xformers\n        legacy: False\n\n    conditioner_config:\n      target: sgm.modules.GeneralConditioner\n      params:\n        emb_models:\n          # crossattn cond\n          - is_trainable: False\n            input_key: txt\n            target: sgm.modules.encoders.modules.FrozenCLIPEmbedder\n            params:\n              layer: hidden\n              layer_idx: 11\n          # crossattn and vector cond\n          - is_trainable: False\n            input_key: txt\n            target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2\n            params:\n              arch: ViT-bigG-14\n              version: laion2b_s39b_b160k\n              freeze: True\n              layer: penultimate\n              always_return_pooled: True\n              legacy: False\n          # vector cond\n          - is_trainable: False\n            input_key: original_size_as_tuple\n            target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n            params:\n              outdim: 256  # multiplied by two\n          # vector cond\n          - is_trainable: False\n            input_key: crop_coords_top_left\n            target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n            params:\n              outdim: 256  # multiplied by two\n          # vector cond\n          - is_trainable: False\n            input_key: target_size_as_tuple\n            target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND\n            params:\n              outdim: 256  # multiplied by two\n\n    first_stage_config:\n      target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper\n      params:\n        embed_dim: 4\n        monitor: val/rec_loss\n        ddconfig:\n          attn_type: vanilla-xformers\n          double_z: true\n          z_channels: 4\n          resolution: 256\n          in_channels: 3\n          out_ch: 3\n          ch: 128\n          ch_mult: [1, 2, 4, 4]\n          num_res_blocks: 2\n          attn_resolutions: []\n          dropout: 0.0\n        lossconfig:\n          target: torch.nn.Identity\n"
  },
  {
    "path": "configs/v1-inference.yaml",
    "content": "model:\n  base_learning_rate: 1.0e-04\n  target: ldm.models.diffusion.ddpm.LatentDiffusion\n  params:\n    linear_start: 0.00085\n    linear_end: 0.0120\n    num_timesteps_cond: 1\n    log_every_t: 200\n    timesteps: 1000\n    first_stage_key: \"jpg\"\n    cond_stage_key: \"txt\"\n    image_size: 64\n    channels: 4\n    cond_stage_trainable: false   # Note: different from the one we trained before\n    conditioning_key: crossattn\n    monitor: val/loss_simple_ema\n    scale_factor: 0.18215\n    use_ema: False\n\n    scheduler_config: # 10000 warmup steps\n      target: ldm.lr_scheduler.LambdaLinearScheduler\n      params:\n        warm_up_steps: [ 10000 ]\n        cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases\n        f_start: [ 1.e-6 ]\n        f_max: [ 1. ]\n        f_min: [ 1. ]\n\n    unet_config:\n      target: ldm.modules.diffusionmodules.openaimodel.UNetModel\n      params:\n        image_size: 32 # unused\n        in_channels: 4\n        out_channels: 4\n        model_channels: 320\n        attention_resolutions: [ 4, 2, 1 ]\n        num_res_blocks: 2\n        channel_mult: [ 1, 2, 4, 4 ]\n        num_heads: 8\n        use_spatial_transformer: True\n        transformer_depth: 1\n        context_dim: 768\n        use_checkpoint: False\n        legacy: False\n\n    first_stage_config:\n      target: ldm.models.autoencoder.AutoencoderKL\n      params:\n        embed_dim: 4\n        monitor: val/rec_loss\n        ddconfig:\n          double_z: true\n          z_channels: 4\n          resolution: 256\n          in_channels: 3\n          out_ch: 3\n          ch: 128\n          ch_mult:\n          - 1\n          - 2\n          - 4\n          - 4\n          num_res_blocks: 2\n          attn_resolutions: []\n          dropout: 0.0\n        lossconfig:\n          target: torch.nn.Identity\n\n    cond_stage_config:\n      target: ldm.modules.encoders.modules.FrozenCLIPEmbedder\n"
  },
  {
    "path": "configs/v1-inpainting-inference.yaml",
    "content": "model:\n  base_learning_rate: 7.5e-05\n  target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion\n  params:\n    linear_start: 0.00085\n    linear_end: 0.0120\n    num_timesteps_cond: 1\n    log_every_t: 200\n    timesteps: 1000\n    first_stage_key: \"jpg\"\n    cond_stage_key: \"txt\"\n    image_size: 64\n    channels: 4\n    cond_stage_trainable: false   # Note: different from the one we trained before\n    conditioning_key: hybrid   # important\n    monitor: val/loss_simple_ema\n    scale_factor: 0.18215\n    finetune_keys: null\n\n    scheduler_config: # 10000 warmup steps\n      target: ldm.lr_scheduler.LambdaLinearScheduler\n      params:\n        warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch\n        cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases\n        f_start: [ 1.e-6 ]\n        f_max: [ 1. ]\n        f_min: [ 1. ]\n\n    unet_config:\n      target: ldm.modules.diffusionmodules.openaimodel.UNetModel\n      params:\n        image_size: 32 # unused\n        in_channels: 9  # 4 data + 4 downscaled image + 1 mask\n        out_channels: 4\n        model_channels: 320\n        attention_resolutions: [ 4, 2, 1 ]\n        num_res_blocks: 2\n        channel_mult: [ 1, 2, 4, 4 ]\n        num_heads: 8\n        use_spatial_transformer: True\n        transformer_depth: 1\n        context_dim: 768\n        use_checkpoint: False\n        legacy: False\n\n    first_stage_config:\n      target: ldm.models.autoencoder.AutoencoderKL\n      params:\n        embed_dim: 4\n        monitor: val/rec_loss\n        ddconfig:\n          double_z: true\n          z_channels: 4\n          resolution: 256\n          in_channels: 3\n          out_ch: 3\n          ch: 128\n          ch_mult:\n          - 1\n          - 2\n          - 4\n          - 4\n          num_res_blocks: 2\n          attn_resolutions: []\n          dropout: 0.0\n        lossconfig:\n          target: torch.nn.Identity\n\n    cond_stage_config:\n      target: ldm.modules.encoders.modules.FrozenCLIPEmbedder\n"
  },
  {
    "path": "environment-wsl2.yaml",
    "content": "name: automatic\nchannels:\n  - pytorch\n  - defaults\ndependencies:\n  - python=3.10\n  - pip=23.0\n  - cudatoolkit=11.8\n  - pytorch=2.0\n  - torchvision=0.15\n  - numpy=1.23\n"
  },
  {
    "path": "extensions-builtin/LDSR/ldsr_model_arch.py",
    "content": "import os\nimport gc\nimport time\n\nimport numpy as np\nimport torch\nimport torchvision\nfrom PIL import Image\nfrom einops import rearrange, repeat\nfrom omegaconf import OmegaConf\nimport safetensors.torch\n\nfrom ldm.models.diffusion.ddim import DDIMSampler\nfrom ldm.util import instantiate_from_config, ismap\nfrom modules import shared, sd_hijack, devices\n\ncached_ldsr_model: torch.nn.Module = None\n\n\n# Create LDSR Class\nclass LDSR:\n    def load_model_from_config(self, half_attention):\n        global cached_ldsr_model\n\n        if shared.opts.ldsr_cached and cached_ldsr_model is not None:\n            print(\"Loading model from cache\")\n            model: torch.nn.Module = cached_ldsr_model\n        else:\n            print(f\"Loading model from {self.modelPath}\")\n            _, extension = os.path.splitext(self.modelPath)\n            if extension.lower() == \".safetensors\":\n                pl_sd = safetensors.torch.load_file(self.modelPath, device=\"cpu\")\n            else:\n                pl_sd = torch.load(self.modelPath, map_location=\"cpu\")\n            sd = pl_sd[\"state_dict\"] if \"state_dict\" in pl_sd else pl_sd\n            config = OmegaConf.load(self.yamlPath)\n            config.model.target = \"ldm.models.diffusion.ddpm.LatentDiffusionV1\"\n            model: torch.nn.Module = instantiate_from_config(config.model)\n            model.load_state_dict(sd, strict=False)\n            model = model.to(shared.device)\n            if half_attention:\n                model = model.half()\n            if shared.cmd_opts.opt_channelslast:\n                model = model.to(memory_format=torch.channels_last)\n\n            sd_hijack.model_hijack.hijack(model) # apply optimization\n            model.eval()\n\n            if shared.opts.ldsr_cached:\n                cached_ldsr_model = model\n\n        return {\"model\": model}\n\n    def __init__(self, model_path, yaml_path):\n        self.modelPath = model_path\n        self.yamlPath = yaml_path\n\n    @staticmethod\n    def run(model, selected_path, custom_steps, eta):\n        example = get_cond(selected_path)\n\n        n_runs = 1\n        guider = None\n        ckwargs = None\n        ddim_use_x0_pred = False\n        temperature = 1.\n        eta = eta\n        custom_shape = None\n\n        height, width = example[\"image\"].shape[1:3]\n        split_input = height >= 128 and width >= 128\n\n        if split_input:\n            ks = 128\n            stride = 64\n            vqf = 4  #\n            model.split_input_params = {\"ks\": (ks, ks), \"stride\": (stride, stride),\n                                        \"vqf\": vqf,\n                                        \"patch_distributed_vq\": True,\n                                        \"tie_braker\": False,\n                                        \"clip_max_weight\": 0.5,\n                                        \"clip_min_weight\": 0.01,\n                                        \"clip_max_tie_weight\": 0.5,\n                                        \"clip_min_tie_weight\": 0.01}\n        else:\n            if hasattr(model, \"split_input_params\"):\n                delattr(model, \"split_input_params\")\n\n        x_t = None\n        logs = None\n        for _ in range(n_runs):\n            if custom_shape is not None:\n                x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)\n                x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])\n\n            logs = make_convolutional_sample(example, model,\n                                             custom_steps=custom_steps,\n                                             eta=eta, quantize_x0=False,\n                                             custom_shape=custom_shape,\n                                             temperature=temperature, noise_dropout=0.,\n                                             corrector=guider, corrector_kwargs=ckwargs, x_T=x_t,\n                                             ddim_use_x0_pred=ddim_use_x0_pred\n                                             )\n        return logs\n\n    def super_resolution(self, image, steps=100, target_scale=2, half_attention=False):\n        model = self.load_model_from_config(half_attention)\n\n        # Run settings\n        diffusion_steps = int(steps)\n        eta = 1.0\n\n\n        gc.collect()\n        devices.torch_gc()\n\n        im_og = image\n        width_og, height_og = im_og.size\n        # If we can adjust the max upscale size, then the 4 below should be our variable\n        down_sample_rate = target_scale / 4\n        wd = width_og * down_sample_rate\n        hd = height_og * down_sample_rate\n        width_downsampled_pre = int(np.ceil(wd))\n        height_downsampled_pre = int(np.ceil(hd))\n\n        if down_sample_rate != 1:\n            print(\n                f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]')\n            im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)\n        else:\n            print(f\"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)\")\n\n        # pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts\n        pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size\n        im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))\n\n        logs = self.run(model[\"model\"], im_padded, diffusion_steps, eta)\n\n        sample = logs[\"sample\"]\n        sample = sample.detach().cpu()\n        sample = torch.clamp(sample, -1., 1.)\n        sample = (sample + 1.) / 2. * 255\n        sample = sample.numpy().astype(np.uint8)\n        sample = np.transpose(sample, (0, 2, 3, 1))\n        a = Image.fromarray(sample[0])\n\n        # remove padding\n        a = a.crop((0, 0) + tuple(np.array(im_og.size) * 4))\n\n        del model\n        gc.collect()\n        devices.torch_gc()\n\n        return a\n\n\ndef get_cond(selected_path):\n    example = {}\n    up_f = 4\n    c = selected_path.convert('RGB')\n    c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)\n    c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]],\n                                                    antialias=True)\n    c_up = rearrange(c_up, '1 c h w -> 1 h w c')\n    c = rearrange(c, '1 c h w -> 1 h w c')\n    c = 2. * c - 1.\n\n    c = c.to(shared.device)\n    example[\"LR_image\"] = c\n    example[\"image\"] = c_up\n\n    return example\n\n\n@torch.no_grad()\ndef convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None,\n                    mask=None, x0=None, quantize_x0=False, temperature=1., score_corrector=None,\n                    corrector_kwargs=None, x_t=None\n                    ):\n    ddim = DDIMSampler(model)\n    bs = shape[0]\n    shape = shape[1:]\n    print(f\"Sampling with eta = {eta}; steps: {steps}\")\n    samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback,\n                                         normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta,\n                                         mask=mask, x0=x0, temperature=temperature, verbose=False,\n                                         score_corrector=score_corrector,\n                                         corrector_kwargs=corrector_kwargs, x_t=x_t)\n\n    return samples, intermediates\n\n\n@torch.no_grad()\ndef make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,\n                              corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):\n    log = {}\n\n    z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,\n                                        return_first_stage_outputs=True,\n                                        force_c_encode=not (hasattr(model, 'split_input_params')\n                                                            and model.cond_stage_key == 'coordinates_bbox'),\n                                        return_original_cond=True)\n\n    if custom_shape is not None:\n        z = torch.randn(custom_shape)\n        print(f\"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}\")\n\n    z0 = None\n\n    log[\"input\"] = x\n    log[\"reconstruction\"] = xrec\n\n    if ismap(xc):\n        log[\"original_conditioning\"] = model.to_rgb(xc)\n        if hasattr(model, 'cond_stage_key'):\n            log[model.cond_stage_key] = model.to_rgb(xc)\n\n    else:\n        log[\"original_conditioning\"] = xc if xc is not None else torch.zeros_like(x)\n        if model.cond_stage_model:\n            log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)\n            if model.cond_stage_key == 'class_label':\n                log[model.cond_stage_key] = xc[model.cond_stage_key]\n\n    with model.ema_scope(\"Plotting\"):\n        t0 = time.time()\n\n        sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape,\n                                                eta=eta,\n                                                quantize_x0=quantize_x0, mask=None, x0=z0,\n                                                temperature=temperature, score_corrector=corrector, corrector_kwargs=corrector_kwargs,\n                                                x_t=x_T)\n        t1 = time.time()\n\n        if ddim_use_x0_pred:\n            sample = intermediates['pred_x0'][-1]\n\n    x_sample = model.decode_first_stage(sample)\n\n    try:\n        x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)\n        log[\"sample_noquant\"] = x_sample_noquant\n        log[\"sample_diff\"] = torch.abs(x_sample_noquant - x_sample)\n    except Exception:\n        pass\n\n    log[\"sample\"] = x_sample\n    log[\"time\"] = t1 - t0\n\n    return log\n"
  },
  {
    "path": "extensions-builtin/LDSR/preload.py",
    "content": "import os\r\nfrom modules import paths\r\n\r\n\r\ndef preload(parser):\r\n    parser.add_argument(\"--ldsr-models-path\", type=str, help=\"Path to directory with LDSR model file(s).\", default=os.path.join(paths.models_path, 'LDSR'))\r\n"
  },
  {
    "path": "extensions-builtin/LDSR/scripts/ldsr_model.py",
    "content": "import os\n\nfrom modules.modelloader import load_file_from_url\nfrom modules.upscaler import Upscaler, UpscalerData\nfrom ldsr_model_arch import LDSR\nfrom modules import shared, script_callbacks, errors\nimport sd_hijack_autoencoder  # noqa: F401\nimport sd_hijack_ddpm_v1  # noqa: F401\n\n\nclass UpscalerLDSR(Upscaler):\n    def __init__(self, user_path):\n        self.name = \"LDSR\"\n        self.user_path = user_path\n        self.model_url = \"https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1\"\n        self.yaml_url = \"https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1\"\n        super().__init__()\n        scaler_data = UpscalerData(\"LDSR\", None, self)\n        self.scalers = [scaler_data]\n\n    def load_model(self, path: str):\n        # Remove incorrect project.yaml file if too big\n        yaml_path = os.path.join(self.model_path, \"project.yaml\")\n        old_model_path = os.path.join(self.model_path, \"model.pth\")\n        new_model_path = os.path.join(self.model_path, \"model.ckpt\")\n\n        local_model_paths = self.find_models(ext_filter=[\".ckpt\", \".safetensors\"])\n        local_ckpt_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith(\"model.ckpt\")]), None)\n        local_safetensors_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith(\"model.safetensors\")]), None)\n        local_yaml_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith(\"project.yaml\")]), None)\n\n        if os.path.exists(yaml_path):\n            statinfo = os.stat(yaml_path)\n            if statinfo.st_size >= 10485760:\n                print(\"Removing invalid LDSR YAML file.\")\n                os.remove(yaml_path)\n\n        if os.path.exists(old_model_path):\n            print(\"Renaming model from model.pth to model.ckpt\")\n            os.rename(old_model_path, new_model_path)\n\n        if local_safetensors_path is not None and os.path.exists(local_safetensors_path):\n            model = local_safetensors_path\n        else:\n            model = local_ckpt_path or load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=\"model.ckpt\")\n\n        yaml = local_yaml_path or load_file_from_url(self.yaml_url, model_dir=self.model_download_path, file_name=\"project.yaml\")\n\n        return LDSR(model, yaml)\n\n    def do_upscale(self, img, path):\n        try:\n            ldsr = self.load_model(path)\n        except Exception:\n            errors.report(f\"Failed loading LDSR model {path}\", exc_info=True)\n            return img\n        ddim_steps = shared.opts.ldsr_steps\n        return ldsr.super_resolution(img, ddim_steps, self.scale)\n\n\ndef on_ui_settings():\n    import gradio as gr\n\n    shared.opts.add_option(\"ldsr_steps\", shared.OptionInfo(100, \"LDSR processing steps. Lower = faster\", gr.Slider, {\"minimum\": 1, \"maximum\": 200, \"step\": 1}, section=('upscaling', \"Upscaling\")))\n    shared.opts.add_option(\"ldsr_cached\", shared.OptionInfo(False, \"Cache LDSR model in memory\", gr.Checkbox, {\"interactive\": True}, section=('upscaling', \"Upscaling\")))\n\n\nscript_callbacks.on_ui_settings(on_ui_settings)\n"
  },
  {
    "path": "extensions-builtin/LDSR/sd_hijack_autoencoder.py",
    "content": "# The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo\n# The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo\n# As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder\nimport numpy as np\nimport torch\nimport pytorch_lightning as pl\nimport torch.nn.functional as F\nfrom contextlib import contextmanager\n\nfrom torch.optim.lr_scheduler import LambdaLR\n\nfrom ldm.modules.ema import LitEma\nfrom vqvae_quantize import VectorQuantizer2 as VectorQuantizer\nfrom ldm.modules.diffusionmodules.model import Encoder, Decoder\nfrom ldm.util import instantiate_from_config\n\nimport ldm.models.autoencoder\nfrom packaging import version\n\nclass VQModel(pl.LightningModule):\n    def __init__(self,\n                 ddconfig,\n                 lossconfig,\n                 n_embed,\n                 embed_dim,\n                 ckpt_path=None,\n                 ignore_keys=None,\n                 image_key=\"image\",\n                 colorize_nlabels=None,\n                 monitor=None,\n                 batch_resize_range=None,\n                 scheduler_config=None,\n                 lr_g_factor=1.0,\n                 remap=None,\n                 sane_index_shape=False, # tell vector quantizer to return indices as bhw\n                 use_ema=False\n                 ):\n        super().__init__()\n        self.embed_dim = embed_dim\n        self.n_embed = n_embed\n        self.image_key = image_key\n        self.encoder = Encoder(**ddconfig)\n        self.decoder = Decoder(**ddconfig)\n        self.loss = instantiate_from_config(lossconfig)\n        self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,\n                                        remap=remap,\n                                        sane_index_shape=sane_index_shape)\n        self.quant_conv = torch.nn.Conv2d(ddconfig[\"z_channels\"], embed_dim, 1)\n        self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig[\"z_channels\"], 1)\n        if colorize_nlabels is not None:\n            assert type(colorize_nlabels)==int\n            self.register_buffer(\"colorize\", torch.randn(3, colorize_nlabels, 1, 1))\n        if monitor is not None:\n            self.monitor = monitor\n        self.batch_resize_range = batch_resize_range\n        if self.batch_resize_range is not None:\n            print(f\"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.\")\n\n        self.use_ema = use_ema\n        if self.use_ema:\n            self.model_ema = LitEma(self)\n            print(f\"Keeping EMAs of {len(list(self.model_ema.buffers()))}.\")\n\n        if ckpt_path is not None:\n            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [])\n        self.scheduler_config = scheduler_config\n        self.lr_g_factor = lr_g_factor\n\n    @contextmanager\n    def ema_scope(self, context=None):\n        if self.use_ema:\n            self.model_ema.store(self.parameters())\n            self.model_ema.copy_to(self)\n            if context is not None:\n                print(f\"{context}: Switched to EMA weights\")\n        try:\n            yield None\n        finally:\n            if self.use_ema:\n                self.model_ema.restore(self.parameters())\n                if context is not None:\n                    print(f\"{context}: Restored training weights\")\n\n    def init_from_ckpt(self, path, ignore_keys=None):\n        sd = torch.load(path, map_location=\"cpu\")[\"state_dict\"]\n        keys = list(sd.keys())\n        for k in keys:\n            for ik in ignore_keys or []:\n                if k.startswith(ik):\n                    print(\"Deleting key {} from state_dict.\".format(k))\n                    del sd[k]\n        missing, unexpected = self.load_state_dict(sd, strict=False)\n        print(f\"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys\")\n        if missing:\n            print(f\"Missing Keys: {missing}\")\n        if unexpected:\n            print(f\"Unexpected Keys: {unexpected}\")\n\n    def on_train_batch_end(self, *args, **kwargs):\n        if self.use_ema:\n            self.model_ema(self)\n\n    def encode(self, x):\n        h = self.encoder(x)\n        h = self.quant_conv(h)\n        quant, emb_loss, info = self.quantize(h)\n        return quant, emb_loss, info\n\n    def encode_to_prequant(self, x):\n        h = self.encoder(x)\n        h = self.quant_conv(h)\n        return h\n\n    def decode(self, quant):\n        quant = self.post_quant_conv(quant)\n        dec = self.decoder(quant)\n        return dec\n\n    def decode_code(self, code_b):\n        quant_b = self.quantize.embed_code(code_b)\n        dec = self.decode(quant_b)\n        return dec\n\n    def forward(self, input, return_pred_indices=False):\n        quant, diff, (_,_,ind) = self.encode(input)\n        dec = self.decode(quant)\n        if return_pred_indices:\n            return dec, diff, ind\n        return dec, diff\n\n    def get_input(self, batch, k):\n        x = batch[k]\n        if len(x.shape) == 3:\n            x = x[..., None]\n        x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()\n        if self.batch_resize_range is not None:\n            lower_size = self.batch_resize_range[0]\n            upper_size = self.batch_resize_range[1]\n            if self.global_step <= 4:\n                # do the first few batches with max size to avoid later oom\n                new_resize = upper_size\n            else:\n                new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))\n            if new_resize != x.shape[2]:\n                x = F.interpolate(x, size=new_resize, mode=\"bicubic\")\n            x = x.detach()\n        return x\n\n    def training_step(self, batch, batch_idx, optimizer_idx):\n        # https://github.com/pytorch/pytorch/issues/37142\n        # try not to fool the heuristics\n        x = self.get_input(batch, self.image_key)\n        xrec, qloss, ind = self(x, return_pred_indices=True)\n\n        if optimizer_idx == 0:\n            # autoencode\n            aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,\n                                            last_layer=self.get_last_layer(), split=\"train\",\n                                            predicted_indices=ind)\n\n            self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)\n            return aeloss\n\n        if optimizer_idx == 1:\n            # discriminator\n            discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,\n                                            last_layer=self.get_last_layer(), split=\"train\")\n            self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)\n            return discloss\n\n    def validation_step(self, batch, batch_idx):\n        log_dict = self._validation_step(batch, batch_idx)\n        with self.ema_scope():\n            self._validation_step(batch, batch_idx, suffix=\"_ema\")\n        return log_dict\n\n    def _validation_step(self, batch, batch_idx, suffix=\"\"):\n        x = self.get_input(batch, self.image_key)\n        xrec, qloss, ind = self(x, return_pred_indices=True)\n        aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,\n                                        self.global_step,\n                                        last_layer=self.get_last_layer(),\n                                        split=\"val\"+suffix,\n                                        predicted_indices=ind\n                                        )\n\n        discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,\n                                            self.global_step,\n                                            last_layer=self.get_last_layer(),\n                                            split=\"val\"+suffix,\n                                            predicted_indices=ind\n                                            )\n        rec_loss = log_dict_ae[f\"val{suffix}/rec_loss\"]\n        self.log(f\"val{suffix}/rec_loss\", rec_loss,\n                   prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)\n        self.log(f\"val{suffix}/aeloss\", aeloss,\n                   prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)\n        if version.parse(pl.__version__) >= version.parse('1.4.0'):\n            del log_dict_ae[f\"val{suffix}/rec_loss\"]\n        self.log_dict(log_dict_ae)\n        self.log_dict(log_dict_disc)\n        return self.log_dict\n\n    def configure_optimizers(self):\n        lr_d = self.learning_rate\n        lr_g = self.lr_g_factor*self.learning_rate\n        print(\"lr_d\", lr_d)\n        print(\"lr_g\", lr_g)\n        opt_ae = torch.optim.Adam(list(self.encoder.parameters())+\n                                  list(self.decoder.parameters())+\n                                  list(self.quantize.parameters())+\n                                  list(self.quant_conv.parameters())+\n                                  list(self.post_quant_conv.parameters()),\n                                  lr=lr_g, betas=(0.5, 0.9))\n        opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),\n                                    lr=lr_d, betas=(0.5, 0.9))\n\n        if self.scheduler_config is not None:\n            scheduler = instantiate_from_config(self.scheduler_config)\n\n            print(\"Setting up LambdaLR scheduler...\")\n            scheduler = [\n                {\n                    'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),\n                    'interval': 'step',\n                    'frequency': 1\n                },\n                {\n                    'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),\n                    'interval': 'step',\n                    'frequency': 1\n                },\n            ]\n            return [opt_ae, opt_disc], scheduler\n        return [opt_ae, opt_disc], []\n\n    def get_last_layer(self):\n        return self.decoder.conv_out.weight\n\n    def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):\n        log = {}\n        x = self.get_input(batch, self.image_key)\n        x = x.to(self.device)\n        if only_inputs:\n            log[\"inputs\"] = x\n            return log\n        xrec, _ = self(x)\n        if x.shape[1] > 3:\n            # colorize with random projection\n            assert xrec.shape[1] > 3\n            x = self.to_rgb(x)\n            xrec = self.to_rgb(xrec)\n        log[\"inputs\"] = x\n        log[\"reconstructions\"] = xrec\n        if plot_ema:\n            with self.ema_scope():\n                xrec_ema, _ = self(x)\n                if x.shape[1] > 3:\n                    xrec_ema = self.to_rgb(xrec_ema)\n                log[\"reconstructions_ema\"] = xrec_ema\n        return log\n\n    def to_rgb(self, x):\n        assert self.image_key == \"segmentation\"\n        if not hasattr(self, \"colorize\"):\n            self.register_buffer(\"colorize\", torch.randn(3, x.shape[1], 1, 1).to(x))\n        x = F.conv2d(x, weight=self.colorize)\n        x = 2.*(x-x.min())/(x.max()-x.min()) - 1.\n        return x\n\n\nclass VQModelInterface(VQModel):\n    def __init__(self, embed_dim, *args, **kwargs):\n        super().__init__(*args, embed_dim=embed_dim, **kwargs)\n        self.embed_dim = embed_dim\n\n    def encode(self, x):\n        h = self.encoder(x)\n        h = self.quant_conv(h)\n        return h\n\n    def decode(self, h, force_not_quantize=False):\n        # also go through quantization layer\n        if not force_not_quantize:\n            quant, emb_loss, info = self.quantize(h)\n        else:\n            quant = h\n        quant = self.post_quant_conv(quant)\n        dec = self.decoder(quant)\n        return dec\n\nldm.models.autoencoder.VQModel = VQModel\nldm.models.autoencoder.VQModelInterface = VQModelInterface\n"
  },
  {
    "path": "extensions-builtin/LDSR/sd_hijack_ddpm_v1.py",
    "content": "# This script is copied from the compvis/stable-diffusion repo (aka the SD V1 repo)\n# Original filename: ldm/models/diffusion/ddpm.py\n# The purpose to reinstate the old DDPM logic which works with VQ, whereas the V2 one doesn't\n# Some models such as LDSR require VQ to work correctly\n# The classes are suffixed with \"V1\" and added back to the \"ldm.models.diffusion.ddpm\" module\n\nimport torch\nimport torch.nn as nn\nimport numpy as np\nimport pytorch_lightning as pl\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom einops import rearrange, repeat\nfrom contextlib import contextmanager\nfrom functools import partial\nfrom tqdm import tqdm\nfrom torchvision.utils import make_grid\nfrom pytorch_lightning.utilities.distributed import rank_zero_only\n\nfrom ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config\nfrom ldm.modules.ema import LitEma\nfrom ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution\nfrom ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL\nfrom ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like\nfrom ldm.models.diffusion.ddim import DDIMSampler\n\nimport ldm.models.diffusion.ddpm\n\n__conditioning_keys__ = {'concat': 'c_concat',\n                         'crossattn': 'c_crossattn',\n                         'adm': 'y'}\n\n\ndef disabled_train(self, mode=True):\n    \"\"\"Overwrite model.train with this function to make sure train/eval mode\n    does not change anymore.\"\"\"\n    return self\n\n\ndef uniform_on_device(r1, r2, shape, device):\n    return (r1 - r2) * torch.rand(*shape, device=device) + r2\n\n\nclass DDPMV1(pl.LightningModule):\n    # classic DDPM with Gaussian diffusion, in image space\n    def __init__(self,\n                 unet_config,\n                 timesteps=1000,\n                 beta_schedule=\"linear\",\n                 loss_type=\"l2\",\n                 ckpt_path=None,\n                 ignore_keys=None,\n                 load_only_unet=False,\n                 monitor=\"val/loss\",\n                 use_ema=True,\n                 first_stage_key=\"image\",\n                 image_size=256,\n                 channels=3,\n                 log_every_t=100,\n                 clip_denoised=True,\n                 linear_start=1e-4,\n                 linear_end=2e-2,\n                 cosine_s=8e-3,\n                 given_betas=None,\n                 original_elbo_weight=0.,\n                 v_posterior=0.,  # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta\n                 l_simple_weight=1.,\n                 conditioning_key=None,\n                 parameterization=\"eps\",  # all assuming fixed variance schedules\n                 scheduler_config=None,\n                 use_positional_encodings=False,\n                 learn_logvar=False,\n                 logvar_init=0.,\n                 ):\n        super().__init__()\n        assert parameterization in [\"eps\", \"x0\"], 'currently only supporting \"eps\" and \"x0\"'\n        self.parameterization = parameterization\n        print(f\"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode\")\n        self.cond_stage_model = None\n        self.clip_denoised = clip_denoised\n        self.log_every_t = log_every_t\n        self.first_stage_key = first_stage_key\n        self.image_size = image_size  # try conv?\n        self.channels = channels\n        self.use_positional_encodings = use_positional_encodings\n        self.model = DiffusionWrapperV1(unet_config, conditioning_key)\n        count_params(self.model, verbose=True)\n        self.use_ema = use_ema\n        if self.use_ema:\n            self.model_ema = LitEma(self.model)\n            print(f\"Keeping EMAs of {len(list(self.model_ema.buffers()))}.\")\n\n        self.use_scheduler = scheduler_config is not None\n        if self.use_scheduler:\n            self.scheduler_config = scheduler_config\n\n        self.v_posterior = v_posterior\n        self.original_elbo_weight = original_elbo_weight\n        self.l_simple_weight = l_simple_weight\n\n        if monitor is not None:\n            self.monitor = monitor\n        if ckpt_path is not None:\n            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet)\n\n        self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,\n                               linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)\n\n        self.loss_type = loss_type\n\n        self.learn_logvar = learn_logvar\n        self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))\n        if self.learn_logvar:\n            self.logvar = nn.Parameter(self.logvar, requires_grad=True)\n\n\n    def register_schedule(self, given_betas=None, beta_schedule=\"linear\", timesteps=1000,\n                          linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):\n        if exists(given_betas):\n            betas = given_betas\n        else:\n            betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,\n                                       cosine_s=cosine_s)\n        alphas = 1. - betas\n        alphas_cumprod = np.cumprod(alphas, axis=0)\n        alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])\n\n        timesteps, = betas.shape\n        self.num_timesteps = int(timesteps)\n        self.linear_start = linear_start\n        self.linear_end = linear_end\n        assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'\n\n        to_torch = partial(torch.tensor, dtype=torch.float32)\n\n        self.register_buffer('betas', to_torch(betas))\n        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))\n        self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))\n\n        # calculations for diffusion q(x_t | x_{t-1}) and others\n        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))\n        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))\n        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))\n        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))\n        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))\n\n        # calculations for posterior q(x_{t-1} | x_t, x_0)\n        posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (\n                    1. - alphas_cumprod) + self.v_posterior * betas\n        # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)\n        self.register_buffer('posterior_variance', to_torch(posterior_variance))\n        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain\n        self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))\n        self.register_buffer('posterior_mean_coef1', to_torch(\n            betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))\n        self.register_buffer('posterior_mean_coef2', to_torch(\n            (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))\n\n        if self.parameterization == \"eps\":\n            lvlb_weights = self.betas ** 2 / (\n                        2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))\n        elif self.parameterization == \"x0\":\n            lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))\n        else:\n            raise NotImplementedError(\"mu not supported\")\n        # TODO how to choose this term\n        lvlb_weights[0] = lvlb_weights[1]\n        self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)\n        assert not torch.isnan(self.lvlb_weights).all()\n\n    @contextmanager\n    def ema_scope(self, context=None):\n        if self.use_ema:\n            self.model_ema.store(self.model.parameters())\n            self.model_ema.copy_to(self.model)\n            if context is not None:\n                print(f\"{context}: Switched to EMA weights\")\n        try:\n            yield None\n        finally:\n            if self.use_ema:\n                self.model_ema.restore(self.model.parameters())\n                if context is not None:\n                    print(f\"{context}: Restored training weights\")\n\n    def init_from_ckpt(self, path, ignore_keys=None, only_model=False):\n        sd = torch.load(path, map_location=\"cpu\")\n        if \"state_dict\" in list(sd.keys()):\n            sd = sd[\"state_dict\"]\n        keys = list(sd.keys())\n        for k in keys:\n            for ik in ignore_keys or []:\n                if k.startswith(ik):\n                    print(\"Deleting key {} from state_dict.\".format(k))\n                    del sd[k]\n        missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(\n            sd, strict=False)\n        print(f\"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys\")\n        if missing:\n            print(f\"Missing Keys: {missing}\")\n        if unexpected:\n            print(f\"Unexpected Keys: {unexpected}\")\n\n    def q_mean_variance(self, x_start, t):\n        \"\"\"\n        Get the distribution q(x_t | x_0).\n        :param x_start: the [N x C x ...] tensor of noiseless inputs.\n        :param t: the number of diffusion steps (minus 1). Here, 0 means one step.\n        :return: A tuple (mean, variance, log_variance), all of x_start's shape.\n        \"\"\"\n        mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)\n        variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)\n        log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)\n        return mean, variance, log_variance\n\n    def predict_start_from_noise(self, x_t, t, noise):\n        return (\n                extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -\n                extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise\n        )\n\n    def q_posterior(self, x_start, x_t, t):\n        posterior_mean = (\n                extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +\n                extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t\n        )\n        posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)\n        posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)\n        return posterior_mean, posterior_variance, posterior_log_variance_clipped\n\n    def p_mean_variance(self, x, t, clip_denoised: bool):\n        model_out = self.model(x, t)\n        if self.parameterization == \"eps\":\n            x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)\n        elif self.parameterization == \"x0\":\n            x_recon = model_out\n        if clip_denoised:\n            x_recon.clamp_(-1., 1.)\n\n        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)\n        return model_mean, posterior_variance, posterior_log_variance\n\n    @torch.no_grad()\n    def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):\n        b, *_, device = *x.shape, x.device\n        model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)\n        noise = noise_like(x.shape, device, repeat_noise)\n        # no noise when t == 0\n        nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))\n        return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise\n\n    @torch.no_grad()\n    def p_sample_loop(self, shape, return_intermediates=False):\n        device = self.betas.device\n        b = shape[0]\n        img = torch.randn(shape, device=device)\n        intermediates = [img]\n        for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):\n            img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),\n                                clip_denoised=self.clip_denoised)\n            if i % self.log_every_t == 0 or i == self.num_timesteps - 1:\n                intermediates.append(img)\n        if return_intermediates:\n            return img, intermediates\n        return img\n\n    @torch.no_grad()\n    def sample(self, batch_size=16, return_intermediates=False):\n        image_size = self.image_size\n        channels = self.channels\n        return self.p_sample_loop((batch_size, channels, image_size, image_size),\n                                  return_intermediates=return_intermediates)\n\n    def q_sample(self, x_start, t, noise=None):\n        noise = default(noise, lambda: torch.randn_like(x_start))\n        return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +\n                extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)\n\n    def get_loss(self, pred, target, mean=True):\n        if self.loss_type == 'l1':\n            loss = (target - pred).abs()\n            if mean:\n                loss = loss.mean()\n        elif self.loss_type == 'l2':\n            if mean:\n                loss = torch.nn.functional.mse_loss(target, pred)\n            else:\n                loss = torch.nn.functional.mse_loss(target, pred, reduction='none')\n        else:\n            raise NotImplementedError(\"unknown loss type '{loss_type}'\")\n\n        return loss\n\n    def p_losses(self, x_start, t, noise=None):\n        noise = default(noise, lambda: torch.randn_like(x_start))\n        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)\n        model_out = self.model(x_noisy, t)\n\n        loss_dict = {}\n        if self.parameterization == \"eps\":\n            target = noise\n        elif self.parameterization == \"x0\":\n            target = x_start\n        else:\n            raise NotImplementedError(f\"Parameterization {self.parameterization} not yet supported\")\n\n        loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])\n\n        log_prefix = 'train' if self.training else 'val'\n\n        loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})\n        loss_simple = loss.mean() * self.l_simple_weight\n\n        loss_vlb = (self.lvlb_weights[t] * loss).mean()\n        loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})\n\n        loss = loss_simple + self.original_elbo_weight * loss_vlb\n\n        loss_dict.update({f'{log_prefix}/loss': loss})\n\n        return loss, loss_dict\n\n    def forward(self, x, *args, **kwargs):\n        # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size\n        # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'\n        t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()\n        return self.p_losses(x, t, *args, **kwargs)\n\n    def get_input(self, batch, k):\n        x = batch[k]\n        if len(x.shape) == 3:\n            x = x[..., None]\n        x = rearrange(x, 'b h w c -> b c h w')\n        x = x.to(memory_format=torch.contiguous_format).float()\n        return x\n\n    def shared_step(self, batch):\n        x = self.get_input(batch, self.first_stage_key)\n        loss, loss_dict = self(x)\n        return loss, loss_dict\n\n    def training_step(self, batch, batch_idx):\n        loss, loss_dict = self.shared_step(batch)\n\n        self.log_dict(loss_dict, prog_bar=True,\n                      logger=True, on_step=True, on_epoch=True)\n\n        self.log(\"global_step\", self.global_step,\n                 prog_bar=True, logger=True, on_step=True, on_epoch=False)\n\n        if self.use_scheduler:\n            lr = self.optimizers().param_groups[0]['lr']\n            self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)\n\n        return loss\n\n    @torch.no_grad()\n    def validation_step(self, batch, batch_idx):\n        _, loss_dict_no_ema = self.shared_step(batch)\n        with self.ema_scope():\n            _, loss_dict_ema = self.shared_step(batch)\n            loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}\n        self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)\n        self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)\n\n    def on_train_batch_end(self, *args, **kwargs):\n        if self.use_ema:\n            self.model_ema(self.model)\n\n    def _get_rows_from_list(self, samples):\n        n_imgs_per_row = len(samples)\n        denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')\n        denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')\n        denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)\n        return denoise_grid\n\n    @torch.no_grad()\n    def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):\n        log = {}\n        x = self.get_input(batch, self.first_stage_key)\n        N = min(x.shape[0], N)\n        n_row = min(x.shape[0], n_row)\n        x = x.to(self.device)[:N]\n        log[\"inputs\"] = x\n\n        # get diffusion row\n        diffusion_row = []\n        x_start = x[:n_row]\n\n        for t in range(self.num_timesteps):\n            if t % self.log_every_t == 0 or t == self.num_timesteps - 1:\n                t = repeat(torch.tensor([t]), '1 -> b', b=n_row)\n                t = t.to(self.device).long()\n                noise = torch.randn_like(x_start)\n                x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)\n                diffusion_row.append(x_noisy)\n\n        log[\"diffusion_row\"] = self._get_rows_from_list(diffusion_row)\n\n        if sample:\n            # get denoise row\n            with self.ema_scope(\"Plotting\"):\n                samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)\n\n            log[\"samples\"] = samples\n            log[\"denoise_row\"] = self._get_rows_from_list(denoise_row)\n\n        if return_keys:\n            if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:\n                return log\n            else:\n                return {key: log[key] for key in return_keys}\n        return log\n\n    def configure_optimizers(self):\n        lr = self.learning_rate\n        params = list(self.model.parameters())\n        if self.learn_logvar:\n            params = params + [self.logvar]\n        opt = torch.optim.AdamW(params, lr=lr)\n        return opt\n\n\nclass LatentDiffusionV1(DDPMV1):\n    \"\"\"main class\"\"\"\n    def __init__(self,\n                 first_stage_config,\n                 cond_stage_config,\n                 num_timesteps_cond=None,\n                 cond_stage_key=\"image\",\n                 cond_stage_trainable=False,\n                 concat_mode=True,\n                 cond_stage_forward=None,\n                 conditioning_key=None,\n                 scale_factor=1.0,\n                 scale_by_std=False,\n                 *args, **kwargs):\n        self.num_timesteps_cond = default(num_timesteps_cond, 1)\n        self.scale_by_std = scale_by_std\n        assert self.num_timesteps_cond <= kwargs['timesteps']\n        # for backwards compatibility after implementation of DiffusionWrapper\n        if conditioning_key is None:\n            conditioning_key = 'concat' if concat_mode else 'crossattn'\n        if cond_stage_config == '__is_unconditional__':\n            conditioning_key = None\n        ckpt_path = kwargs.pop(\"ckpt_path\", None)\n        ignore_keys = kwargs.pop(\"ignore_keys\", [])\n        super().__init__(*args, conditioning_key=conditioning_key, **kwargs)\n        self.concat_mode = concat_mode\n        self.cond_stage_trainable = cond_stage_trainable\n        self.cond_stage_key = cond_stage_key\n        try:\n            self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1\n        except Exception:\n            self.num_downs = 0\n        if not scale_by_std:\n            self.scale_factor = scale_factor\n        else:\n            self.register_buffer('scale_factor', torch.tensor(scale_factor))\n        self.instantiate_first_stage(first_stage_config)\n        self.instantiate_cond_stage(cond_stage_config)\n        self.cond_stage_forward = cond_stage_forward\n        self.clip_denoised = False\n        self.bbox_tokenizer = None\n\n        self.restarted_from_ckpt = False\n        if ckpt_path is not None:\n            self.init_from_ckpt(ckpt_path, ignore_keys)\n            self.restarted_from_ckpt = True\n\n    def make_cond_schedule(self, ):\n        self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)\n        ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()\n        self.cond_ids[:self.num_timesteps_cond] = ids\n\n    @rank_zero_only\n    @torch.no_grad()\n    def on_train_batch_start(self, batch, batch_idx, dataloader_idx):\n        # only for very first batch\n        if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:\n            assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'\n            # set rescale weight to 1./std of encodings\n            print(\"### USING STD-RESCALING ###\")\n            x = super().get_input(batch, self.first_stage_key)\n            x = x.to(self.device)\n            encoder_posterior = self.encode_first_stage(x)\n            z = self.get_first_stage_encoding(encoder_posterior).detach()\n            del self.scale_factor\n            self.register_buffer('scale_factor', 1. / z.flatten().std())\n            print(f\"setting self.scale_factor to {self.scale_factor}\")\n            print(\"### USING STD-RESCALING ###\")\n\n    def register_schedule(self,\n                          given_betas=None, beta_schedule=\"linear\", timesteps=1000,\n                          linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):\n        super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)\n\n        self.shorten_cond_schedule = self.num_timesteps_cond > 1\n        if self.shorten_cond_schedule:\n            self.make_cond_schedule()\n\n    def instantiate_first_stage(self, config):\n        model = instantiate_from_config(config)\n        self.first_stage_model = model.eval()\n        self.first_stage_model.train = disabled_train\n        for param in self.first_stage_model.parameters():\n            param.requires_grad = False\n\n    def instantiate_cond_stage(self, config):\n        if not self.cond_stage_trainable:\n            if config == \"__is_first_stage__\":\n                print(\"Using first stage also as cond stage.\")\n                self.cond_stage_model = self.first_stage_model\n            elif config == \"__is_unconditional__\":\n                print(f\"Training {self.__class__.__name__} as an unconditional model.\")\n                self.cond_stage_model = None\n                # self.be_unconditional = True\n            else:\n                model = instantiate_from_config(config)\n                self.cond_stage_model = model.eval()\n                self.cond_stage_model.train = disabled_train\n                for param in self.cond_stage_model.parameters():\n                    param.requires_grad = False\n        else:\n            assert config != '__is_first_stage__'\n            assert config != '__is_unconditional__'\n            model = instantiate_from_config(config)\n            self.cond_stage_model = model\n\n    def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):\n        denoise_row = []\n        for zd in tqdm(samples, desc=desc):\n            denoise_row.append(self.decode_first_stage(zd.to(self.device),\n                                                            force_not_quantize=force_no_decoder_quantization))\n        n_imgs_per_row = len(denoise_row)\n        denoise_row = torch.stack(denoise_row)  # n_log_step, n_row, C, H, W\n        denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')\n        denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')\n        denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)\n        return denoise_grid\n\n    def get_first_stage_encoding(self, encoder_posterior):\n        if isinstance(encoder_posterior, DiagonalGaussianDistribution):\n            z = encoder_posterior.sample()\n        elif isinstance(encoder_posterior, torch.Tensor):\n            z = encoder_posterior\n        else:\n            raise NotImplementedError(f\"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented\")\n        return self.scale_factor * z\n\n    def get_learned_conditioning(self, c):\n        if self.cond_stage_forward is None:\n            if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):\n                c = self.cond_stage_model.encode(c)\n                if isinstance(c, DiagonalGaussianDistribution):\n                    c = c.mode()\n            else:\n                c = self.cond_stage_model(c)\n        else:\n            assert hasattr(self.cond_stage_model, self.cond_stage_forward)\n            c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)\n        return c\n\n    def meshgrid(self, h, w):\n        y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)\n        x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)\n\n        arr = torch.cat([y, x], dim=-1)\n        return arr\n\n    def delta_border(self, h, w):\n        \"\"\"\n        :param h: height\n        :param w: width\n        :return: normalized distance to image border,\n         with min distance = 0 at border and max dist = 0.5 at image center\n        \"\"\"\n        lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)\n        arr = self.meshgrid(h, w) / lower_right_corner\n        dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]\n        dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]\n        edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]\n        return edge_dist\n\n    def get_weighting(self, h, w, Ly, Lx, device):\n        weighting = self.delta_border(h, w)\n        weighting = torch.clip(weighting, self.split_input_params[\"clip_min_weight\"],\n                               self.split_input_params[\"clip_max_weight\"], )\n        weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)\n\n        if self.split_input_params[\"tie_braker\"]:\n            L_weighting = self.delta_border(Ly, Lx)\n            L_weighting = torch.clip(L_weighting,\n                                     self.split_input_params[\"clip_min_tie_weight\"],\n                                     self.split_input_params[\"clip_max_tie_weight\"])\n\n            L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)\n            weighting = weighting * L_weighting\n        return weighting\n\n    def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1):  # todo load once not every time, shorten code\n        \"\"\"\n        :param x: img of size (bs, c, h, w)\n        :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])\n        \"\"\"\n        bs, nc, h, w = x.shape\n\n        # number of crops in image\n        Ly = (h - kernel_size[0]) // stride[0] + 1\n        Lx = (w - kernel_size[1]) // stride[1] + 1\n\n        if uf == 1 and df == 1:\n            fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)\n            unfold = torch.nn.Unfold(**fold_params)\n\n            fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)\n\n            weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)\n            normalization = fold(weighting).view(1, 1, h, w)  # normalizes the overlap\n            weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))\n\n        elif uf > 1 and df == 1:\n            fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)\n            unfold = torch.nn.Unfold(**fold_params)\n\n            fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),\n                                dilation=1, padding=0,\n                                stride=(stride[0] * uf, stride[1] * uf))\n            fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)\n\n            weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)\n            normalization = fold(weighting).view(1, 1, h * uf, w * uf)  # normalizes the overlap\n            weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))\n\n        elif df > 1 and uf == 1:\n            fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)\n            unfold = torch.nn.Unfold(**fold_params)\n\n            fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),\n                                dilation=1, padding=0,\n                                stride=(stride[0] // df, stride[1] // df))\n            fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)\n\n            weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)\n            normalization = fold(weighting).view(1, 1, h // df, w // df)  # normalizes the overlap\n            weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))\n\n        else:\n            raise NotImplementedError\n\n        return fold, unfold, normalization, weighting\n\n    @torch.no_grad()\n    def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,\n                  cond_key=None, return_original_cond=False, bs=None):\n        x = super().get_input(batch, k)\n        if bs is not None:\n            x = x[:bs]\n        x = x.to(self.device)\n        encoder_posterior = self.encode_first_stage(x)\n        z = self.get_first_stage_encoding(encoder_posterior).detach()\n\n        if self.model.conditioning_key is not None:\n            if cond_key is None:\n                cond_key = self.cond_stage_key\n            if cond_key != self.first_stage_key:\n                if cond_key in ['caption', 'coordinates_bbox']:\n                    xc = batch[cond_key]\n                elif cond_key == 'class_label':\n                    xc = batch\n                else:\n                    xc = super().get_input(batch, cond_key).to(self.device)\n            else:\n                xc = x\n            if not self.cond_stage_trainable or force_c_encode:\n                if isinstance(xc, dict) or isinstance(xc, list):\n                    # import pudb; pudb.set_trace()\n                    c = self.get_learned_conditioning(xc)\n                else:\n                    c = self.get_learned_conditioning(xc.to(self.device))\n            else:\n                c = xc\n            if bs is not None:\n                c = c[:bs]\n\n            if self.use_positional_encodings:\n                pos_x, pos_y = self.compute_latent_shifts(batch)\n                ckey = __conditioning_keys__[self.model.conditioning_key]\n                c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}\n\n        else:\n            c = None\n            xc = None\n            if self.use_positional_encodings:\n                pos_x, pos_y = self.compute_latent_shifts(batch)\n                c = {'pos_x': pos_x, 'pos_y': pos_y}\n        out = [z, c]\n        if return_first_stage_outputs:\n            xrec = self.decode_first_stage(z)\n            out.extend([x, xrec])\n        if return_original_cond:\n            out.append(xc)\n        return out\n\n    @torch.no_grad()\n    def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):\n        if predict_cids:\n            if z.dim() == 4:\n                z = torch.argmax(z.exp(), dim=1).long()\n            z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)\n            z = rearrange(z, 'b h w c -> b c h w').contiguous()\n\n        z = 1. / self.scale_factor * z\n\n        if hasattr(self, \"split_input_params\"):\n            if self.split_input_params[\"patch_distributed_vq\"]:\n                ks = self.split_input_params[\"ks\"]  # eg. (128, 128)\n                stride = self.split_input_params[\"stride\"]  # eg. (64, 64)\n                uf = self.split_input_params[\"vqf\"]\n                bs, nc, h, w = z.shape\n                if ks[0] > h or ks[1] > w:\n                    ks = (min(ks[0], h), min(ks[1], w))\n                    print(\"reducing Kernel\")\n\n                if stride[0] > h or stride[1] > w:\n                    stride = (min(stride[0], h), min(stride[1], w))\n                    print(\"reducing stride\")\n\n                fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)\n\n                z = unfold(z)  # (bn, nc * prod(**ks), L)\n                # 1. Reshape to img shape\n                z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1]))  # (bn, nc, ks[0], ks[1], L )\n\n                # 2. apply model loop over last dim\n                if isinstance(self.first_stage_model, VQModelInterface):\n                    output_list = [self.first_stage_model.decode(z[:, :, :, :, i],\n                                                                 force_not_quantize=predict_cids or force_not_quantize)\n                                   for i in range(z.shape[-1])]\n                else:\n\n                    output_list = [self.first_stage_model.decode(z[:, :, :, :, i])\n                                   for i in range(z.shape[-1])]\n\n                o = torch.stack(output_list, axis=-1)  # # (bn, nc, ks[0], ks[1], L)\n                o = o * weighting\n                # Reverse 1. reshape to img shape\n                o = o.view((o.shape[0], -1, o.shape[-1]))  # (bn, nc * ks[0] * ks[1], L)\n                # stitch crops together\n                decoded = fold(o)\n                decoded = decoded / normalization  # norm is shape (1, 1, h, w)\n                return decoded\n            else:\n                if isinstance(self.first_stage_model, VQModelInterface):\n                    return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)\n                else:\n                    return self.first_stage_model.decode(z)\n\n        else:\n            if isinstance(self.first_stage_model, VQModelInterface):\n                return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)\n            else:\n                return self.first_stage_model.decode(z)\n\n    # same as above but without decorator\n    def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):\n        if predict_cids:\n            if z.dim() == 4:\n                z = torch.argmax(z.exp(), dim=1).long()\n            z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)\n            z = rearrange(z, 'b h w c -> b c h w').contiguous()\n\n        z = 1. / self.scale_factor * z\n\n        if hasattr(self, \"split_input_params\"):\n            if self.split_input_params[\"patch_distributed_vq\"]:\n                ks = self.split_input_params[\"ks\"]  # eg. (128, 128)\n                stride = self.split_input_params[\"stride\"]  # eg. (64, 64)\n                uf = self.split_input_params[\"vqf\"]\n                bs, nc, h, w = z.shape\n                if ks[0] > h or ks[1] > w:\n                    ks = (min(ks[0], h), min(ks[1], w))\n                    print(\"reducing Kernel\")\n\n                if stride[0] > h or stride[1] > w:\n                    stride = (min(stride[0], h), min(stride[1], w))\n                    print(\"reducing stride\")\n\n                fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)\n\n                z = unfold(z)  # (bn, nc * prod(**ks), L)\n                # 1. Reshape to img shape\n                z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1]))  # (bn, nc, ks[0], ks[1], L )\n\n                # 2. apply model loop over last dim\n                if isinstance(self.first_stage_model, VQModelInterface):\n                    output_list = [self.first_stage_model.decode(z[:, :, :, :, i],\n                                                                 force_not_quantize=predict_cids or force_not_quantize)\n                                   for i in range(z.shape[-1])]\n                else:\n\n                    output_list = [self.first_stage_model.decode(z[:, :, :, :, i])\n                                   for i in range(z.shape[-1])]\n\n                o = torch.stack(output_list, axis=-1)  # # (bn, nc, ks[0], ks[1], L)\n                o = o * weighting\n                # Reverse 1. reshape to img shape\n                o = o.view((o.shape[0], -1, o.shape[-1]))  # (bn, nc * ks[0] * ks[1], L)\n                # stitch crops together\n                decoded = fold(o)\n                decoded = decoded / normalization  # norm is shape (1, 1, h, w)\n                return decoded\n            else:\n                if isinstance(self.first_stage_model, VQModelInterface):\n                    return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)\n                else:\n                    return self.first_stage_model.decode(z)\n\n        else:\n            if isinstance(self.first_stage_model, VQModelInterface):\n                return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)\n            else:\n                return self.first_stage_model.decode(z)\n\n    @torch.no_grad()\n    def encode_first_stage(self, x):\n        if hasattr(self, \"split_input_params\"):\n            if self.split_input_params[\"patch_distributed_vq\"]:\n                ks = self.split_input_params[\"ks\"]  # eg. (128, 128)\n                stride = self.split_input_params[\"stride\"]  # eg. (64, 64)\n                df = self.split_input_params[\"vqf\"]\n                self.split_input_params['original_image_size'] = x.shape[-2:]\n                bs, nc, h, w = x.shape\n                if ks[0] > h or ks[1] > w:\n                    ks = (min(ks[0], h), min(ks[1], w))\n                    print(\"reducing Kernel\")\n\n                if stride[0] > h or stride[1] > w:\n                    stride = (min(stride[0], h), min(stride[1], w))\n                    print(\"reducing stride\")\n\n                fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)\n                z = unfold(x)  # (bn, nc * prod(**ks), L)\n                # Reshape to img shape\n                z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1]))  # (bn, nc, ks[0], ks[1], L )\n\n                output_list = [self.first_stage_model.encode(z[:, :, :, :, i])\n                               for i in range(z.shape[-1])]\n\n                o = torch.stack(output_list, axis=-1)\n                o = o * weighting\n\n                # Reverse reshape to img shape\n                o = o.view((o.shape[0], -1, o.shape[-1]))  # (bn, nc * ks[0] * ks[1], L)\n                # stitch crops together\n                decoded = fold(o)\n                decoded = decoded / normalization\n                return decoded\n\n            else:\n                return self.first_stage_model.encode(x)\n        else:\n            return self.first_stage_model.encode(x)\n\n    def shared_step(self, batch, **kwargs):\n        x, c = self.get_input(batch, self.first_stage_key)\n        loss = self(x, c)\n        return loss\n\n    def forward(self, x, c, *args, **kwargs):\n        t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()\n        if self.model.conditioning_key is not None:\n            assert c is not None\n            if self.cond_stage_trainable:\n                c = self.get_learned_conditioning(c)\n            if self.shorten_cond_schedule:  # TODO: drop this option\n                tc = self.cond_ids[t].to(self.device)\n                c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))\n        return self.p_losses(x, c, t, *args, **kwargs)\n\n    def apply_model(self, x_noisy, t, cond, return_ids=False):\n\n        if isinstance(cond, dict):\n            # hybrid case, cond is expected to be a dict\n            pass\n        else:\n            if not isinstance(cond, list):\n                cond = [cond]\n            key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'\n            cond = {key: cond}\n\n        if hasattr(self, \"split_input_params\"):\n            assert len(cond) == 1  # todo can only deal with one conditioning atm\n            assert not return_ids\n            ks = self.split_input_params[\"ks\"]  # eg. (128, 128)\n            stride = self.split_input_params[\"stride\"]  # eg. (64, 64)\n\n            h, w = x_noisy.shape[-2:]\n\n            fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)\n\n            z = unfold(x_noisy)  # (bn, nc * prod(**ks), L)\n            # Reshape to img shape\n            z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1]))  # (bn, nc, ks[0], ks[1], L )\n            z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]\n\n            if self.cond_stage_key in [\"image\", \"LR_image\", \"segmentation\",\n                                       'bbox_img'] and self.model.conditioning_key:  # todo check for completeness\n                c_key = next(iter(cond.keys()))  # get key\n                c = next(iter(cond.values()))  # get value\n                assert (len(c) == 1)  # todo extend to list with more than one elem\n                c = c[0]  # get element\n\n                c = unfold(c)\n                c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1]))  # (bn, nc, ks[0], ks[1], L )\n\n                cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]\n\n            elif self.cond_stage_key == 'coordinates_bbox':\n                assert 'original_image_size' in self.split_input_params, 'BoundingBoxRescaling is missing original_image_size'\n\n                # assuming padding of unfold is always 0 and its dilation is always 1\n                n_patches_per_row = int((w - ks[0]) / stride[0] + 1)\n                full_img_h, full_img_w = self.split_input_params['original_image_size']\n                # as we are operating on latents, we need the factor from the original image size to the\n                # spatial latent size to properly rescale the crops for regenerating the bbox annotations\n                num_downs = self.first_stage_model.encoder.num_resolutions - 1\n                rescale_latent = 2 ** (num_downs)\n\n                # get top left positions of patches as conforming for the bbbox tokenizer, therefore we\n                # need to rescale the tl patch coordinates to be in between (0,1)\n                tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,\n                                         rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)\n                                        for patch_nr in range(z.shape[-1])]\n\n                # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)\n                patch_limits = [(x_tl, y_tl,\n                                 rescale_latent * ks[0] / full_img_w,\n                                 rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]\n                # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]\n\n                # tokenize crop coordinates for the bounding boxes of the respective patches\n                patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)\n                                      for bbox in patch_limits]  # list of length l with tensors of shape (1, 2)\n                print(patch_limits_tknzd[0].shape)\n                # cut tknzd crop position from conditioning\n                assert isinstance(cond, dict), 'cond must be dict to be fed into model'\n                cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)\n                print(cut_cond.shape)\n\n                adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])\n                adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')\n                print(adapted_cond.shape)\n                adapted_cond = self.get_learned_conditioning(adapted_cond)\n                print(adapted_cond.shape)\n                adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])\n                print(adapted_cond.shape)\n\n                cond_list = [{'c_crossattn': [e]} for e in adapted_cond]\n\n            else:\n                cond_list = [cond for i in range(z.shape[-1])]  # Todo make this more efficient\n\n            # apply model by loop over crops\n            output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]\n            assert not isinstance(output_list[0],\n                                  tuple)  # todo cant deal with multiple model outputs check this never happens\n\n            o = torch.stack(output_list, axis=-1)\n            o = o * weighting\n            # Reverse reshape to img shape\n            o = o.view((o.shape[0], -1, o.shape[-1]))  # (bn, nc * ks[0] * ks[1], L)\n            # stitch crops together\n            x_recon = fold(o) / normalization\n\n        else:\n            x_recon = self.model(x_noisy, t, **cond)\n\n        if isinstance(x_recon, tuple) and not return_ids:\n            return x_recon[0]\n        else:\n            return x_recon\n\n    def _predict_eps_from_xstart(self, x_t, t, pred_xstart):\n        return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \\\n               extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)\n\n    def _prior_bpd(self, x_start):\n        \"\"\"\n        Get the prior KL term for the variational lower-bound, measured in\n        bits-per-dim.\n        This term can't be optimized, as it only depends on the encoder.\n        :param x_start: the [N x C x ...] tensor of inputs.\n        :return: a batch of [N] KL values (in bits), one per batch element.\n        \"\"\"\n        batch_size = x_start.shape[0]\n        t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)\n        qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)\n        kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)\n        return mean_flat(kl_prior) / np.log(2.0)\n\n    def p_losses(self, x_start, cond, t, noise=None):\n        noise = default(noise, lambda: torch.randn_like(x_start))\n        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)\n        model_output = self.apply_model(x_noisy, t, cond)\n\n        loss_dict = {}\n        prefix = 'train' if self.training else 'val'\n\n        if self.parameterization == \"x0\":\n            target = x_start\n        elif self.parameterization == \"eps\":\n            target = noise\n        else:\n            raise NotImplementedError()\n\n        loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])\n        loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})\n\n        logvar_t = self.logvar[t].to(self.device)\n        loss = loss_simple / torch.exp(logvar_t) + logvar_t\n        # loss = loss_simple / torch.exp(self.logvar) + self.logvar\n        if self.learn_logvar:\n            loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})\n            loss_dict.update({'logvar': self.logvar.data.mean()})\n\n        loss = self.l_simple_weight * loss.mean()\n\n        loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))\n        loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()\n        loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})\n        loss += (self.original_elbo_weight * loss_vlb)\n        loss_dict.update({f'{prefix}/loss': loss})\n\n        return loss, loss_dict\n\n    def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,\n                        return_x0=False, score_corrector=None, corrector_kwargs=None):\n        t_in = t\n        model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)\n\n        if score_corrector is not None:\n            assert self.parameterization == \"eps\"\n            model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)\n\n        if return_codebook_ids:\n            model_out, logits = model_out\n\n        if self.parameterization == \"eps\":\n            x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)\n        elif self.parameterization == \"x0\":\n            x_recon = model_out\n        else:\n            raise NotImplementedError()\n\n        if clip_denoised:\n            x_recon.clamp_(-1., 1.)\n        if quantize_denoised:\n            x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)\n        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)\n        if return_codebook_ids:\n            return model_mean, posterior_variance, posterior_log_variance, logits\n        elif return_x0:\n            return model_mean, posterior_variance, posterior_log_variance, x_recon\n        else:\n            return model_mean, posterior_variance, posterior_log_variance\n\n    @torch.no_grad()\n    def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,\n                 return_codebook_ids=False, quantize_denoised=False, return_x0=False,\n                 temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):\n        b, *_, device = *x.shape, x.device\n        outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,\n                                       return_codebook_ids=return_codebook_ids,\n                                       quantize_denoised=quantize_denoised,\n                                       return_x0=return_x0,\n                                       score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)\n        if return_codebook_ids:\n            raise DeprecationWarning(\"Support dropped.\")\n            model_mean, _, model_log_variance, logits = outputs\n        elif return_x0:\n            model_mean, _, model_log_variance, x0 = outputs\n        else:\n            model_mean, _, model_log_variance = outputs\n\n        noise = noise_like(x.shape, device, repeat_noise) * temperature\n        if noise_dropout > 0.:\n            noise = torch.nn.functional.dropout(noise, p=noise_dropout)\n        # no noise when t == 0\n        nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))\n\n        if return_codebook_ids:\n            return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)\n        if return_x0:\n            return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0\n        else:\n            return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise\n\n    @torch.no_grad()\n    def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,\n                              img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,\n                              score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,\n                              log_every_t=None):\n        if not log_every_t:\n            log_every_t = self.log_every_t\n        timesteps = self.num_timesteps\n        if batch_size is not None:\n            b = batch_size if batch_size is not None else shape[0]\n            shape = [batch_size] + list(shape)\n        else:\n            b = batch_size = shape[0]\n        if x_T is None:\n            img = torch.randn(shape, device=self.device)\n        else:\n            img = x_T\n        intermediates = []\n        if cond is not None:\n            if isinstance(cond, dict):\n                cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else\n                [x[:batch_size] for x in cond[key]] for key in cond}\n            else:\n                cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]\n\n        if start_T is not None:\n            timesteps = min(timesteps, start_T)\n        iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',\n                        total=timesteps) if verbose else reversed(\n            range(0, timesteps))\n        if type(temperature) == float:\n            temperature = [temperature] * timesteps\n\n        for i in iterator:\n            ts = torch.full((b,), i, device=self.device, dtype=torch.long)\n            if self.shorten_cond_schedule:\n                assert self.model.conditioning_key != 'hybrid'\n                tc = self.cond_ids[ts].to(cond.device)\n                cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))\n\n            img, x0_partial = self.p_sample(img, cond, ts,\n                                            clip_denoised=self.clip_denoised,\n                                            quantize_denoised=quantize_denoised, return_x0=True,\n                                            temperature=temperature[i], noise_dropout=noise_dropout,\n                                            score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)\n            if mask is not None:\n                assert x0 is not None\n                img_orig = self.q_sample(x0, ts)\n                img = img_orig * mask + (1. - mask) * img\n\n            if i % log_every_t == 0 or i == timesteps - 1:\n                intermediates.append(x0_partial)\n            if callback:\n                callback(i)\n            if img_callback:\n                img_callback(img, i)\n        return img, intermediates\n\n    @torch.no_grad()\n    def p_sample_loop(self, cond, shape, return_intermediates=False,\n                      x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,\n                      mask=None, x0=None, img_callback=None, start_T=None,\n                      log_every_t=None):\n\n        if not log_every_t:\n            log_every_t = self.log_every_t\n        device = self.betas.device\n        b = shape[0]\n        if x_T is None:\n            img = torch.randn(shape, device=device)\n        else:\n            img = x_T\n\n        intermediates = [img]\n        if timesteps is None:\n            timesteps = self.num_timesteps\n\n        if start_T is not None:\n            timesteps = min(timesteps, start_T)\n        iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(\n            range(0, timesteps))\n\n        if mask is not None:\n            assert x0 is not None\n            assert x0.shape[2:3] == mask.shape[2:3]  # spatial size has to match\n\n        for i in iterator:\n            ts = torch.full((b,), i, device=device, dtype=torch.long)\n            if self.shorten_cond_schedule:\n                assert self.model.conditioning_key != 'hybrid'\n                tc = self.cond_ids[ts].to(cond.device)\n                cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))\n\n            img = self.p_sample(img, cond, ts,\n                                clip_denoised=self.clip_denoised,\n                                quantize_denoised=quantize_denoised)\n            if mask is not None:\n                img_orig = self.q_sample(x0, ts)\n                img = img_orig * mask + (1. - mask) * img\n\n            if i % log_every_t == 0 or i == timesteps - 1:\n                intermediates.append(img)\n            if callback:\n                callback(i)\n            if img_callback:\n                img_callback(img, i)\n\n        if return_intermediates:\n            return img, intermediates\n        return img\n\n    @torch.no_grad()\n    def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,\n               verbose=True, timesteps=None, quantize_denoised=False,\n               mask=None, x0=None, shape=None,**kwargs):\n        if shape is None:\n            shape = (batch_size, self.channels, self.image_size, self.image_size)\n        if cond is not None:\n            if isinstance(cond, dict):\n                cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else\n                [x[:batch_size] for x in cond[key]] for key in cond}\n            else:\n                cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]\n        return self.p_sample_loop(cond,\n                                  shape,\n                                  return_intermediates=return_intermediates, x_T=x_T,\n                                  verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,\n                                  mask=mask, x0=x0)\n\n    @torch.no_grad()\n    def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):\n\n        if ddim:\n            ddim_sampler = DDIMSampler(self)\n            shape = (self.channels, self.image_size, self.image_size)\n            samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,\n                                                        shape,cond,verbose=False,**kwargs)\n\n        else:\n            samples, intermediates = self.sample(cond=cond, batch_size=batch_size,\n                                                 return_intermediates=True,**kwargs)\n\n        return samples, intermediates\n\n\n    @torch.no_grad()\n    def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,\n                   quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,\n                   plot_diffusion_rows=True, **kwargs):\n\n        use_ddim = ddim_steps is not None\n\n        log = {}\n        z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,\n                                           return_first_stage_outputs=True,\n                                           force_c_encode=True,\n                                           return_original_cond=True,\n                                           bs=N)\n        N = min(x.shape[0], N)\n        n_row = min(x.shape[0], n_row)\n        log[\"inputs\"] = x\n        log[\"reconstruction\"] = xrec\n        if self.model.conditioning_key is not None:\n            if hasattr(self.cond_stage_model, \"decode\"):\n                xc = self.cond_stage_model.decode(c)\n                log[\"conditioning\"] = xc\n            elif self.cond_stage_key in [\"caption\"]:\n                xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[\"caption\"])\n                log[\"conditioning\"] = xc\n            elif self.cond_stage_key == 'class_label':\n                xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[\"human_label\"])\n                log['conditioning'] = xc\n            elif isimage(xc):\n                log[\"conditioning\"] = xc\n            if ismap(xc):\n                log[\"original_conditioning\"] = self.to_rgb(xc)\n\n        if plot_diffusion_rows:\n            # get diffusion row\n            diffusion_row = []\n            z_start = z[:n_row]\n            for t in range(self.num_timesteps):\n                if t % self.log_every_t == 0 or t == self.num_timesteps - 1:\n                    t = repeat(torch.tensor([t]), '1 -> b', b=n_row)\n                    t = t.to(self.device).long()\n                    noise = torch.randn_like(z_start)\n                    z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)\n                    diffusion_row.append(self.decode_first_stage(z_noisy))\n\n            diffusion_row = torch.stack(diffusion_row)  # n_log_step, n_row, C, H, W\n            diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')\n            diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')\n            diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])\n            log[\"diffusion_row\"] = diffusion_grid\n\n        if sample:\n            # get denoise row\n            with self.ema_scope(\"Plotting\"):\n                samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,\n                                                         ddim_steps=ddim_steps,eta=ddim_eta)\n                # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)\n            x_samples = self.decode_first_stage(samples)\n            log[\"samples\"] = x_samples\n            if plot_denoise_rows:\n                denoise_grid = self._get_denoise_row_from_list(z_denoise_row)\n                log[\"denoise_row\"] = denoise_grid\n\n            if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(\n                    self.first_stage_model, IdentityFirstStage):\n                # also display when quantizing x0 while sampling\n                with self.ema_scope(\"Plotting Quantized Denoised\"):\n                    samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,\n                                                             ddim_steps=ddim_steps,eta=ddim_eta,\n                                                             quantize_denoised=True)\n                    # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,\n                    #                                      quantize_denoised=True)\n                x_samples = self.decode_first_stage(samples.to(self.device))\n                log[\"samples_x0_quantized\"] = x_samples\n\n            if inpaint:\n                # make a simple center square\n                h, w = z.shape[2], z.shape[3]\n                mask = torch.ones(N, h, w).to(self.device)\n                # zeros will be filled in\n                mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.\n                mask = mask[:, None, ...]\n                with self.ema_scope(\"Plotting Inpaint\"):\n\n                    samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,\n                                                ddim_steps=ddim_steps, x0=z[:N], mask=mask)\n                x_samples = self.decode_first_stage(samples.to(self.device))\n                log[\"samples_inpainting\"] = x_samples\n                log[\"mask\"] = mask\n\n                # outpaint\n                with self.ema_scope(\"Plotting Outpaint\"):\n                    samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,\n                                                ddim_steps=ddim_steps, x0=z[:N], mask=mask)\n                x_samples = self.decode_first_stage(samples.to(self.device))\n                log[\"samples_outpainting\"] = x_samples\n\n        if plot_progressive_rows:\n            with self.ema_scope(\"Plotting Progressives\"):\n                img, progressives = self.progressive_denoising(c,\n                                                               shape=(self.channels, self.image_size, self.image_size),\n                                                               batch_size=N)\n            prog_row = self._get_denoise_row_from_list(progressives, desc=\"Progressive Generation\")\n            log[\"progressive_row\"] = prog_row\n\n        if return_keys:\n            if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:\n                return log\n            else:\n                return {key: log[key] for key in return_keys}\n        return log\n\n    def configure_optimizers(self):\n        lr = self.learning_rate\n        params = list(self.model.parameters())\n        if self.cond_stage_trainable:\n            print(f\"{self.__class__.__name__}: Also optimizing conditioner params!\")\n            params = params + list(self.cond_stage_model.parameters())\n        if self.learn_logvar:\n            print('Diffusion model optimizing logvar')\n            params.append(self.logvar)\n        opt = torch.optim.AdamW(params, lr=lr)\n        if self.use_scheduler:\n            assert 'target' in self.scheduler_config\n            scheduler = instantiate_from_config(self.scheduler_config)\n\n            print(\"Setting up LambdaLR scheduler...\")\n            scheduler = [\n                {\n                    'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),\n                    'interval': 'step',\n                    'frequency': 1\n                }]\n            return [opt], scheduler\n        return opt\n\n    @torch.no_grad()\n    def to_rgb(self, x):\n        x = x.float()\n        if not hasattr(self, \"colorize\"):\n            self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)\n        x = nn.functional.conv2d(x, weight=self.colorize)\n        x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.\n        return x\n\n\nclass DiffusionWrapperV1(pl.LightningModule):\n    def __init__(self, diff_model_config, conditioning_key):\n        super().__init__()\n        self.diffusion_model = instantiate_from_config(diff_model_config)\n        self.conditioning_key = conditioning_key\n        assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']\n\n    def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):\n        if self.conditioning_key is None:\n            out = self.diffusion_model(x, t)\n        elif self.conditioning_key == 'concat':\n            xc = torch.cat([x] + c_concat, dim=1)\n            out = self.diffusion_model(xc, t)\n        elif self.conditioning_key == 'crossattn':\n            cc = torch.cat(c_crossattn, 1)\n            out = self.diffusion_model(x, t, context=cc)\n        elif self.conditioning_key == 'hybrid':\n            xc = torch.cat([x] + c_concat, dim=1)\n            cc = torch.cat(c_crossattn, 1)\n            out = self.diffusion_model(xc, t, context=cc)\n        elif self.conditioning_key == 'adm':\n            cc = c_crossattn[0]\n            out = self.diffusion_model(x, t, y=cc)\n        else:\n            raise NotImplementedError()\n\n        return out\n\n\nclass Layout2ImgDiffusionV1(LatentDiffusionV1):\n    # TODO: move all layout-specific hacks to this class\n    def __init__(self, cond_stage_key, *args, **kwargs):\n        assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key=\"coordinates_bbox\"'\n        super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs)\n\n    def log_images(self, batch, N=8, *args, **kwargs):\n        logs = super().log_images(*args, batch=batch, N=N, **kwargs)\n\n        key = 'train' if self.training else 'validation'\n        dset = self.trainer.datamodule.datasets[key]\n        mapper = dset.conditional_builders[self.cond_stage_key]\n\n        bbox_imgs = []\n        map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno))\n        for tknzd_bbox in batch[self.cond_stage_key][:N]:\n            bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256))\n            bbox_imgs.append(bboximg)\n\n        cond_img = torch.stack(bbox_imgs, dim=0)\n        logs['bbox_image'] = cond_img\n        return logs\n\nldm.models.diffusion.ddpm.DDPMV1 = DDPMV1\nldm.models.diffusion.ddpm.LatentDiffusionV1 = LatentDiffusionV1\nldm.models.diffusion.ddpm.DiffusionWrapperV1 = DiffusionWrapperV1\nldm.models.diffusion.ddpm.Layout2ImgDiffusionV1 = Layout2ImgDiffusionV1\n"
  },
  {
    "path": "extensions-builtin/LDSR/vqvae_quantize.py",
    "content": "# Vendored from https://raw.githubusercontent.com/CompVis/taming-transformers/24268930bf1dce879235a7fddd0b2355b84d7ea6/taming/modules/vqvae/quantize.py,\n# where the license is as follows:\n#\n# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND,\n# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF\n# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.\n# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,\n# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR\n# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE\n# OR OTHER DEALINGS IN THE SOFTWARE./\n\nimport torch\nimport torch.nn as nn\nimport numpy as np\nfrom einops import rearrange\n\n\nclass VectorQuantizer2(nn.Module):\n    \"\"\"\n    Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly\n    avoids costly matrix multiplications and allows for post-hoc remapping of indices.\n    \"\"\"\n\n    # NOTE: due to a bug the beta term was applied to the wrong term. for\n    # backwards compatibility we use the buggy version by default, but you can\n    # specify legacy=False to fix it.\n    def __init__(self, n_e, e_dim, beta, remap=None, unknown_index=\"random\",\n                 sane_index_shape=False, legacy=True):\n        super().__init__()\n        self.n_e = n_e\n        self.e_dim = e_dim\n        self.beta = beta\n        self.legacy = legacy\n\n        self.embedding = nn.Embedding(self.n_e, self.e_dim)\n        self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)\n\n        self.remap = remap\n        if self.remap is not None:\n            self.register_buffer(\"used\", torch.tensor(np.load(self.remap)))\n            self.re_embed = self.used.shape[0]\n            self.unknown_index = unknown_index  # \"random\" or \"extra\" or integer\n            if self.unknown_index == \"extra\":\n                self.unknown_index = self.re_embed\n                self.re_embed = self.re_embed + 1\n            print(f\"Remapping {self.n_e} indices to {self.re_embed} indices. \"\n                  f\"Using {self.unknown_index} for unknown indices.\")\n        else:\n            self.re_embed = n_e\n\n        self.sane_index_shape = sane_index_shape\n\n    def remap_to_used(self, inds):\n        ishape = inds.shape\n        assert len(ishape) > 1\n        inds = inds.reshape(ishape[0], -1)\n        used = self.used.to(inds)\n        match = (inds[:, :, None] == used[None, None, ...]).long()\n        new = match.argmax(-1)\n        unknown = match.sum(2) < 1\n        if self.unknown_index == \"random\":\n            new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)\n        else:\n            new[unknown] = self.unknown_index\n        return new.reshape(ishape)\n\n    def unmap_to_all(self, inds):\n        ishape = inds.shape\n        assert len(ishape) > 1\n        inds = inds.reshape(ishape[0], -1)\n        used = self.used.to(inds)\n        if self.re_embed > self.used.shape[0]:  # extra token\n            inds[inds >= self.used.shape[0]] = 0  # simply set to zero\n        back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)\n        return back.reshape(ishape)\n\n    def forward(self, z, temp=None, rescale_logits=False, return_logits=False):\n        assert temp is None or temp == 1.0, \"Only for interface compatible with Gumbel\"\n        assert rescale_logits is False, \"Only for interface compatible with Gumbel\"\n        assert return_logits is False, \"Only for interface compatible with Gumbel\"\n        # reshape z -> (batch, height, width, channel) and flatten\n        z = rearrange(z, 'b c h w -> b h w c').contiguous()\n        z_flattened = z.view(-1, self.e_dim)\n        # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z\n\n        d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \\\n            torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \\\n            torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))\n\n        min_encoding_indices = torch.argmin(d, dim=1)\n        z_q = self.embedding(min_encoding_indices).view(z.shape)\n        perplexity = None\n        min_encodings = None\n\n        # compute loss for embedding\n        if not self.legacy:\n            loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + \\\n                   torch.mean((z_q - z.detach()) ** 2)\n        else:\n            loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * \\\n                   torch.mean((z_q - z.detach()) ** 2)\n\n        # preserve gradients\n        z_q = z + (z_q - z).detach()\n\n        # reshape back to match original input shape\n        z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()\n\n        if self.remap is not None:\n            min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1)  # add batch axis\n            min_encoding_indices = self.remap_to_used(min_encoding_indices)\n            min_encoding_indices = min_encoding_indices.reshape(-1, 1)  # flatten\n\n        if self.sane_index_shape:\n            min_encoding_indices = min_encoding_indices.reshape(\n                z_q.shape[0], z_q.shape[2], z_q.shape[3])\n\n        return z_q, loss, (perplexity, min_encodings, min_encoding_indices)\n\n    def get_codebook_entry(self, indices, shape):\n        # shape specifying (batch, height, width, channel)\n        if self.remap is not None:\n            indices = indices.reshape(shape[0], -1)  # add batch axis\n            indices = self.unmap_to_all(indices)\n            indices = indices.reshape(-1)  # flatten again\n\n        # get quantized latent vectors\n        z_q = self.embedding(indices)\n\n        if shape is not None:\n            z_q = z_q.view(shape)\n            # reshape back to match original input shape\n            z_q = z_q.permute(0, 3, 1, 2).contiguous()\n\n        return z_q\n"
  },
  {
    "path": "extensions-builtin/Lora/extra_networks_lora.py",
    "content": "from modules import extra_networks, shared\r\nimport networks\r\n\r\n\r\nclass ExtraNetworkLora(extra_networks.ExtraNetwork):\r\n    def __init__(self):\r\n        super().__init__('lora')\r\n\r\n        self.errors = {}\r\n        \"\"\"mapping of network names to the number of errors the network had during operation\"\"\"\r\n\r\n    remove_symbols = str.maketrans('', '', \":,\")\r\n\r\n    def activate(self, p, params_list):\r\n        additional = shared.opts.sd_lora\r\n\r\n        self.errors.clear()\r\n\r\n        if additional != \"None\" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional):\r\n            p.all_prompts = [x + f\"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>\" for x in p.all_prompts]\r\n            params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))\r\n\r\n        names = []\r\n        te_multipliers = []\r\n        unet_multipliers = []\r\n        dyn_dims = []\r\n        for params in params_list:\r\n            assert params.items\r\n\r\n            names.append(params.positional[0])\r\n\r\n            te_multiplier = float(params.positional[1]) if len(params.positional) > 1 else 1.0\r\n            te_multiplier = float(params.named.get(\"te\", te_multiplier))\r\n\r\n            unet_multiplier = float(params.positional[2]) if len(params.positional) > 2 else te_multiplier\r\n            unet_multiplier = float(params.named.get(\"unet\", unet_multiplier))\r\n\r\n            dyn_dim = int(params.positional[3]) if len(params.positional) > 3 else None\r\n            dyn_dim = int(params.named[\"dyn\"]) if \"dyn\" in params.named else dyn_dim\r\n\r\n            te_multipliers.append(te_multiplier)\r\n            unet_multipliers.append(unet_multiplier)\r\n            dyn_dims.append(dyn_dim)\r\n\r\n        networks.load_networks(names, te_multipliers, unet_multipliers, dyn_dims)\r\n\r\n        if shared.opts.lora_add_hashes_to_infotext:\r\n            if not getattr(p, \"is_hr_pass\", False) or not hasattr(p, \"lora_hashes\"):\r\n                p.lora_hashes = {}\r\n\r\n            for item in networks.loaded_networks:\r\n                if item.network_on_disk.shorthash and item.mentioned_name:\r\n                    p.lora_hashes[item.mentioned_name.translate(self.remove_symbols)] = item.network_on_disk.shorthash\r\n\r\n            if p.lora_hashes:\r\n                p.extra_generation_params[\"Lora hashes\"] = ', '.join(f'{k}: {v}' for k, v in p.lora_hashes.items())\r\n\r\n    def deactivate(self, p):\r\n        if self.errors:\r\n            p.comment(\"Networks with errors: \" + \", \".join(f\"{k} ({v})\" for k, v in self.errors.items()))\r\n\r\n            self.errors.clear()\r\n"
  },
  {
    "path": "extensions-builtin/Lora/lora.py",
    "content": "import networks\r\n\r\nlist_available_loras = networks.list_available_networks\r\n\r\navailable_loras = networks.available_networks\r\navailable_lora_aliases = networks.available_network_aliases\r\navailable_lora_hash_lookup = networks.available_network_hash_lookup\r\nforbidden_lora_aliases = networks.forbidden_network_aliases\r\nloaded_loras = networks.loaded_networks\r\n"
  },
  {
    "path": "extensions-builtin/Lora/lora_logger.py",
    "content": "import sys\nimport copy\nimport logging\n\n\nclass ColoredFormatter(logging.Formatter):\n    COLORS = {\n        \"DEBUG\": \"\\033[0;36m\",  # CYAN\n        \"INFO\": \"\\033[0;32m\",  # GREEN\n        \"WARNING\": \"\\033[0;33m\",  # YELLOW\n        \"ERROR\": \"\\033[0;31m\",  # RED\n        \"CRITICAL\": \"\\033[0;37;41m\",  # WHITE ON RED\n        \"RESET\": \"\\033[0m\",  # RESET COLOR\n    }\n\n    def format(self, record):\n        colored_record = copy.copy(record)\n        levelname = colored_record.levelname\n        seq = self.COLORS.get(levelname, self.COLORS[\"RESET\"])\n        colored_record.levelname = f\"{seq}{levelname}{self.COLORS['RESET']}\"\n        return super().format(colored_record)\n\n\nlogger = logging.getLogger(\"lora\")\nlogger.propagate = False\n\n\nif not logger.handlers:\n    handler = logging.StreamHandler(sys.stdout)\n    handler.setFormatter(\n        ColoredFormatter(\"[%(name)s]-%(levelname)s: %(message)s\")\n    )\n    logger.addHandler(handler)\n"
  },
  {
    "path": "extensions-builtin/Lora/lora_patches.py",
    "content": "import torch\r\n\r\nimport networks\r\nfrom modules import patches\r\n\r\n\r\nclass LoraPatches:\r\n    def __init__(self):\r\n        self.Linear_forward = patches.patch(__name__, torch.nn.Linear, 'forward', networks.network_Linear_forward)\r\n        self.Linear_load_state_dict = patches.patch(__name__, torch.nn.Linear, '_load_from_state_dict', networks.network_Linear_load_state_dict)\r\n        self.Conv2d_forward = patches.patch(__name__, torch.nn.Conv2d, 'forward', networks.network_Conv2d_forward)\r\n        self.Conv2d_load_state_dict = patches.patch(__name__, torch.nn.Conv2d, '_load_from_state_dict', networks.network_Conv2d_load_state_dict)\r\n        self.GroupNorm_forward = patches.patch(__name__, torch.nn.GroupNorm, 'forward', networks.network_GroupNorm_forward)\r\n        self.GroupNorm_load_state_dict = patches.patch(__name__, torch.nn.GroupNorm, '_load_from_state_dict', networks.network_GroupNorm_load_state_dict)\r\n        self.LayerNorm_forward = patches.patch(__name__, torch.nn.LayerNorm, 'forward', networks.network_LayerNorm_forward)\r\n        self.LayerNorm_load_state_dict = patches.patch(__name__, torch.nn.LayerNorm, '_load_from_state_dict', networks.network_LayerNorm_load_state_dict)\r\n        self.MultiheadAttention_forward = patches.patch(__name__, torch.nn.MultiheadAttention, 'forward', networks.network_MultiheadAttention_forward)\r\n        self.MultiheadAttention_load_state_dict = patches.patch(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict', networks.network_MultiheadAttention_load_state_dict)\r\n\r\n    def undo(self):\r\n        self.Linear_forward = patches.undo(__name__, torch.nn.Linear, 'forward')\r\n        self.Linear_load_state_dict = patches.undo(__name__, torch.nn.Linear, '_load_from_state_dict')\r\n        self.Conv2d_forward = patches.undo(__name__, torch.nn.Conv2d, 'forward')\r\n        self.Conv2d_load_state_dict = patches.undo(__name__, torch.nn.Conv2d, '_load_from_state_dict')\r\n        self.GroupNorm_forward = patches.undo(__name__, torch.nn.GroupNorm, 'forward')\r\n        self.GroupNorm_load_state_dict = patches.undo(__name__, torch.nn.GroupNorm, '_load_from_state_dict')\r\n        self.LayerNorm_forward = patches.undo(__name__, torch.nn.LayerNorm, 'forward')\r\n        self.LayerNorm_load_state_dict = patches.undo(__name__, torch.nn.LayerNorm, '_load_from_state_dict')\r\n        self.MultiheadAttention_forward = patches.undo(__name__, torch.nn.MultiheadAttention, 'forward')\r\n        self.MultiheadAttention_load_state_dict = patches.undo(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict')\r\n\r\n"
  },
  {
    "path": "extensions-builtin/Lora/lyco_helpers.py",
    "content": "import torch\r\n\r\n\r\ndef make_weight_cp(t, wa, wb):\r\n    temp = torch.einsum('i j k l, j r -> i r k l', t, wb)\r\n    return torch.einsum('i j k l, i r -> r j k l', temp, wa)\r\n\r\n\r\ndef rebuild_conventional(up, down, shape, dyn_dim=None):\r\n    up = up.reshape(up.size(0), -1)\r\n    down = down.reshape(down.size(0), -1)\r\n    if dyn_dim is not None:\r\n        up = up[:, :dyn_dim]\r\n        down = down[:dyn_dim, :]\r\n    return (up @ down).reshape(shape)\r\n\r\n\r\ndef rebuild_cp_decomposition(up, down, mid):\r\n    up = up.reshape(up.size(0), -1)\r\n    down = down.reshape(down.size(0), -1)\r\n    return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down)\r\n\r\n\r\n# copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py\r\ndef factorization(dimension: int, factor:int=-1) -> tuple[int, int]:\r\n    '''\r\n    return a tuple of two value of input dimension decomposed by the number closest to factor\r\n    second value is higher or equal than first value.\r\n\r\n    In LoRA with Kroneckor Product, first value is a value for weight scale.\r\n    secon value is a value for weight.\r\n\r\n    Because of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different.\r\n\r\n    examples)\r\n    factor\r\n        -1               2                4               8               16               ...\r\n    127 -> 1, 127   127 -> 1, 127    127 -> 1, 127   127 -> 1, 127   127 -> 1, 127\r\n    128 -> 8, 16    128 -> 2, 64     128 -> 4, 32    128 -> 8, 16    128 -> 8, 16\r\n    250 -> 10, 25   250 -> 2, 125    250 -> 2, 125   250 -> 5, 50    250 -> 10, 25\r\n    360 -> 8, 45    360 -> 2, 180    360 -> 4, 90    360 -> 8, 45    360 -> 12, 30\r\n    512 -> 16, 32   512 -> 2, 256    512 -> 4, 128   512 -> 8, 64    512 -> 16, 32\r\n    1024 -> 32, 32  1024 -> 2, 512   1024 -> 4, 256  1024 -> 8, 128  1024 -> 16, 64\r\n    '''\r\n\r\n    if factor > 0 and (dimension % factor) == 0:\r\n        m = factor\r\n        n = dimension // factor\r\n        if m > n:\r\n            n, m = m, n\r\n        return m, n\r\n    if factor < 0:\r\n        factor = dimension\r\n    m, n = 1, dimension\r\n    length = m + n\r\n    while m<n:\r\n        new_m = m + 1\r\n        while dimension%new_m != 0:\r\n            new_m += 1\r\n        new_n = dimension // new_m\r\n        if new_m + new_n > length or new_m>factor:\r\n            break\r\n        else:\r\n            m, n = new_m, new_n\r\n    if m > n:\r\n        n, m = m, n\r\n    return m, n\r\n\r\n"
  },
  {
    "path": "extensions-builtin/Lora/network.py",
    "content": "from __future__ import annotations\r\nimport os\r\nfrom collections import namedtuple\r\nimport enum\r\n\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\n\r\nfrom modules import sd_models, cache, errors, hashes, shared\r\nimport modules.models.sd3.mmdit\r\n\r\nNetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])\r\n\r\nmetadata_tags_order = {\"ss_sd_model_name\": 1, \"ss_resolution\": 2, \"ss_clip_skip\": 3, \"ss_num_train_images\": 10, \"ss_tag_frequency\": 20}\r\n\r\n\r\nclass SdVersion(enum.Enum):\r\n    Unknown = 1\r\n    SD1 = 2\r\n    SD2 = 3\r\n    SDXL = 4\r\n\r\n\r\nclass NetworkOnDisk:\r\n    def __init__(self, name, filename):\r\n        self.name = name\r\n        self.filename = filename\r\n        self.metadata = {}\r\n        self.is_safetensors = os.path.splitext(filename)[1].lower() == \".safetensors\"\r\n\r\n        def read_metadata():\r\n            metadata = sd_models.read_metadata_from_safetensors(filename)\r\n\r\n            return metadata\r\n\r\n        if self.is_safetensors:\r\n            try:\r\n                self.metadata = cache.cached_data_for_file('safetensors-metadata', \"lora/\" + self.name, filename, read_metadata)\r\n            except Exception as e:\r\n                errors.display(e, f\"reading lora {filename}\")\r\n\r\n        if self.metadata:\r\n            m = {}\r\n            for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):\r\n                m[k] = v\r\n\r\n            self.metadata = m\r\n\r\n        self.alias = self.metadata.get('ss_output_name', self.name)\r\n\r\n        self.hash = None\r\n        self.shorthash = None\r\n        self.set_hash(\r\n            self.metadata.get('sshs_model_hash') or\r\n            hashes.sha256_from_cache(self.filename, \"lora/\" + self.name, use_addnet_hash=self.is_safetensors) or\r\n            ''\r\n        )\r\n\r\n        self.sd_version = self.detect_version()\r\n\r\n    def detect_version(self):\r\n        if str(self.metadata.get('ss_base_model_version', \"\")).startswith(\"sdxl_\"):\r\n            return SdVersion.SDXL\r\n        elif str(self.metadata.get('ss_v2', \"\")) == \"True\":\r\n            return SdVersion.SD2\r\n        elif len(self.metadata):\r\n            return SdVersion.SD1\r\n\r\n        return SdVersion.Unknown\r\n\r\n    def set_hash(self, v):\r\n        self.hash = v\r\n        self.shorthash = self.hash[0:12]\r\n\r\n        if self.shorthash:\r\n            import networks\r\n            networks.available_network_hash_lookup[self.shorthash] = self\r\n\r\n    def read_hash(self):\r\n        if not self.hash:\r\n            self.set_hash(hashes.sha256(self.filename, \"lora/\" + self.name, use_addnet_hash=self.is_safetensors) or '')\r\n\r\n    def get_alias(self):\r\n        import networks\r\n        if shared.opts.lora_preferred_name == \"Filename\" or self.alias.lower() in networks.forbidden_network_aliases:\r\n            return self.name\r\n        else:\r\n            return self.alias\r\n\r\n\r\nclass Network:  # LoraModule\r\n    def __init__(self, name, network_on_disk: NetworkOnDisk):\r\n        self.name = name\r\n        self.network_on_disk = network_on_disk\r\n        self.te_multiplier = 1.0\r\n        self.unet_multiplier = 1.0\r\n        self.dyn_dim = None\r\n        self.modules = {}\r\n        self.bundle_embeddings = {}\r\n        self.mtime = None\r\n\r\n        self.mentioned_name = None\r\n        \"\"\"the text that was used to add the network to prompt - can be either name or an alias\"\"\"\r\n\r\n\r\nclass ModuleType:\r\n    def create_module(self, net: Network, weights: NetworkWeights) -> Network | None:\r\n        return None\r\n\r\n\r\nclass NetworkModule:\r\n    def __init__(self, net: Network, weights: NetworkWeights):\r\n        self.network = net\r\n        self.network_key = weights.network_key\r\n        self.sd_key = weights.sd_key\r\n        self.sd_module = weights.sd_module\r\n\r\n        if isinstance(self.sd_module, modules.models.sd3.mmdit.QkvLinear):\r\n            s = self.sd_module.weight.shape\r\n            self.shape = (s[0] // 3, s[1])\r\n        elif hasattr(self.sd_module, 'weight'):\r\n            self.shape = self.sd_module.weight.shape\r\n        elif isinstance(self.sd_module, nn.MultiheadAttention):\r\n            # For now, only self-attn use Pytorch's MHA\r\n            # So assume all qkvo proj have same shape\r\n            self.shape = self.sd_module.out_proj.weight.shape\r\n        else:\r\n            self.shape = None\r\n\r\n        self.ops = None\r\n        self.extra_kwargs = {}\r\n        if isinstance(self.sd_module, nn.Conv2d):\r\n            self.ops = F.conv2d\r\n            self.extra_kwargs = {\r\n                'stride': self.sd_module.stride,\r\n                'padding': self.sd_module.padding\r\n            }\r\n        elif isinstance(self.sd_module, nn.Linear):\r\n            self.ops = F.linear\r\n        elif isinstance(self.sd_module, nn.LayerNorm):\r\n            self.ops = F.layer_norm\r\n            self.extra_kwargs = {\r\n                'normalized_shape': self.sd_module.normalized_shape,\r\n                'eps': self.sd_module.eps\r\n            }\r\n        elif isinstance(self.sd_module, nn.GroupNorm):\r\n            self.ops = F.group_norm\r\n            self.extra_kwargs = {\r\n                'num_groups': self.sd_module.num_groups,\r\n                'eps': self.sd_module.eps\r\n            }\r\n\r\n        self.dim = None\r\n        self.bias = weights.w.get(\"bias\")\r\n        self.alpha = weights.w[\"alpha\"].item() if \"alpha\" in weights.w else None\r\n        self.scale = weights.w[\"scale\"].item() if \"scale\" in weights.w else None\r\n\r\n        self.dora_scale = weights.w.get(\"dora_scale\", None)\r\n        self.dora_norm_dims = len(self.shape) - 1\r\n\r\n    def multiplier(self):\r\n        if 'transformer' in self.sd_key[:20]:\r\n            return self.network.te_multiplier\r\n        else:\r\n            return self.network.unet_multiplier\r\n\r\n    def calc_scale(self):\r\n        if self.scale is not None:\r\n            return self.scale\r\n        if self.dim is not None and self.alpha is not None:\r\n            return self.alpha / self.dim\r\n\r\n        return 1.0\r\n\r\n    def apply_weight_decompose(self, updown, orig_weight):\r\n        # Match the device/dtype\r\n        orig_weight = orig_weight.to(updown.dtype)\r\n        dora_scale = self.dora_scale.to(device=orig_weight.device, dtype=updown.dtype)\r\n        updown = updown.to(orig_weight.device)\r\n\r\n        merged_scale1 = updown + orig_weight\r\n        merged_scale1_norm = (\r\n            merged_scale1.transpose(0, 1)\r\n            .reshape(merged_scale1.shape[1], -1)\r\n            .norm(dim=1, keepdim=True)\r\n            .reshape(merged_scale1.shape[1], *[1] * self.dora_norm_dims)\r\n            .transpose(0, 1)\r\n        )\r\n\r\n        dora_merged = (\r\n            merged_scale1 * (dora_scale / merged_scale1_norm)\r\n        )\r\n        final_updown = dora_merged - orig_weight\r\n        return final_updown\r\n\r\n    def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):\r\n        if self.bias is not None:\r\n            updown = updown.reshape(self.bias.shape)\r\n            updown += self.bias.to(orig_weight.device, dtype=updown.dtype)\r\n            updown = updown.reshape(output_shape)\r\n\r\n        if len(output_shape) == 4:\r\n            updown = updown.reshape(output_shape)\r\n\r\n        if orig_weight.size().numel() == updown.size().numel():\r\n            updown = updown.reshape(orig_weight.shape)\r\n\r\n        if ex_bias is not None:\r\n            ex_bias = ex_bias * self.multiplier()\r\n\r\n        updown = updown * self.calc_scale()\r\n\r\n        if self.dora_scale is not None:\r\n            updown = self.apply_weight_decompose(updown, orig_weight)\r\n\r\n        return updown * self.multiplier(), ex_bias\r\n\r\n    def calc_updown(self, target):\r\n        raise NotImplementedError()\r\n\r\n    def forward(self, x, y):\r\n        \"\"\"A general forward implementation for all modules\"\"\"\r\n        if self.ops is None:\r\n            raise NotImplementedError()\r\n        else:\r\n            updown, ex_bias = self.calc_updown(self.sd_module.weight)\r\n            return y + self.ops(x, weight=updown, bias=ex_bias, **self.extra_kwargs)\r\n\r\n"
  },
  {
    "path": "extensions-builtin/Lora/network_full.py",
    "content": "import network\r\n\r\n\r\nclass ModuleTypeFull(network.ModuleType):\r\n    def create_module(self, net: network.Network, weights: network.NetworkWeights):\r\n        if all(x in weights.w for x in [\"diff\"]):\r\n            return NetworkModuleFull(net, weights)\r\n\r\n        return None\r\n\r\n\r\nclass NetworkModuleFull(network.NetworkModule):\r\n    def __init__(self,  net: network.Network, weights: network.NetworkWeights):\r\n        super().__init__(net, weights)\r\n\r\n        self.weight = weights.w.get(\"diff\")\r\n        self.ex_bias = weights.w.get(\"diff_b\")\r\n\r\n    def calc_updown(self, orig_weight):\r\n        output_shape = self.weight.shape\r\n        updown = self.weight.to(orig_weight.device)\r\n        if self.ex_bias is not None:\r\n            ex_bias = self.ex_bias.to(orig_weight.device)\r\n        else:\r\n            ex_bias = None\r\n\r\n        return self.finalize_updown(updown, orig_weight, output_shape, ex_bias)\r\n"
  },
  {
    "path": "extensions-builtin/Lora/network_glora.py",
    "content": "\nimport network\n\nclass ModuleTypeGLora(network.ModuleType):\n    def create_module(self, net: network.Network, weights: network.NetworkWeights):\n        if all(x in weights.w for x in [\"a1.weight\", \"a2.weight\", \"alpha\", \"b1.weight\", \"b2.weight\"]):\n            return NetworkModuleGLora(net, weights)\n\n        return None\n\n# adapted from https://github.com/KohakuBlueleaf/LyCORIS\nclass NetworkModuleGLora(network.NetworkModule):\n    def __init__(self,  net: network.Network, weights: network.NetworkWeights):\n        super().__init__(net, weights)\n\n        if hasattr(self.sd_module, 'weight'):\n            self.shape = self.sd_module.weight.shape\n\n        self.w1a = weights.w[\"a1.weight\"]\n        self.w1b = weights.w[\"b1.weight\"]\n        self.w2a = weights.w[\"a2.weight\"]\n        self.w2b = weights.w[\"b2.weight\"]\n\n    def calc_updown(self, orig_weight):\n        w1a = self.w1a.to(orig_weight.device)\n        w1b = self.w1b.to(orig_weight.device)\n        w2a = self.w2a.to(orig_weight.device)\n        w2b = self.w2b.to(orig_weight.device)\n\n        output_shape = [w1a.size(0), w1b.size(1)]\n        updown = ((w2b @ w1b) + ((orig_weight.to(dtype = w1a.dtype) @ w2a) @ w1a))\n\n        return self.finalize_updown(updown, orig_weight, output_shape)\n"
  },
  {
    "path": "extensions-builtin/Lora/network_hada.py",
    "content": "import lyco_helpers\r\nimport network\r\n\r\n\r\nclass ModuleTypeHada(network.ModuleType):\r\n    def create_module(self, net: network.Network, weights: network.NetworkWeights):\r\n        if all(x in weights.w for x in [\"hada_w1_a\", \"hada_w1_b\", \"hada_w2_a\", \"hada_w2_b\"]):\r\n            return NetworkModuleHada(net, weights)\r\n\r\n        return None\r\n\r\n\r\nclass NetworkModuleHada(network.NetworkModule):\r\n    def __init__(self,  net: network.Network, weights: network.NetworkWeights):\r\n        super().__init__(net, weights)\r\n\r\n        if hasattr(self.sd_module, 'weight'):\r\n            self.shape = self.sd_module.weight.shape\r\n\r\n        self.w1a = weights.w[\"hada_w1_a\"]\r\n        self.w1b = weights.w[\"hada_w1_b\"]\r\n        self.dim = self.w1b.shape[0]\r\n        self.w2a = weights.w[\"hada_w2_a\"]\r\n        self.w2b = weights.w[\"hada_w2_b\"]\r\n\r\n        self.t1 = weights.w.get(\"hada_t1\")\r\n        self.t2 = weights.w.get(\"hada_t2\")\r\n\r\n    def calc_updown(self, orig_weight):\r\n        w1a = self.w1a.to(orig_weight.device)\r\n        w1b = self.w1b.to(orig_weight.device)\r\n        w2a = self.w2a.to(orig_weight.device)\r\n        w2b = self.w2b.to(orig_weight.device)\r\n\r\n        output_shape = [w1a.size(0), w1b.size(1)]\r\n\r\n        if self.t1 is not None:\r\n            output_shape = [w1a.size(1), w1b.size(1)]\r\n            t1 = self.t1.to(orig_weight.device)\r\n            updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b)\r\n            output_shape += t1.shape[2:]\r\n        else:\r\n            if len(w1b.shape) == 4:\r\n                output_shape += w1b.shape[2:]\r\n            updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape)\r\n\r\n        if self.t2 is not None:\r\n            t2 = self.t2.to(orig_weight.device)\r\n            updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)\r\n        else:\r\n            updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape)\r\n\r\n        updown = updown1 * updown2\r\n\r\n        return self.finalize_updown(updown, orig_weight, output_shape)\r\n"
  },
  {
    "path": "extensions-builtin/Lora/network_ia3.py",
    "content": "import network\r\n\r\n\r\nclass ModuleTypeIa3(network.ModuleType):\r\n    def create_module(self, net: network.Network, weights: network.NetworkWeights):\r\n        if all(x in weights.w for x in [\"weight\"]):\r\n            return NetworkModuleIa3(net, weights)\r\n\r\n        return None\r\n\r\n\r\nclass NetworkModuleIa3(network.NetworkModule):\r\n    def __init__(self,  net: network.Network, weights: network.NetworkWeights):\r\n        super().__init__(net, weights)\r\n\r\n        self.w = weights.w[\"weight\"]\r\n        self.on_input = weights.w[\"on_input\"].item()\r\n\r\n    def calc_updown(self, orig_weight):\r\n        w = self.w.to(orig_weight.device)\r\n\r\n        output_shape = [w.size(0), orig_weight.size(1)]\r\n        if self.on_input:\r\n            output_shape.reverse()\r\n        else:\r\n            w = w.reshape(-1, 1)\r\n\r\n        updown = orig_weight * w\r\n\r\n        return self.finalize_updown(updown, orig_weight, output_shape)\r\n"
  },
  {
    "path": "extensions-builtin/Lora/network_lokr.py",
    "content": "import torch\r\n\r\nimport lyco_helpers\r\nimport network\r\n\r\n\r\nclass ModuleTypeLokr(network.ModuleType):\r\n    def create_module(self, net: network.Network, weights: network.NetworkWeights):\r\n        has_1 = \"lokr_w1\" in weights.w or (\"lokr_w1_a\" in weights.w and \"lokr_w1_b\" in weights.w)\r\n        has_2 = \"lokr_w2\" in weights.w or (\"lokr_w2_a\" in weights.w and \"lokr_w2_b\" in weights.w)\r\n        if has_1 and has_2:\r\n            return NetworkModuleLokr(net, weights)\r\n\r\n        return None\r\n\r\n\r\ndef make_kron(orig_shape, w1, w2):\r\n    if len(w2.shape) == 4:\r\n        w1 = w1.unsqueeze(2).unsqueeze(2)\r\n    w2 = w2.contiguous()\r\n    return torch.kron(w1, w2).reshape(orig_shape)\r\n\r\n\r\nclass NetworkModuleLokr(network.NetworkModule):\r\n    def __init__(self,  net: network.Network, weights: network.NetworkWeights):\r\n        super().__init__(net, weights)\r\n\r\n        self.w1 = weights.w.get(\"lokr_w1\")\r\n        self.w1a = weights.w.get(\"lokr_w1_a\")\r\n        self.w1b = weights.w.get(\"lokr_w1_b\")\r\n        self.dim = self.w1b.shape[0] if self.w1b is not None else self.dim\r\n        self.w2 = weights.w.get(\"lokr_w2\")\r\n        self.w2a = weights.w.get(\"lokr_w2_a\")\r\n        self.w2b = weights.w.get(\"lokr_w2_b\")\r\n        self.dim = self.w2b.shape[0] if self.w2b is not None else self.dim\r\n        self.t2 = weights.w.get(\"lokr_t2\")\r\n\r\n    def calc_updown(self, orig_weight):\r\n        if self.w1 is not None:\r\n            w1 = self.w1.to(orig_weight.device)\r\n        else:\r\n            w1a = self.w1a.to(orig_weight.device)\r\n            w1b = self.w1b.to(orig_weight.device)\r\n            w1 = w1a @ w1b\r\n\r\n        if self.w2 is not None:\r\n            w2 = self.w2.to(orig_weight.device)\r\n        elif self.t2 is None:\r\n            w2a = self.w2a.to(orig_weight.device)\r\n            w2b = self.w2b.to(orig_weight.device)\r\n            w2 = w2a @ w2b\r\n        else:\r\n            t2 = self.t2.to(orig_weight.device)\r\n            w2a = self.w2a.to(orig_weight.device)\r\n            w2b = self.w2b.to(orig_weight.device)\r\n            w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)\r\n\r\n        output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)]\r\n        if len(orig_weight.shape) == 4:\r\n            output_shape = orig_weight.shape\r\n\r\n        updown = make_kron(output_shape, w1, w2)\r\n\r\n        return self.finalize_updown(updown, orig_weight, output_shape)\r\n"
  },
  {
    "path": "extensions-builtin/Lora/network_lora.py",
    "content": "import torch\r\n\r\nimport lyco_helpers\r\nimport modules.models.sd3.mmdit\r\nimport network\r\nfrom modules import devices\r\n\r\n\r\nclass ModuleTypeLora(network.ModuleType):\r\n    def create_module(self, net: network.Network, weights: network.NetworkWeights):\r\n        if all(x in weights.w for x in [\"lora_up.weight\", \"lora_down.weight\"]):\r\n            return NetworkModuleLora(net, weights)\r\n\r\n        if all(x in weights.w for x in [\"lora_A.weight\", \"lora_B.weight\"]):\r\n            w = weights.w.copy()\r\n            weights.w.clear()\r\n            weights.w.update({\"lora_up.weight\": w[\"lora_B.weight\"], \"lora_down.weight\": w[\"lora_A.weight\"]})\r\n\r\n            return NetworkModuleLora(net, weights)\r\n\r\n        return None\r\n\r\n\r\nclass NetworkModuleLora(network.NetworkModule):\r\n    def __init__(self,  net: network.Network, weights: network.NetworkWeights):\r\n        super().__init__(net, weights)\r\n\r\n        self.up_model = self.create_module(weights.w, \"lora_up.weight\")\r\n        self.down_model = self.create_module(weights.w, \"lora_down.weight\")\r\n        self.mid_model = self.create_module(weights.w, \"lora_mid.weight\", none_ok=True)\r\n\r\n        self.dim = weights.w[\"lora_down.weight\"].shape[0]\r\n\r\n    def create_module(self, weights, key, none_ok=False):\r\n        weight = weights.get(key)\r\n\r\n        if weight is None and none_ok:\r\n            return None\r\n\r\n        is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention, modules.models.sd3.mmdit.QkvLinear]\r\n        is_conv = type(self.sd_module) in [torch.nn.Conv2d]\r\n\r\n        if is_linear:\r\n            weight = weight.reshape(weight.shape[0], -1)\r\n            module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)\r\n        elif is_conv and key == \"lora_down.weight\" or key == \"dyn_up\":\r\n            if len(weight.shape) == 2:\r\n                weight = weight.reshape(weight.shape[0], -1, 1, 1)\r\n\r\n            if weight.shape[2] != 1 or weight.shape[3] != 1:\r\n                module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)\r\n            else:\r\n                module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)\r\n        elif is_conv and key == \"lora_mid.weight\":\r\n            module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)\r\n        elif is_conv and key == \"lora_up.weight\" or key == \"dyn_down\":\r\n            module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)\r\n        else:\r\n            raise AssertionError(f'Lora layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}')\r\n\r\n        with torch.no_grad():\r\n            if weight.shape != module.weight.shape:\r\n                weight = weight.reshape(module.weight.shape)\r\n            module.weight.copy_(weight)\r\n\r\n        module.to(device=devices.cpu, dtype=devices.dtype)\r\n        module.weight.requires_grad_(False)\r\n\r\n        return module\r\n\r\n    def calc_updown(self, orig_weight):\r\n        up = self.up_model.weight.to(orig_weight.device)\r\n        down = self.down_model.weight.to(orig_weight.device)\r\n\r\n        output_shape = [up.size(0), down.size(1)]\r\n        if self.mid_model is not None:\r\n            # cp-decomposition\r\n            mid = self.mid_model.weight.to(orig_weight.device)\r\n            updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid)\r\n            output_shape += mid.shape[2:]\r\n        else:\r\n            if len(down.shape) == 4:\r\n                output_shape += down.shape[2:]\r\n            updown = lyco_helpers.rebuild_conventional(up, down, output_shape, self.network.dyn_dim)\r\n\r\n        return self.finalize_updown(updown, orig_weight, output_shape)\r\n\r\n    def forward(self, x, y):\r\n        self.up_model.to(device=devices.device)\r\n        self.down_model.to(device=devices.device)\r\n\r\n        return y + self.up_model(self.down_model(x)) * self.multiplier() * self.calc_scale()\r\n\r\n\r\n"
  },
  {
    "path": "extensions-builtin/Lora/network_norm.py",
    "content": "import network\n\n\nclass ModuleTypeNorm(network.ModuleType):\n    def create_module(self, net: network.Network, weights: network.NetworkWeights):\n        if all(x in weights.w for x in [\"w_norm\", \"b_norm\"]):\n            return NetworkModuleNorm(net, weights)\n\n        return None\n\n\nclass NetworkModuleNorm(network.NetworkModule):\n    def __init__(self,  net: network.Network, weights: network.NetworkWeights):\n        super().__init__(net, weights)\n\n        self.w_norm = weights.w.get(\"w_norm\")\n        self.b_norm = weights.w.get(\"b_norm\")\n\n    def calc_updown(self, orig_weight):\n        output_shape = self.w_norm.shape\n        updown = self.w_norm.to(orig_weight.device)\n\n        if self.b_norm is not None:\n            ex_bias = self.b_norm.to(orig_weight.device)\n        else:\n            ex_bias = None\n\n        return self.finalize_updown(updown, orig_weight, output_shape, ex_bias)\n"
  },
  {
    "path": "extensions-builtin/Lora/network_oft.py",
    "content": "import torch\nimport network\nfrom einops import rearrange\n\n\nclass ModuleTypeOFT(network.ModuleType):\n    def create_module(self, net: network.Network, weights: network.NetworkWeights):\n        if all(x in weights.w for x in [\"oft_blocks\"]) or all(x in weights.w for x in [\"oft_diag\"]):\n            return NetworkModuleOFT(net, weights)\n\n        return None\n\n# Supports both kohya-ss' implementation of COFT  https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py\n# and KohakuBlueleaf's implementation of OFT/COFT https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/diag_oft.py\nclass NetworkModuleOFT(network.NetworkModule):\n    def __init__(self,  net: network.Network, weights: network.NetworkWeights):\n\n        super().__init__(net, weights)\n\n        self.lin_module = None\n        self.org_module: list[torch.Module] = [self.sd_module]\n\n        self.scale = 1.0\n        self.is_R = False\n        self.is_boft = False\n\n        # kohya-ss/New LyCORIS OFT/BOFT\n        if \"oft_blocks\" in weights.w.keys():\n            self.oft_blocks = weights.w[\"oft_blocks\"] # (num_blocks, block_size, block_size)\n            self.alpha = weights.w.get(\"alpha\", None) # alpha is constraint\n            self.dim = self.oft_blocks.shape[0] # lora dim\n        # Old LyCORIS OFT\n        elif \"oft_diag\" in weights.w.keys():\n            self.is_R = True\n            self.oft_blocks = weights.w[\"oft_diag\"]\n            # self.alpha is unused\n            self.dim = self.oft_blocks.shape[1] # (num_blocks, block_size, block_size)\n\n        is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]\n        is_conv = type(self.sd_module) in [torch.nn.Conv2d]\n        is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] # unsupported\n\n        if is_linear:\n            self.out_dim = self.sd_module.out_features\n        elif is_conv:\n            self.out_dim = self.sd_module.out_channels\n        elif is_other_linear:\n            self.out_dim = self.sd_module.embed_dim\n\n        # LyCORIS BOFT\n        if self.oft_blocks.dim() == 4:\n            self.is_boft = True\n        self.rescale = weights.w.get('rescale', None)\n        if self.rescale is not None and not is_other_linear:\n            self.rescale = self.rescale.reshape(-1, *[1]*(self.org_module[0].weight.dim() - 1))\n\n        self.num_blocks = self.dim\n        self.block_size = self.out_dim // self.dim\n        self.constraint = (0 if self.alpha is None else self.alpha) * self.out_dim\n        if self.is_R:\n            self.constraint = None\n            self.block_size = self.dim\n            self.num_blocks = self.out_dim // self.dim\n        elif self.is_boft:\n            self.boft_m = self.oft_blocks.shape[0]\n            self.num_blocks = self.oft_blocks.shape[1]\n            self.block_size = self.oft_blocks.shape[2]\n            self.boft_b = self.block_size\n\n    def calc_updown(self, orig_weight):\n        oft_blocks = self.oft_blocks.to(orig_weight.device)\n        eye = torch.eye(self.block_size, device=oft_blocks.device)\n\n        if not self.is_R:\n            block_Q = oft_blocks - oft_blocks.transpose(-1, -2) # ensure skew-symmetric orthogonal matrix\n            if self.constraint != 0:\n                norm_Q = torch.norm(block_Q.flatten())\n                new_norm_Q = torch.clamp(norm_Q, max=self.constraint.to(oft_blocks.device))\n                block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))\n            oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse())\n\n        R = oft_blocks.to(orig_weight.device)\n\n        if not self.is_boft:\n            # This errors out for MultiheadAttention, might need to be handled up-stream\n            merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)\n            merged_weight = torch.einsum(\n                'k n m, k n ... -> k m ...',\n                R,\n                merged_weight\n            )\n            merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')\n        else:\n            # TODO: determine correct value for scale\n            scale = 1.0\n            m = self.boft_m\n            b = self.boft_b\n            r_b = b // 2\n            inp = orig_weight\n            for i in range(m):\n                bi = R[i] # b_num, b_size, b_size\n                if i == 0:\n                    # Apply multiplier/scale and rescale into first weight\n                    bi = bi * scale + (1 - scale) * eye\n                inp = rearrange(inp, \"(c g k) ... -> (c k g) ...\", g=2, k=2**i * r_b)\n                inp = rearrange(inp, \"(d b) ... -> d b ...\", b=b)\n                inp = torch.einsum(\"b i j, b j ... -> b i ...\", bi, inp)\n                inp = rearrange(inp, \"d b ... -> (d b) ...\")\n                inp = rearrange(inp, \"(c k g) ... -> (c g k) ...\", g=2, k=2**i * r_b)\n            merged_weight = inp\n\n        # Rescale mechanism\n        if self.rescale is not None:\n            merged_weight = self.rescale.to(merged_weight) * merged_weight\n\n        updown = merged_weight.to(orig_weight.device) - orig_weight.to(merged_weight.dtype)\n        output_shape = orig_weight.shape\n        return self.finalize_updown(updown, orig_weight, output_shape)\n"
  },
  {
    "path": "extensions-builtin/Lora/networks.py",
    "content": "from __future__ import annotations\r\nimport gradio as gr\r\nimport logging\r\nimport os\r\nimport re\r\n\r\nimport lora_patches\r\nimport network\r\nimport network_lora\r\nimport network_glora\r\nimport network_hada\r\nimport network_ia3\r\nimport network_lokr\r\nimport network_full\r\nimport network_norm\r\nimport network_oft\r\n\r\nimport torch\r\nfrom typing import Union\r\n\r\nfrom modules import shared, devices, sd_models, errors, scripts, sd_hijack\r\nimport modules.textual_inversion.textual_inversion as textual_inversion\r\nimport modules.models.sd3.mmdit\r\n\r\nfrom lora_logger import logger\r\n\r\nmodule_types = [\r\n    network_lora.ModuleTypeLora(),\r\n    network_hada.ModuleTypeHada(),\r\n    network_ia3.ModuleTypeIa3(),\r\n    network_lokr.ModuleTypeLokr(),\r\n    network_full.ModuleTypeFull(),\r\n    network_norm.ModuleTypeNorm(),\r\n    network_glora.ModuleTypeGLora(),\r\n    network_oft.ModuleTypeOFT(),\r\n]\r\n\r\n\r\nre_digits = re.compile(r\"\\d+\")\r\nre_x_proj = re.compile(r\"(.*)_([qkv]_proj)$\")\r\nre_compiled = {}\r\n\r\nsuffix_conversion = {\r\n    \"attentions\": {},\r\n    \"resnets\": {\r\n        \"conv1\": \"in_layers_2\",\r\n        \"conv2\": \"out_layers_3\",\r\n        \"norm1\": \"in_layers_0\",\r\n        \"norm2\": \"out_layers_0\",\r\n        \"time_emb_proj\": \"emb_layers_1\",\r\n        \"conv_shortcut\": \"skip_connection\",\r\n    }\r\n}\r\n\r\n\r\ndef convert_diffusers_name_to_compvis(key, is_sd2):\r\n    def match(match_list, regex_text):\r\n        regex = re_compiled.get(regex_text)\r\n        if regex is None:\r\n            regex = re.compile(regex_text)\r\n            re_compiled[regex_text] = regex\r\n\r\n        r = re.match(regex, key)\r\n        if not r:\r\n            return False\r\n\r\n        match_list.clear()\r\n        match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])\r\n        return True\r\n\r\n    m = []\r\n\r\n    if match(m, r\"lora_unet_conv_in(.*)\"):\r\n        return f'diffusion_model_input_blocks_0_0{m[0]}'\r\n\r\n    if match(m, r\"lora_unet_conv_out(.*)\"):\r\n        return f'diffusion_model_out_2{m[0]}'\r\n\r\n    if match(m, r\"lora_unet_time_embedding_linear_(\\d+)(.*)\"):\r\n        return f\"diffusion_model_time_embed_{m[0] * 2 - 2}{m[1]}\"\r\n\r\n    if match(m, r\"lora_unet_down_blocks_(\\d+)_(attentions|resnets)_(\\d+)_(.+)\"):\r\n        suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])\r\n        return f\"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}\"\r\n\r\n    if match(m, r\"lora_unet_mid_block_(attentions|resnets)_(\\d+)_(.+)\"):\r\n        suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2])\r\n        return f\"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}\"\r\n\r\n    if match(m, r\"lora_unet_up_blocks_(\\d+)_(attentions|resnets)_(\\d+)_(.+)\"):\r\n        suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])\r\n        return f\"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}\"\r\n\r\n    if match(m, r\"lora_unet_down_blocks_(\\d+)_downsamplers_0_conv\"):\r\n        return f\"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op\"\r\n\r\n    if match(m, r\"lora_unet_up_blocks_(\\d+)_upsamplers_0_conv\"):\r\n        return f\"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv\"\r\n\r\n    if match(m, r\"lora_te_text_model_encoder_layers_(\\d+)_(.+)\"):\r\n        if is_sd2:\r\n            if 'mlp_fc1' in m[1]:\r\n                return f\"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}\"\r\n            elif 'mlp_fc2' in m[1]:\r\n                return f\"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}\"\r\n            else:\r\n                return f\"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}\"\r\n\r\n        return f\"transformer_text_model_encoder_layers_{m[0]}_{m[1]}\"\r\n\r\n    if match(m, r\"lora_te2_text_model_encoder_layers_(\\d+)_(.+)\"):\r\n        if 'mlp_fc1' in m[1]:\r\n            return f\"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}\"\r\n        elif 'mlp_fc2' in m[1]:\r\n            return f\"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}\"\r\n        else:\r\n            return f\"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}\"\r\n\r\n    return key\r\n\r\n\r\ndef assign_network_names_to_compvis_modules(sd_model):\r\n    network_layer_mapping = {}\r\n\r\n    if shared.sd_model.is_sdxl:\r\n        for i, embedder in enumerate(shared.sd_model.conditioner.embedders):\r\n            if not hasattr(embedder, 'wrapped'):\r\n                continue\r\n\r\n            for name, module in embedder.wrapped.named_modules():\r\n                network_name = f'{i}_{name.replace(\".\", \"_\")}'\r\n                network_layer_mapping[network_name] = module\r\n                module.network_layer_name = network_name\r\n    else:\r\n        cond_stage_model = getattr(shared.sd_model.cond_stage_model, 'wrapped', shared.sd_model.cond_stage_model)\r\n\r\n        for name, module in cond_stage_model.named_modules():\r\n            network_name = name.replace(\".\", \"_\")\r\n            network_layer_mapping[network_name] = module\r\n            module.network_layer_name = network_name\r\n\r\n    for name, module in shared.sd_model.model.named_modules():\r\n        network_name = name.replace(\".\", \"_\")\r\n        network_layer_mapping[network_name] = module\r\n        module.network_layer_name = network_name\r\n\r\n    sd_model.network_layer_mapping = network_layer_mapping\r\n\r\n\r\nclass BundledTIHash(str):\r\n    def __init__(self, hash_str):\r\n        self.hash = hash_str\r\n\r\n    def __str__(self):\r\n        return self.hash if shared.opts.lora_bundled_ti_to_infotext else ''\r\n\r\n\r\ndef load_network(name, network_on_disk):\r\n    net = network.Network(name, network_on_disk)\r\n    net.mtime = os.path.getmtime(network_on_disk.filename)\r\n\r\n    sd = sd_models.read_state_dict(network_on_disk.filename)\r\n\r\n    # this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0\r\n    if not hasattr(shared.sd_model, 'network_layer_mapping'):\r\n        assign_network_names_to_compvis_modules(shared.sd_model)\r\n\r\n    keys_failed_to_match = {}\r\n    is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping\r\n    if hasattr(shared.sd_model, 'diffusers_weight_map'):\r\n        diffusers_weight_map = shared.sd_model.diffusers_weight_map\r\n    elif hasattr(shared.sd_model, 'diffusers_weight_mapping'):\r\n        diffusers_weight_map = {}\r\n        for k, v in shared.sd_model.diffusers_weight_mapping():\r\n            diffusers_weight_map[k] = v\r\n        shared.sd_model.diffusers_weight_map = diffusers_weight_map\r\n    else:\r\n        diffusers_weight_map = None\r\n\r\n    matched_networks = {}\r\n    bundle_embeddings = {}\r\n\r\n    for key_network, weight in sd.items():\r\n\r\n        if diffusers_weight_map:\r\n            key_network_without_network_parts, network_name, network_weight = key_network.rsplit(\".\", 2)\r\n            network_part = network_name + '.' + network_weight\r\n        else:\r\n            key_network_without_network_parts, _, network_part = key_network.partition(\".\")\r\n\r\n        if key_network_without_network_parts == \"bundle_emb\":\r\n            emb_name, vec_name = network_part.split(\".\", 1)\r\n            emb_dict = bundle_embeddings.get(emb_name, {})\r\n            if vec_name.split('.')[0] == 'string_to_param':\r\n                _, k2 = vec_name.split('.', 1)\r\n                emb_dict['string_to_param'] = {k2: weight}\r\n            else:\r\n                emb_dict[vec_name] = weight\r\n            bundle_embeddings[emb_name] = emb_dict\r\n\r\n        if diffusers_weight_map:\r\n            key = diffusers_weight_map.get(key_network_without_network_parts, key_network_without_network_parts)\r\n        else:\r\n            key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2)\r\n\r\n        sd_module = shared.sd_model.network_layer_mapping.get(key, None)\r\n\r\n        if sd_module is None:\r\n            m = re_x_proj.match(key)\r\n            if m:\r\n                sd_module = shared.sd_model.network_layer_mapping.get(m.group(1), None)\r\n\r\n        # SDXL loras seem to already have correct compvis keys, so only need to replace \"lora_unet\" with \"diffusion_model\"\r\n        if sd_module is None and \"lora_unet\" in key_network_without_network_parts:\r\n            key = key_network_without_network_parts.replace(\"lora_unet\", \"diffusion_model\")\r\n            sd_module = shared.sd_model.network_layer_mapping.get(key, None)\r\n        elif sd_module is None and \"lora_te1_text_model\" in key_network_without_network_parts:\r\n            key = key_network_without_network_parts.replace(\"lora_te1_text_model\", \"0_transformer_text_model\")\r\n            sd_module = shared.sd_model.network_layer_mapping.get(key, None)\r\n\r\n            # some SD1 Loras also have correct compvis keys\r\n            if sd_module is None:\r\n                key = key_network_without_network_parts.replace(\"lora_te1_text_model\", \"transformer_text_model\")\r\n                sd_module = shared.sd_model.network_layer_mapping.get(key, None)\r\n\r\n        # kohya_ss OFT module\r\n        elif sd_module is None and \"oft_unet\" in key_network_without_network_parts:\r\n            key = key_network_without_network_parts.replace(\"oft_unet\", \"diffusion_model\")\r\n            sd_module = shared.sd_model.network_layer_mapping.get(key, None)\r\n\r\n        # KohakuBlueLeaf OFT module\r\n        if sd_module is None and \"oft_diag\" in key:\r\n            key = key_network_without_network_parts.replace(\"lora_unet\", \"diffusion_model\")\r\n            key = key_network_without_network_parts.replace(\"lora_te1_text_model\", \"0_transformer_text_model\")\r\n            sd_module = shared.sd_model.network_layer_mapping.get(key, None)\r\n\r\n        if sd_module is None:\r\n            keys_failed_to_match[key_network] = key\r\n            continue\r\n\r\n        if key not in matched_networks:\r\n            matched_networks[key] = network.NetworkWeights(network_key=key_network, sd_key=key, w={}, sd_module=sd_module)\r\n\r\n        matched_networks[key].w[network_part] = weight\r\n\r\n    for key, weights in matched_networks.items():\r\n        net_module = None\r\n        for nettype in module_types:\r\n            net_module = nettype.create_module(net, weights)\r\n            if net_module is not None:\r\n                break\r\n\r\n        if net_module is None:\r\n            raise AssertionError(f\"Could not find a module type (out of {', '.join([x.__class__.__name__ for x in module_types])}) that would accept those keys: {', '.join(weights.w)}\")\r\n\r\n        net.modules[key] = net_module\r\n\r\n    embeddings = {}\r\n    for emb_name, data in bundle_embeddings.items():\r\n        embedding = textual_inversion.create_embedding_from_data(data, emb_name, filename=network_on_disk.filename + \"/\" + emb_name)\r\n        embedding.loaded = None\r\n        embedding.shorthash = BundledTIHash(name)\r\n        embeddings[emb_name] = embedding\r\n\r\n    net.bundle_embeddings = embeddings\r\n\r\n    if keys_failed_to_match:\r\n        logging.debug(f\"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}\")\r\n\r\n    return net\r\n\r\n\r\ndef purge_networks_from_memory():\r\n    while len(networks_in_memory) > shared.opts.lora_in_memory_limit and len(networks_in_memory) > 0:\r\n        name = next(iter(networks_in_memory))\r\n        networks_in_memory.pop(name, None)\r\n\r\n    devices.torch_gc()\r\n\r\n\r\ndef load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):\r\n    emb_db = sd_hijack.model_hijack.embedding_db\r\n    already_loaded = {}\r\n\r\n    for net in loaded_networks:\r\n        if net.name in names:\r\n            already_loaded[net.name] = net\r\n        for emb_name, embedding in net.bundle_embeddings.items():\r\n            if embedding.loaded:\r\n                emb_db.register_embedding_by_name(None, shared.sd_model, emb_name)\r\n\r\n    loaded_networks.clear()\r\n\r\n    unavailable_networks = []\r\n    for name in names:\r\n        if name.lower() in forbidden_network_aliases and available_networks.get(name) is None:\r\n            unavailable_networks.append(name)\r\n        elif available_network_aliases.get(name) is None:\r\n            unavailable_networks.append(name)\r\n\r\n    if unavailable_networks:\r\n        update_available_networks_by_names(unavailable_networks)\r\n\r\n    networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names]\r\n    if any(x is None for x in networks_on_disk):\r\n        list_available_networks()\r\n\r\n        networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names]\r\n\r\n    failed_to_load_networks = []\r\n\r\n    for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)):\r\n        net = already_loaded.get(name, None)\r\n\r\n        if network_on_disk is not None:\r\n            if net is None:\r\n                net = networks_in_memory.get(name)\r\n\r\n            if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime:\r\n                try:\r\n                    net = load_network(name, network_on_disk)\r\n\r\n                    networks_in_memory.pop(name, None)\r\n                    networks_in_memory[name] = net\r\n                except Exception as e:\r\n                    errors.display(e, f\"loading network {network_on_disk.filename}\")\r\n                    continue\r\n\r\n            net.mentioned_name = name\r\n\r\n            network_on_disk.read_hash()\r\n\r\n        if net is None:\r\n            failed_to_load_networks.append(name)\r\n            logging.info(f\"Couldn't find network with name {name}\")\r\n            continue\r\n\r\n        net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0\r\n        net.unet_multiplier = unet_multipliers[i] if unet_multipliers else 1.0\r\n        net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0\r\n        loaded_networks.append(net)\r\n\r\n        for emb_name, embedding in net.bundle_embeddings.items():\r\n            if embedding.loaded is None and emb_name in emb_db.word_embeddings:\r\n                logger.warning(\r\n                    f'Skip bundle embedding: \"{emb_name}\"'\r\n                    ' as it was already loaded from embeddings folder'\r\n                )\r\n                continue\r\n\r\n            embedding.loaded = False\r\n            if emb_db.expected_shape == -1 or emb_db.expected_shape == embedding.shape:\r\n                embedding.loaded = True\r\n                emb_db.register_embedding(embedding, shared.sd_model)\r\n            else:\r\n                emb_db.skipped_embeddings[name] = embedding\r\n\r\n    if failed_to_load_networks:\r\n        lora_not_found_message = f'Lora not found: {\", \".join(failed_to_load_networks)}'\r\n        sd_hijack.model_hijack.comments.append(lora_not_found_message)\r\n        if shared.opts.lora_not_found_warning_console:\r\n            print(f'\\n{lora_not_found_message}\\n')\r\n        if shared.opts.lora_not_found_gradio_warning:\r\n            gr.Warning(lora_not_found_message)\r\n\r\n    purge_networks_from_memory()\r\n\r\n\r\ndef allowed_layer_without_weight(layer):\r\n    if isinstance(layer, torch.nn.LayerNorm) and not layer.elementwise_affine:\r\n        return True\r\n\r\n    return False\r\n\r\n\r\ndef store_weights_backup(weight):\r\n    if weight is None:\r\n        return None\r\n\r\n    return weight.to(devices.cpu, copy=True)\r\n\r\n\r\ndef restore_weights_backup(obj, field, weight):\r\n    if weight is None:\r\n        setattr(obj, field, None)\r\n        return\r\n\r\n    getattr(obj, field).copy_(weight)\r\n\r\n\r\ndef network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):\r\n    weights_backup = getattr(self, \"network_weights_backup\", None)\r\n    bias_backup = getattr(self, \"network_bias_backup\", None)\r\n\r\n    if weights_backup is None and bias_backup is None:\r\n        return\r\n\r\n    if weights_backup is not None:\r\n        if isinstance(self, torch.nn.MultiheadAttention):\r\n            restore_weights_backup(self, 'in_proj_weight', weights_backup[0])\r\n            restore_weights_backup(self.out_proj, 'weight', weights_backup[1])\r\n        else:\r\n            restore_weights_backup(self, 'weight', weights_backup)\r\n\r\n    if isinstance(self, torch.nn.MultiheadAttention):\r\n        restore_weights_backup(self.out_proj, 'bias', bias_backup)\r\n    else:\r\n        restore_weights_backup(self, 'bias', bias_backup)\r\n\r\n\r\ndef network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]):\r\n    \"\"\"\r\n    Applies the currently selected set of networks to the weights of torch layer self.\r\n    If weights already have this particular set of networks applied, does nothing.\r\n    If not, restores original weights from backup and alters weights according to networks.\r\n    \"\"\"\r\n\r\n    network_layer_name = getattr(self, 'network_layer_name', None)\r\n    if network_layer_name is None:\r\n        return\r\n\r\n    current_names = getattr(self, \"network_current_names\", ())\r\n    wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks)\r\n\r\n    weights_backup = getattr(self, \"network_weights_backup\", None)\r\n    if weights_backup is None and wanted_names != ():\r\n        if current_names != () and not allowed_layer_without_weight(self):\r\n            raise RuntimeError(f\"{network_layer_name} - no backup weights found and current weights are not unchanged\")\r\n\r\n        if isinstance(self, torch.nn.MultiheadAttention):\r\n            weights_backup = (store_weights_backup(self.in_proj_weight), store_weights_backup(self.out_proj.weight))\r\n        else:\r\n            weights_backup = store_weights_backup(self.weight)\r\n\r\n        self.network_weights_backup = weights_backup\r\n\r\n    bias_backup = getattr(self, \"network_bias_backup\", None)\r\n    if bias_backup is None and wanted_names != ():\r\n        if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None:\r\n            bias_backup = store_weights_backup(self.out_proj.bias)\r\n        elif getattr(self, 'bias', None) is not None:\r\n            bias_backup = store_weights_backup(self.bias)\r\n        else:\r\n            bias_backup = None\r\n\r\n        # Unlike weight which always has value, some modules don't have bias.\r\n        # Only report if bias is not None and current bias are not unchanged.\r\n        if bias_backup is not None and current_names != ():\r\n            raise RuntimeError(\"no backup bias found and current bias are not unchanged\")\r\n\r\n        self.network_bias_backup = bias_backup\r\n\r\n    if current_names != wanted_names:\r\n        network_restore_weights_from_backup(self)\r\n\r\n        for net in loaded_networks:\r\n            module = net.modules.get(network_layer_name, None)\r\n            if module is not None and hasattr(self, 'weight') and not isinstance(module, modules.models.sd3.mmdit.QkvLinear):\r\n                try:\r\n                    with torch.no_grad():\r\n                        if getattr(self, 'fp16_weight', None) is None:\r\n                            weight = self.weight\r\n                            bias = self.bias\r\n                        else:\r\n                            weight = self.fp16_weight.clone().to(self.weight.device)\r\n                            bias = getattr(self, 'fp16_bias', None)\r\n                            if bias is not None:\r\n                                bias = bias.clone().to(self.bias.device)\r\n                        updown, ex_bias = module.calc_updown(weight)\r\n\r\n                        if len(weight.shape) == 4 and weight.shape[1] == 9:\r\n                            # inpainting model. zero pad updown to make channel[1]  4 to 9\r\n                            updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))\r\n\r\n                        self.weight.copy_((weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype))\r\n                        if ex_bias is not None and hasattr(self, 'bias'):\r\n                            if self.bias is None:\r\n                                self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype)\r\n                            else:\r\n                                self.bias.copy_((bias + ex_bias).to(dtype=self.bias.dtype))\r\n                except RuntimeError as e:\r\n                    logging.debug(f\"Network {net.name} layer {network_layer_name}: {e}\")\r\n                    extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1\r\n\r\n                continue\r\n\r\n            module_q = net.modules.get(network_layer_name + \"_q_proj\", None)\r\n            module_k = net.modules.get(network_layer_name + \"_k_proj\", None)\r\n            module_v = net.modules.get(network_layer_name + \"_v_proj\", None)\r\n            module_out = net.modules.get(network_layer_name + \"_out_proj\", None)\r\n\r\n            if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:\r\n                try:\r\n                    with torch.no_grad():\r\n                        # Send \"real\" orig_weight into MHA's lora module\r\n                        qw, kw, vw = self.in_proj_weight.chunk(3, 0)\r\n                        updown_q, _ = module_q.calc_updown(qw)\r\n                        updown_k, _ = module_k.calc_updown(kw)\r\n                        updown_v, _ = module_v.calc_updown(vw)\r\n                        del qw, kw, vw\r\n                        updown_qkv = torch.vstack([updown_q, updown_k, updown_v])\r\n                        updown_out, ex_bias = module_out.calc_updown(self.out_proj.weight)\r\n\r\n                        self.in_proj_weight += updown_qkv\r\n                        self.out_proj.weight += updown_out\r\n                    if ex_bias is not None:\r\n                        if self.out_proj.bias is None:\r\n                            self.out_proj.bias = torch.nn.Parameter(ex_bias)\r\n                        else:\r\n                            self.out_proj.bias += ex_bias\r\n\r\n                except RuntimeError as e:\r\n                    logging.debug(f\"Network {net.name} layer {network_layer_name}: {e}\")\r\n                    extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1\r\n\r\n                continue\r\n\r\n            if isinstance(self, modules.models.sd3.mmdit.QkvLinear) and module_q and module_k and module_v:\r\n                try:\r\n                    with torch.no_grad():\r\n                        # Send \"real\" orig_weight into MHA's lora module\r\n                        qw, kw, vw = self.weight.chunk(3, 0)\r\n                        updown_q, _ = module_q.calc_updown(qw)\r\n                        updown_k, _ = module_k.calc_updown(kw)\r\n                        updown_v, _ = module_v.calc_updown(vw)\r\n                        del qw, kw, vw\r\n                        updown_qkv = torch.vstack([updown_q, updown_k, updown_v])\r\n                        self.weight += updown_qkv\r\n\r\n                except RuntimeError as e:\r\n                    logging.debug(f\"Network {net.name} layer {network_layer_name}: {e}\")\r\n                    extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1\r\n\r\n                continue\r\n\r\n            if module is None:\r\n                continue\r\n\r\n            logging.debug(f\"Network {net.name} layer {network_layer_name}: couldn't find supported operation\")\r\n            extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1\r\n\r\n        self.network_current_names = wanted_names\r\n\r\n\r\ndef network_forward(org_module, input, original_forward):\r\n    \"\"\"\r\n    Old way of applying Lora by executing operations during layer's forward.\r\n    Stacking many loras this way results in big performance degradation.\r\n    \"\"\"\r\n\r\n    if len(loaded_networks) == 0:\r\n        return original_forward(org_module, input)\r\n\r\n    input = devices.cond_cast_unet(input)\r\n\r\n    network_restore_weights_from_backup(org_module)\r\n    network_reset_cached_weight(org_module)\r\n\r\n    y = original_forward(org_module, input)\r\n\r\n    network_layer_name = getattr(org_module, 'network_layer_name', None)\r\n    for lora in loaded_networks:\r\n        module = lora.modules.get(network_layer_name, None)\r\n        if module is None:\r\n            continue\r\n\r\n        y = module.forward(input, y)\r\n\r\n    return y\r\n\r\n\r\ndef network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):\r\n    self.network_current_names = ()\r\n    self.network_weights_backup = None\r\n    self.network_bias_backup = None\r\n\r\n\r\ndef network_Linear_forward(self, input):\r\n    if shared.opts.lora_functional:\r\n        return network_forward(self, input, originals.Linear_forward)\r\n\r\n    network_apply_weights(self)\r\n\r\n    return originals.Linear_forward(self, input)\r\n\r\n\r\ndef network_Linear_load_state_dict(self, *args, **kwargs):\r\n    network_reset_cached_weight(self)\r\n\r\n    return originals.Linear_load_state_dict(self, *args, **kwargs)\r\n\r\n\r\ndef network_Conv2d_forward(self, input):\r\n    if shared.opts.lora_functional:\r\n        return network_forward(self, input, originals.Conv2d_forward)\r\n\r\n    network_apply_weights(self)\r\n\r\n    return originals.Conv2d_forward(self, input)\r\n\r\n\r\ndef network_Conv2d_load_state_dict(self, *args, **kwargs):\r\n    network_reset_cached_weight(self)\r\n\r\n    return originals.Conv2d_load_state_dict(self, *args, **kwargs)\r\n\r\n\r\ndef network_GroupNorm_forward(self, input):\r\n    if shared.opts.lora_functional:\r\n        return network_forward(self, input, originals.GroupNorm_forward)\r\n\r\n    network_apply_weights(self)\r\n\r\n    return originals.GroupNorm_forward(self, input)\r\n\r\n\r\ndef network_GroupNorm_load_state_dict(self, *args, **kwargs):\r\n    network_reset_cached_weight(self)\r\n\r\n    return originals.GroupNorm_load_state_dict(self, *args, **kwargs)\r\n\r\n\r\ndef network_LayerNorm_forward(self, input):\r\n    if shared.opts.lora_functional:\r\n        return network_forward(self, input, originals.LayerNorm_forward)\r\n\r\n    network_apply_weights(self)\r\n\r\n    return originals.LayerNorm_forward(self, input)\r\n\r\n\r\ndef network_LayerNorm_load_state_dict(self, *args, **kwargs):\r\n    network_reset_cached_weight(self)\r\n\r\n    return originals.LayerNorm_load_state_dict(self, *args, **kwargs)\r\n\r\n\r\ndef network_MultiheadAttention_forward(self, *args, **kwargs):\r\n    network_apply_weights(self)\r\n\r\n    return originals.MultiheadAttention_forward(self, *args, **kwargs)\r\n\r\n\r\ndef network_MultiheadAttention_load_state_dict(self, *args, **kwargs):\r\n    network_reset_cached_weight(self)\r\n\r\n    return originals.MultiheadAttention_load_state_dict(self, *args, **kwargs)\r\n\r\n\r\ndef process_network_files(names: list[str] | None = None):\r\n    candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[\".pt\", \".ckpt\", \".safetensors\"]))\r\n    candidates += list(shared.walk_files(shared.cmd_opts.lyco_dir_backcompat, allowed_extensions=[\".pt\", \".ckpt\", \".safetensors\"]))\r\n    for filename in candidates:\r\n        if os.path.isdir(filename):\r\n            continue\r\n        name = os.path.splitext(os.path.basename(filename))[0]\r\n        # if names is provided, only load networks with names in the list\r\n        if names and name not in names:\r\n            continue\r\n        try:\r\n            entry = network.NetworkOnDisk(name, filename)\r\n        except OSError:  # should catch FileNotFoundError and PermissionError etc.\r\n            errors.report(f\"Failed to load network {name} from {filename}\", exc_info=True)\r\n            continue\r\n\r\n        available_networks[name] = entry\r\n\r\n        if entry.alias in available_network_aliases:\r\n            forbidden_network_aliases[entry.alias.lower()] = 1\r\n\r\n        available_network_aliases[name] = entry\r\n        available_network_aliases[entry.alias] = entry\r\n\r\n\r\ndef update_available_networks_by_names(names: list[str]):\r\n    process_network_files(names)\r\n\r\n\r\ndef list_available_networks():\r\n    available_networks.clear()\r\n    available_network_aliases.clear()\r\n    forbidden_network_aliases.clear()\r\n    available_network_hash_lookup.clear()\r\n    forbidden_network_aliases.update({\"none\": 1, \"Addams\": 1})\r\n\r\n    os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)\r\n\r\n    process_network_files()\r\n\r\n\r\nre_network_name = re.compile(r\"(.*)\\s*\\([0-9a-fA-F]+\\)\")\r\n\r\n\r\ndef infotext_pasted(infotext, params):\r\n    if \"AddNet Module 1\" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:\r\n        return  # if the other extension is active, it will handle those fields, no need to do anything\r\n\r\n    added = []\r\n\r\n    for k in params:\r\n        if not k.startswith(\"AddNet Model \"):\r\n            continue\r\n\r\n        num = k[13:]\r\n\r\n        if params.get(\"AddNet Module \" + num) != \"LoRA\":\r\n            continue\r\n\r\n        name = params.get(\"AddNet Model \" + num)\r\n        if name is None:\r\n            continue\r\n\r\n        m = re_network_name.match(name)\r\n        if m:\r\n            name = m.group(1)\r\n\r\n        multiplier = params.get(\"AddNet Weight A \" + num, \"1.0\")\r\n\r\n        added.append(f\"<lora:{name}:{multiplier}>\")\r\n\r\n    if added:\r\n        params[\"Prompt\"] += \"\\n\" + \"\".join(added)\r\n\r\n\r\noriginals: lora_patches.LoraPatches = None\r\n\r\nextra_network_lora = None\r\n\r\navailable_networks = {}\r\navailable_network_aliases = {}\r\nloaded_networks = []\r\nloaded_bundle_embeddings = {}\r\nnetworks_in_memory = {}\r\navailable_network_hash_lookup = {}\r\nforbidden_network_aliases = {}\r\n\r\nlist_available_networks()\r\n"
  },
  {
    "path": "extensions-builtin/Lora/preload.py",
    "content": "import os\r\nfrom modules import paths\r\nfrom modules.paths_internal import normalized_filepath\r\n\r\n\r\ndef preload(parser):\r\n    parser.add_argument(\"--lora-dir\", type=normalized_filepath, help=\"Path to directory with Lora networks.\", default=os.path.join(paths.models_path, 'Lora'))\r\n    parser.add_argument(\"--lyco-dir-backcompat\", type=normalized_filepath, help=\"Path to directory with LyCORIS networks (for backawards compatibility; can also use --lyco-dir).\", default=os.path.join(paths.models_path, 'LyCORIS'))\r\n"
  },
  {
    "path": "extensions-builtin/Lora/scripts/lora_script.py",
    "content": "import re\r\n\r\nimport gradio as gr\r\nfrom fastapi import FastAPI\r\n\r\nimport network\r\nimport networks\r\nimport lora  # noqa:F401\r\nimport lora_patches\r\nimport extra_networks_lora\r\nimport ui_extra_networks_lora\r\nfrom modules import script_callbacks, ui_extra_networks, extra_networks, shared\r\n\r\n\r\ndef unload():\r\n    networks.originals.undo()\r\n\r\n\r\ndef before_ui():\r\n    ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())\r\n\r\n    networks.extra_network_lora = extra_networks_lora.ExtraNetworkLora()\r\n    extra_networks.register_extra_network(networks.extra_network_lora)\r\n    extra_networks.register_extra_network_alias(networks.extra_network_lora, \"lyco\")\r\n\r\n\r\nnetworks.originals = lora_patches.LoraPatches()\r\n\r\nscript_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules)\r\nscript_callbacks.on_script_unloaded(unload)\r\nscript_callbacks.on_before_ui(before_ui)\r\nscript_callbacks.on_infotext_pasted(networks.infotext_pasted)\r\n\r\n\r\nshared.options_templates.update(shared.options_section(('extra_networks', \"Extra Networks\"), {\r\n    \"sd_lora\": shared.OptionInfo(\"None\", \"Add network to prompt\", gr.Dropdown, lambda: {\"choices\": [\"None\", *networks.available_networks]}, refresh=networks.list_available_networks),\r\n    \"lora_preferred_name\": shared.OptionInfo(\"Alias from file\", \"When adding to prompt, refer to Lora by\", gr.Radio, {\"choices\": [\"Alias from file\", \"Filename\"]}),\r\n    \"lora_add_hashes_to_infotext\": shared.OptionInfo(True, \"Add Lora hashes to infotext\"),\r\n    \"lora_bundled_ti_to_infotext\": shared.OptionInfo(True, \"Add Lora name as TI hashes for bundled Textual Inversion\").info('\"Add Textual Inversion hashes to infotext\" needs to be enabled'),\r\n    \"lora_show_all\": shared.OptionInfo(False, \"Always show all networks on the Lora page\").info(\"otherwise, those detected as for incompatible version of Stable Diffusion will be hidden\"),\r\n    \"lora_hide_unknown_for_versions\": shared.OptionInfo([], \"Hide networks of unknown versions for model versions\", gr.CheckboxGroup, {\"choices\": [\"SD1\", \"SD2\", \"SDXL\"]}),\r\n    \"lora_in_memory_limit\": shared.OptionInfo(0, \"Number of Lora networks to keep cached in memory\", gr.Number, {\"precision\": 0}),\r\n    \"lora_not_found_warning_console\": shared.OptionInfo(False, \"Lora not found warning in console\"),\r\n    \"lora_not_found_gradio_warning\": shared.OptionInfo(False, \"Lora not found warning popup in webui\"),\r\n}))\r\n\r\n\r\nshared.options_templates.update(shared.options_section(('compatibility', \"Compatibility\"), {\r\n    \"lora_functional\": shared.OptionInfo(False, \"Lora/Networks: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension\"),\r\n}))\r\n\r\n\r\ndef create_lora_json(obj: network.NetworkOnDisk):\r\n    return {\r\n        \"name\": obj.name,\r\n        \"alias\": obj.alias,\r\n        \"path\": obj.filename,\r\n        \"metadata\": obj.metadata,\r\n    }\r\n\r\n\r\ndef api_networks(_: gr.Blocks, app: FastAPI):\r\n    @app.get(\"/sdapi/v1/loras\")\r\n    async def get_loras():\r\n        return [create_lora_json(obj) for obj in networks.available_networks.values()]\r\n\r\n    @app.post(\"/sdapi/v1/refresh-loras\")\r\n    async def refresh_loras():\r\n        return networks.list_available_networks()\r\n\r\n\r\nscript_callbacks.on_app_started(api_networks)\r\n\r\nre_lora = re.compile(\"<lora:([^:]+):\")\r\n\r\n\r\ndef infotext_pasted(infotext, d):\r\n    hashes = d.get(\"Lora hashes\")\r\n    if not hashes:\r\n        return\r\n\r\n    hashes = [x.strip().split(':', 1) for x in hashes.split(\",\")]\r\n    hashes = {x[0].strip().replace(\",\", \"\"): x[1].strip() for x in hashes}\r\n\r\n    def network_replacement(m):\r\n        alias = m.group(1)\r\n        shorthash = hashes.get(alias)\r\n        if shorthash is None:\r\n            return m.group(0)\r\n\r\n        network_on_disk = networks.available_network_hash_lookup.get(shorthash)\r\n        if network_on_disk is None:\r\n            return m.group(0)\r\n\r\n        return f'<lora:{network_on_disk.get_alias()}:'\r\n\r\n    d[\"Prompt\"] = re.sub(re_lora, network_replacement, d[\"Prompt\"])\r\n\r\n\r\nscript_callbacks.on_infotext_pasted(infotext_pasted)\r\n\r\nshared.opts.onchange(\"lora_in_memory_limit\", networks.purge_networks_from_memory)\r\n"
  },
  {
    "path": "extensions-builtin/Lora/ui_edit_user_metadata.py",
    "content": "import datetime\r\nimport html\r\nimport random\r\n\r\nimport gradio as gr\r\nimport re\r\n\r\nfrom modules import ui_extra_networks_user_metadata\r\n\r\n\r\ndef is_non_comma_tagset(tags):\r\n    average_tag_length = sum(len(x) for x in tags.keys()) / len(tags)\r\n\r\n    return average_tag_length >= 16\r\n\r\n\r\nre_word = re.compile(r\"[-_\\w']+\")\r\nre_comma = re.compile(r\" *, *\")\r\n\r\n\r\ndef build_tags(metadata):\r\n    tags = {}\r\n\r\n    ss_tag_frequency = metadata.get(\"ss_tag_frequency\", {})\r\n    if ss_tag_frequency is not None and hasattr(ss_tag_frequency, 'items'):\r\n        for _, tags_dict in ss_tag_frequency.items():\r\n            for tag, tag_count in tags_dict.items():\r\n                tag = tag.strip()\r\n                tags[tag] = tags.get(tag, 0) + int(tag_count)\r\n\r\n    if tags and is_non_comma_tagset(tags):\r\n        new_tags = {}\r\n\r\n        for text, text_count in tags.items():\r\n            for word in re.findall(re_word, text):\r\n                if len(word) < 3:\r\n                    continue\r\n\r\n                new_tags[word] = new_tags.get(word, 0) + text_count\r\n\r\n        tags = new_tags\r\n\r\n    ordered_tags = sorted(tags.keys(), key=tags.get, reverse=True)\r\n\r\n    return [(tag, tags[tag]) for tag in ordered_tags]\r\n\r\n\r\nclass LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor):\r\n    def __init__(self, ui, tabname, page):\r\n        super().__init__(ui, tabname, page)\r\n\r\n        self.select_sd_version = None\r\n\r\n        self.taginfo = None\r\n        self.edit_activation_text = None\r\n        self.slider_preferred_weight = None\r\n        self.edit_notes = None\r\n\r\n    def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, negative_text, notes):\r\n        user_metadata = self.get_user_metadata(name)\r\n        user_metadata[\"description\"] = desc\r\n        user_metadata[\"sd version\"] = sd_version\r\n        user_metadata[\"activation text\"] = activation_text\r\n        user_metadata[\"preferred weight\"] = preferred_weight\r\n        user_metadata[\"negative text\"] = negative_text\r\n        user_metadata[\"notes\"] = notes\r\n\r\n        self.write_user_metadata(name, user_metadata)\r\n\r\n    def get_metadata_table(self, name):\r\n        table = super().get_metadata_table(name)\r\n        item = self.page.items.get(name, {})\r\n        metadata = item.get(\"metadata\") or {}\r\n\r\n        keys = {\r\n            'ss_output_name': \"Output name:\",\r\n            'ss_sd_model_name': \"Model:\",\r\n            'ss_clip_skip': \"Clip skip:\",\r\n            'ss_network_module': \"Kohya module:\",\r\n        }\r\n\r\n        for key, label in keys.items():\r\n            value = metadata.get(key, None)\r\n            if value is not None and str(value) != \"None\":\r\n                table.append((label, html.escape(value)))\r\n\r\n        ss_training_started_at = metadata.get('ss_training_started_at')\r\n        if ss_training_started_at:\r\n            table.append((\"Date trained:\", datetime.datetime.utcfromtimestamp(float(ss_training_started_at)).strftime('%Y-%m-%d %H:%M')))\r\n\r\n        ss_bucket_info = metadata.get(\"ss_bucket_info\")\r\n        if ss_bucket_info and \"buckets\" in ss_bucket_info:\r\n            resolutions = {}\r\n            for _, bucket in ss_bucket_info[\"buckets\"].items():\r\n                resolution = bucket[\"resolution\"]\r\n                resolution = f'{resolution[1]}x{resolution[0]}'\r\n\r\n                resolutions[resolution] = resolutions.get(resolution, 0) + int(bucket[\"count\"])\r\n\r\n            resolutions_list = sorted(resolutions.keys(), key=resolutions.get, reverse=True)\r\n            resolutions_text = html.escape(\", \".join(resolutions_list[0:4]))\r\n            if len(resolutions) > 4:\r\n                resolutions_text += \", ...\"\r\n                resolutions_text = f\"<span title='{html.escape(', '.join(resolutions_list))}'>{resolutions_text}</span>\"\r\n\r\n            table.append(('Resolutions:' if len(resolutions_list) > 1 else 'Resolution:', resolutions_text))\r\n\r\n        image_count = 0\r\n        for _, params in metadata.get(\"ss_dataset_dirs\", {}).items():\r\n            image_count += int(params.get(\"img_count\", 0))\r\n\r\n        if image_count:\r\n            table.append((\"Dataset size:\", image_count))\r\n\r\n        return table\r\n\r\n    def put_values_into_components(self, name):\r\n        user_metadata = self.get_user_metadata(name)\r\n        values = super().put_values_into_components(name)\r\n\r\n        item = self.page.items.get(name, {})\r\n        metadata = item.get(\"metadata\") or {}\r\n\r\n        tags = build_tags(metadata)\r\n        gradio_tags = [(tag, str(count)) for tag, count in tags[0:24]]\r\n\r\n        return [\r\n            *values[0:5],\r\n            item.get(\"sd_version\", \"Unknown\"),\r\n            gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False),\r\n            user_metadata.get('activation text', ''),\r\n            float(user_metadata.get('preferred weight', 0.0)),\r\n            user_metadata.get('negative text', ''),\r\n            gr.update(visible=True if tags else False),\r\n            gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False),\r\n        ]\r\n\r\n    def generate_random_prompt(self, name):\r\n        item = self.page.items.get(name, {})\r\n        metadata = item.get(\"metadata\") or {}\r\n        tags = build_tags(metadata)\r\n\r\n        return self.generate_random_prompt_from_tags(tags)\r\n\r\n    def generate_random_prompt_from_tags(self, tags):\r\n        max_count = None\r\n        res = []\r\n        for tag, count in tags:\r\n            if not max_count:\r\n                max_count = count\r\n\r\n            v = random.random() * max_count\r\n            if count > v:\r\n                for x in \"({[]})\":\r\n                    tag = tag.replace(x, '\\\\' + x)\r\n                res.append(tag)\r\n\r\n        return \", \".join(sorted(res))\r\n\r\n    def create_extra_default_items_in_left_column(self):\r\n\r\n        # this would be a lot better as gr.Radio but I can't make it work\r\n        self.select_sd_version = gr.Dropdown(['SD1', 'SD2', 'SDXL', 'Unknown'], value='Unknown', label='Stable Diffusion version', interactive=True)\r\n\r\n    def create_editor(self):\r\n        self.create_default_editor_elems()\r\n\r\n        self.taginfo = gr.HighlightedText(label=\"Training dataset tags\")\r\n        self.edit_activation_text = gr.Text(label='Activation text', info=\"Will be added to prompt along with Lora\")\r\n        self.slider_preferred_weight = gr.Slider(label='Preferred weight', info=\"Set to 0 to disable\", minimum=0.0, maximum=2.0, step=0.01)\r\n        self.edit_negative_text = gr.Text(label='Negative prompt', info=\"Will be added to negative prompts\")\r\n        with gr.Row() as row_random_prompt:\r\n            with gr.Column(scale=8):\r\n                random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False)\r\n\r\n            with gr.Column(scale=1, min_width=120):\r\n                generate_random_prompt = gr.Button('Generate', size=\"lg\", scale=1)\r\n\r\n        self.edit_notes = gr.TextArea(label='Notes', lines=4)\r\n\r\n        generate_random_prompt.click(fn=self.generate_random_prompt, inputs=[self.edit_name_input], outputs=[random_prompt], show_progress=False)\r\n\r\n        def select_tag(activation_text, evt: gr.SelectData):\r\n            tag = evt.value[0]\r\n\r\n            words = re.split(re_comma, activation_text)\r\n            if tag in words:\r\n                words = [x for x in words if x != tag and x.strip()]\r\n                return \", \".join(words)\r\n\r\n            return activation_text + \", \" + tag if activation_text else tag\r\n\r\n        self.taginfo.select(fn=select_tag, inputs=[self.edit_activation_text], outputs=[self.edit_activation_text], show_progress=False)\r\n\r\n        self.create_default_buttons()\r\n\r\n        viewed_components = [\r\n            self.edit_name,\r\n            self.edit_description,\r\n            self.html_filedata,\r\n            self.html_preview,\r\n            self.edit_notes,\r\n            self.select_sd_version,\r\n            self.taginfo,\r\n            self.edit_activation_text,\r\n            self.slider_preferred_weight,\r\n            self.edit_negative_text,\r\n            row_random_prompt,\r\n            random_prompt,\r\n        ]\r\n\r\n        self.button_edit\\\r\n            .click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=viewed_components)\\\r\n            .then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box])\r\n\r\n        edited_components = [\r\n            self.edit_description,\r\n            self.select_sd_version,\r\n            self.edit_activation_text,\r\n            self.slider_preferred_weight,\r\n            self.edit_negative_text,\r\n            self.edit_notes,\r\n        ]\r\n\r\n\r\n        self.setup_save_handler(self.button_save, self.save_lora_user_metadata, edited_components)\r\n"
  },
  {
    "path": "extensions-builtin/Lora/ui_extra_networks_lora.py",
    "content": "import os\r\n\r\nimport network\r\nimport networks\r\n\r\nfrom modules import shared, ui_extra_networks\r\nfrom modules.ui_extra_networks import quote_js\r\nfrom ui_edit_user_metadata import LoraUserMetadataEditor\r\n\r\n\r\nclass ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):\r\n    def __init__(self):\r\n        super().__init__('Lora')\r\n\r\n    def refresh(self):\r\n        networks.list_available_networks()\r\n\r\n    def create_item(self, name, index=None, enable_filter=True):\r\n        lora_on_disk = networks.available_networks.get(name)\r\n        if lora_on_disk is None:\r\n            return\r\n\r\n        path, ext = os.path.splitext(lora_on_disk.filename)\r\n\r\n        alias = lora_on_disk.get_alias()\r\n\r\n        search_terms = [self.search_terms_from_path(lora_on_disk.filename)]\r\n        if lora_on_disk.hash:\r\n            search_terms.append(lora_on_disk.hash)\r\n        item = {\r\n            \"name\": name,\r\n            \"filename\": lora_on_disk.filename,\r\n            \"shorthash\": lora_on_disk.shorthash,\r\n            \"preview\": self.find_preview(path) or self.find_embedded_preview(path, name, lora_on_disk.metadata),\r\n            \"description\": self.find_description(path),\r\n            \"search_terms\": search_terms,\r\n            \"local_preview\": f\"{path}.{shared.opts.samples_format}\",\r\n            \"metadata\": lora_on_disk.metadata,\r\n            \"sort_keys\": {'default': index, **self.get_sort_keys(lora_on_disk.filename)},\r\n            \"sd_version\": lora_on_disk.sd_version.name,\r\n        }\r\n\r\n        self.read_user_metadata(item)\r\n        activation_text = item[\"user_metadata\"].get(\"activation text\")\r\n        preferred_weight = item[\"user_metadata\"].get(\"preferred weight\", 0.0)\r\n        item[\"prompt\"] = quote_js(f\"<lora:{alias}:\") + \" + \" + (str(preferred_weight) if preferred_weight else \"opts.extra_networks_default_multiplier\") + \" + \" + quote_js(\">\")\r\n\r\n        if activation_text:\r\n            item[\"prompt\"] += \" + \" + quote_js(\" \" + activation_text)\r\n\r\n        negative_prompt = item[\"user_metadata\"].get(\"negative text\")\r\n        item[\"negative_prompt\"] = quote_js(\"\")\r\n        if negative_prompt:\r\n            item[\"negative_prompt\"] = quote_js('(' + negative_prompt + ':1)')\r\n\r\n        sd_version = item[\"user_metadata\"].get(\"sd version\")\r\n        if sd_version in network.SdVersion.__members__:\r\n            item[\"sd_version\"] = sd_version\r\n            sd_version = network.SdVersion[sd_version]\r\n        else:\r\n            sd_version = lora_on_disk.sd_version\r\n\r\n        if shared.opts.lora_show_all or not enable_filter or not shared.sd_model:\r\n            pass\r\n        elif sd_version == network.SdVersion.Unknown:\r\n            model_version = network.SdVersion.SDXL if shared.sd_model.is_sdxl else network.SdVersion.SD2 if shared.sd_model.is_sd2 else network.SdVersion.SD1\r\n            if model_version.name in shared.opts.lora_hide_unknown_for_versions:\r\n                return None\r\n        elif shared.sd_model.is_sdxl and sd_version != network.SdVersion.SDXL:\r\n            return None\r\n        elif shared.sd_model.is_sd2 and sd_version != network.SdVersion.SD2:\r\n            return None\r\n        elif shared.sd_model.is_sd1 and sd_version != network.SdVersion.SD1:\r\n            return None\r\n\r\n        return item\r\n\r\n    def list_items(self):\r\n        # instantiate a list to protect against concurrent modification\r\n        names = list(networks.available_networks)\r\n        for index, name in enumerate(names):\r\n            item = self.create_item(name, index)\r\n            if item is not None:\r\n                yield item\r\n\r\n    def allowed_directories_for_previews(self):\r\n        return [shared.cmd_opts.lora_dir, shared.cmd_opts.lyco_dir_backcompat]\r\n\r\n    def create_user_metadata_editor(self, ui, tabname):\r\n        return LoraUserMetadataEditor(ui, tabname, self)\r\n"
  },
  {
    "path": "extensions-builtin/ScuNET/preload.py",
    "content": "import os\r\nfrom modules import paths\r\n\r\n\r\ndef preload(parser):\r\n    parser.add_argument(\"--scunet-models-path\", type=str, help=\"Path to directory with ScuNET model file(s).\", default=os.path.join(paths.models_path, 'ScuNET'))\r\n"
  },
  {
    "path": "extensions-builtin/ScuNET/scripts/scunet_model.py",
    "content": "import sys\n\nimport PIL.Image\n\nimport modules.upscaler\nfrom modules import devices, errors, modelloader, script_callbacks, shared, upscaler_utils\n\n\nclass UpscalerScuNET(modules.upscaler.Upscaler):\n    def __init__(self, dirname):\n        self.name = \"ScuNET\"\n        self.model_name = \"ScuNET GAN\"\n        self.model_name2 = \"ScuNET PSNR\"\n        self.model_url = \"https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_gan.pth\"\n        self.model_url2 = \"https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_psnr.pth\"\n        self.user_path = dirname\n        super().__init__()\n        model_paths = self.find_models(ext_filter=[\".pth\"])\n        scalers = []\n        add_model2 = True\n        for file in model_paths:\n            if file.startswith(\"http\"):\n                name = self.model_name\n            else:\n                name = modelloader.friendly_name(file)\n            if name == self.model_name2 or file == self.model_url2:\n                add_model2 = False\n            try:\n                scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)\n                scalers.append(scaler_data)\n            except Exception:\n                errors.report(f\"Error loading ScuNET model: {file}\", exc_info=True)\n        if add_model2:\n            scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self)\n            scalers.append(scaler_data2)\n        self.scalers = scalers\n\n    def do_upscale(self, img: PIL.Image.Image, selected_file):\n        devices.torch_gc()\n        try:\n            model = self.load_model(selected_file)\n        except Exception as e:\n            print(f\"ScuNET: Unable to load model from {selected_file}: {e}\", file=sys.stderr)\n            return img\n\n        img = upscaler_utils.upscale_2(\n            img,\n            model,\n            tile_size=shared.opts.SCUNET_tile,\n            tile_overlap=shared.opts.SCUNET_tile_overlap,\n            scale=1,  # ScuNET is a denoising model, not an upscaler\n            desc='ScuNET',\n        )\n        devices.torch_gc()\n        return img\n\n    def load_model(self, path: str):\n        device = devices.get_device_for('scunet')\n        if path.startswith(\"http\"):\n            # TODO: this doesn't use `path` at all?\n            filename = modelloader.load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f\"{self.name}.pth\")\n        else:\n            filename = path\n        return modelloader.load_spandrel_model(filename, device=device, expected_architecture='SCUNet')\n\n\ndef on_ui_settings():\n    import gradio as gr\n\n    shared.opts.add_option(\"SCUNET_tile\", shared.OptionInfo(256, \"Tile size for SCUNET upscalers.\", gr.Slider, {\"minimum\": 0, \"maximum\": 512, \"step\": 16}, section=('upscaling', \"Upscaling\")).info(\"0 = no tiling\"))\n    shared.opts.add_option(\"SCUNET_tile_overlap\", shared.OptionInfo(8, \"Tile overlap for SCUNET upscalers.\", gr.Slider, {\"minimum\": 0, \"maximum\": 64, \"step\": 1}, section=('upscaling', \"Upscaling\")).info(\"Low values = visible seam\"))\n\n\nscript_callbacks.on_ui_settings(on_ui_settings)\n"
  },
  {
    "path": "extensions-builtin/SwinIR/preload.py",
    "content": "import os\r\nfrom modules import paths\r\n\r\n\r\ndef preload(parser):\r\n    parser.add_argument(\"--swinir-models-path\", type=str, help=\"Path to directory with SwinIR model file(s).\", default=os.path.join(paths.models_path, 'SwinIR'))\r\n"
  },
  {
    "path": "extensions-builtin/SwinIR/scripts/swinir_model.py",
    "content": "import logging\nimport sys\n\nimport torch\nfrom PIL import Image\n\nfrom modules import devices, modelloader, script_callbacks, shared, upscaler_utils\nfrom modules.upscaler import Upscaler, UpscalerData\n\nSWINIR_MODEL_URL = \"https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth\"\n\nlogger = logging.getLogger(__name__)\n\n\nclass UpscalerSwinIR(Upscaler):\n    def __init__(self, dirname):\n        self._cached_model = None           # keep the model when SWIN_torch_compile is on to prevent re-compile every runs\n        self._cached_model_config = None    # to clear '_cached_model' when changing model (v1/v2) or settings\n        self.name = \"SwinIR\"\n        self.model_url = SWINIR_MODEL_URL\n        self.model_name = \"SwinIR 4x\"\n        self.user_path = dirname\n        super().__init__()\n        scalers = []\n        model_files = self.find_models(ext_filter=[\".pt\", \".pth\"])\n        for model in model_files:\n            if model.startswith(\"http\"):\n                name = self.model_name\n            else:\n                name = modelloader.friendly_name(model)\n            model_data = UpscalerData(name, model, self)\n            scalers.append(model_data)\n        self.scalers = scalers\n\n    def do_upscale(self, img: Image.Image, model_file: str) -> Image.Image:\n        current_config = (model_file, shared.opts.SWIN_tile)\n\n        if self._cached_model_config == current_config:\n            model = self._cached_model\n        else:\n            try:\n                model = self.load_model(model_file)\n            except Exception as e:\n                print(f\"Failed loading SwinIR model {model_file}: {e}\", file=sys.stderr)\n                return img\n            self._cached_model = model\n            self._cached_model_config = current_config\n\n        img = upscaler_utils.upscale_2(\n            img,\n            model,\n            tile_size=shared.opts.SWIN_tile,\n            tile_overlap=shared.opts.SWIN_tile_overlap,\n            scale=model.scale,\n            desc=\"SwinIR\",\n        )\n        devices.torch_gc()\n        return img\n\n    def load_model(self, path, scale=4):\n        if path.startswith(\"http\"):\n            filename = modelloader.load_file_from_url(\n                url=path,\n                model_dir=self.model_download_path,\n                file_name=f\"{self.model_name.replace(' ', '_')}.pth\",\n            )\n        else:\n            filename = path\n\n        model_descriptor = modelloader.load_spandrel_model(\n            filename,\n            device=self._get_device(),\n            prefer_half=(devices.dtype == torch.float16),\n            expected_architecture=\"SwinIR\",\n        )\n        if getattr(shared.opts, 'SWIN_torch_compile', False):\n            try:\n                model_descriptor.model.compile()\n            except Exception:\n                logger.warning(\"Failed to compile SwinIR model, fallback to JIT\", exc_info=True)\n        return model_descriptor\n\n    def _get_device(self):\n        return devices.get_device_for('swinir')\n\n\ndef on_ui_settings():\n    import gradio as gr\n\n    shared.opts.add_option(\"SWIN_tile\", shared.OptionInfo(192, \"Tile size for all SwinIR.\", gr.Slider, {\"minimum\": 16, \"maximum\": 512, \"step\": 16}, section=('upscaling', \"Upscaling\")))\n    shared.opts.add_option(\"SWIN_tile_overlap\", shared.OptionInfo(8, \"Tile overlap, in pixels for SwinIR. Low values = visible seam.\", gr.Slider, {\"minimum\": 0, \"maximum\": 48, \"step\": 1}, section=('upscaling', \"Upscaling\")))\n    shared.opts.add_option(\"SWIN_torch_compile\", shared.OptionInfo(False, \"Use torch.compile to accelerate SwinIR.\", gr.Checkbox, {\"interactive\": True}, section=('upscaling', \"Upscaling\")).info(\"Takes longer on first run\"))\n\n\nscript_callbacks.on_ui_settings(on_ui_settings)\n"
  },
  {
    "path": "extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js",
    "content": "onUiLoaded(async() => {\n    const elementIDs = {\n        img2imgTabs: \"#mode_img2img .tab-nav\",\n        inpaint: \"#img2maskimg\",\n        inpaintSketch: \"#inpaint_sketch\",\n        rangeGroup: \"#img2img_column_size\",\n        sketch: \"#img2img_sketch\"\n    };\n    const tabNameToElementId = {\n        \"Inpaint sketch\": elementIDs.inpaintSketch,\n        \"Inpaint\": elementIDs.inpaint,\n        \"Sketch\": elementIDs.sketch\n    };\n\n\n    // Helper functions\n    // Get active tab\n\n    /**\n     * Waits for an element to be present in the DOM.\n     */\n    const waitForElement = (id) => new Promise(resolve => {\n        const checkForElement = () => {\n            const element = document.querySelector(id);\n            if (element) return resolve(element);\n            setTimeout(checkForElement, 100);\n        };\n        checkForElement();\n    });\n\n    function getActiveTab(elements, all = false) {\n        if (!elements.img2imgTabs) return null;\n        const tabs = elements.img2imgTabs.querySelectorAll(\"button\");\n\n        if (all) return tabs;\n\n        for (let tab of tabs) {\n            if (tab.classList.contains(\"selected\")) {\n                return tab;\n            }\n        }\n    }\n\n    // Get tab ID\n    function getTabId(elements) {\n        const activeTab = getActiveTab(elements);\n        if (!activeTab) return null;\n        return tabNameToElementId[activeTab.innerText];\n    }\n\n    // Wait until opts loaded\n    async function waitForOpts() {\n        for (; ;) {\n            if (window.opts && Object.keys(window.opts).length) {\n                return window.opts;\n            }\n            await new Promise(resolve => setTimeout(resolve, 100));\n        }\n    }\n\n    // Detect whether the element has a horizontal scroll bar\n    function hasHorizontalScrollbar(element) {\n        return element.scrollWidth > element.clientWidth;\n    }\n\n    // Function for defining the \"Ctrl\", \"Shift\" and \"Alt\" keys\n    function isModifierKey(event, key) {\n        switch (key) {\n        case \"Ctrl\":\n            return event.ctrlKey;\n        case \"Shift\":\n            return event.shiftKey;\n        case \"Alt\":\n            return event.altKey;\n        default:\n            return false;\n        }\n    }\n\n    // Check if hotkey is valid\n    function isValidHotkey(value) {\n        const specialKeys = [\"Ctrl\", \"Alt\", \"Shift\", \"Disable\"];\n        return (\n            (typeof value === \"string\" &&\n                value.length === 1 &&\n                /[a-z]/i.test(value)) ||\n            specialKeys.includes(value)\n        );\n    }\n\n    // Normalize hotkey\n    function normalizeHotkey(hotkey) {\n        return hotkey.length === 1 ? \"Key\" + hotkey.toUpperCase() : hotkey;\n    }\n\n    // Format hotkey for display\n    function formatHotkeyForDisplay(hotkey) {\n        return hotkey.startsWith(\"Key\") ? hotkey.slice(3) : hotkey;\n    }\n\n    // Create hotkey configuration with the provided options\n    function createHotkeyConfig(defaultHotkeysConfig, hotkeysConfigOpts) {\n        const result = {}; // Resulting hotkey configuration\n        const usedKeys = new Set(); // Set of used hotkeys\n\n        // Iterate through defaultHotkeysConfig keys\n        for (const key in defaultHotkeysConfig) {\n            const userValue = hotkeysConfigOpts[key]; // User-provided hotkey value\n            const defaultValue = defaultHotkeysConfig[key]; // Default hotkey value\n\n            // Apply appropriate value for undefined, boolean, or object userValue\n            if (\n                userValue === undefined ||\n                typeof userValue === \"boolean\" ||\n                typeof userValue === \"object\" ||\n                userValue === \"disable\"\n            ) {\n                result[key] =\n                    userValue === undefined ? defaultValue : userValue;\n            } else if (isValidHotkey(userValue)) {\n                const normalizedUserValue = normalizeHotkey(userValue);\n\n                // Check for conflicting hotkeys\n                if (!usedKeys.has(normalizedUserValue)) {\n                    usedKeys.add(normalizedUserValue);\n                    result[key] = normalizedUserValue;\n                } else {\n                    console.error(\n                        `Hotkey: ${formatHotkeyForDisplay(\n                            userValue\n                        )} for ${key} is repeated and conflicts with another hotkey. The default hotkey is used: ${formatHotkeyForDisplay(\n                            defaultValue\n                        )}`\n                    );\n                    result[key] = defaultValue;\n                }\n            } else {\n                console.error(\n                    `Hotkey: ${formatHotkeyForDisplay(\n                        userValue\n                    )} for ${key} is not valid. The default hotkey is used: ${formatHotkeyForDisplay(\n                        defaultValue\n                    )}`\n                );\n                result[key] = defaultValue;\n            }\n        }\n\n        return result;\n    }\n\n    // Disables functions in the config object based on the provided list of function names\n    function disableFunctions(config, disabledFunctions) {\n        // Bind the hasOwnProperty method to the functionMap object to avoid errors\n        const hasOwnProperty =\n            Object.prototype.hasOwnProperty.bind(functionMap);\n\n        // Loop through the disabledFunctions array and disable the corresponding functions in the config object\n        disabledFunctions.forEach(funcName => {\n            if (hasOwnProperty(funcName)) {\n                const key = functionMap[funcName];\n                config[key] = \"disable\";\n            }\n        });\n\n        // Return the updated config object\n        return config;\n    }\n\n    /**\n     * The restoreImgRedMask function displays a red mask around an image to indicate the aspect ratio.\n     * If the image display property is set to 'none', the mask breaks. To fix this, the function\n     * temporarily sets the display property to 'block' and then hides the mask again after 300 milliseconds\n     * to avoid breaking the canvas. Additionally, the function adjusts the mask to work correctly on\n     * very long images.\n     */\n    function restoreImgRedMask(elements) {\n        const mainTabId = getTabId(elements);\n\n        if (!mainTabId) return;\n\n        const mainTab = gradioApp().querySelector(mainTabId);\n        const img = mainTab.querySelector(\"img\");\n        const imageARPreview = gradioApp().querySelector(\"#imageARPreview\");\n\n        if (!img || !imageARPreview) return;\n\n        imageARPreview.style.transform = \"\";\n        if (parseFloat(mainTab.style.width) > 865) {\n            const transformString = mainTab.style.transform;\n            const scaleMatch = transformString.match(\n                /scale\\(([-+]?[0-9]*\\.?[0-9]+)\\)/\n            );\n            let zoom = 1; // default zoom\n\n            if (scaleMatch && scaleMatch[1]) {\n                zoom = Number(scaleMatch[1]);\n            }\n\n            imageARPreview.style.transformOrigin = \"0 0\";\n            imageARPreview.style.transform = `scale(${zoom})`;\n        }\n\n        if (img.style.display !== \"none\") return;\n\n        img.style.display = \"block\";\n\n        setTimeout(() => {\n            img.style.display = \"none\";\n        }, 400);\n    }\n\n    const hotkeysConfigOpts = await waitForOpts();\n\n    // Default config\n    const defaultHotkeysConfig = {\n        canvas_hotkey_zoom: \"Alt\",\n        canvas_hotkey_adjust: \"Ctrl\",\n        canvas_hotkey_reset: \"KeyR\",\n        canvas_hotkey_fullscreen: \"KeyS\",\n        canvas_hotkey_move: \"KeyF\",\n        canvas_hotkey_overlap: \"KeyO\",\n        canvas_hotkey_shrink_brush: \"KeyQ\",\n        canvas_hotkey_grow_brush: \"KeyW\",\n        canvas_disabled_functions: [],\n        canvas_show_tooltip: true,\n        canvas_auto_expand: true,\n        canvas_blur_prompt: false,\n    };\n\n    const functionMap = {\n        \"Zoom\": \"canvas_hotkey_zoom\",\n        \"Adjust brush size\": \"canvas_hotkey_adjust\",\n        \"Hotkey shrink brush\": \"canvas_hotkey_shrink_brush\",\n        \"Hotkey enlarge brush\": \"canvas_hotkey_grow_brush\",\n        \"Moving canvas\": \"canvas_hotkey_move\",\n        \"Fullscreen\": \"canvas_hotkey_fullscreen\",\n        \"Reset Zoom\": \"canvas_hotkey_reset\",\n        \"Overlap\": \"canvas_hotkey_overlap\"\n    };\n\n    // Loading the configuration from opts\n    const preHotkeysConfig = createHotkeyConfig(\n        defaultHotkeysConfig,\n        hotkeysConfigOpts\n    );\n\n    // Disable functions that are not needed by the user\n    const hotkeysConfig = disableFunctions(\n        preHotkeysConfig,\n        preHotkeysConfig.canvas_disabled_functions\n    );\n\n    let isMoving = false;\n    let mouseX, mouseY;\n    let activeElement;\n    let interactedWithAltKey = false;\n\n    const elements = Object.fromEntries(\n        Object.keys(elementIDs).map(id => [\n            id,\n            gradioApp().querySelector(elementIDs[id])\n        ])\n    );\n    const elemData = {};\n\n    // Apply functionality to the range inputs. Restore redmask and correct for long images.\n    const rangeInputs = elements.rangeGroup ?\n        Array.from(elements.rangeGroup.querySelectorAll(\"input\")) :\n        [\n            gradioApp().querySelector(\"#img2img_width input[type='range']\"),\n            gradioApp().querySelector(\"#img2img_height input[type='range']\")\n        ];\n\n    for (const input of rangeInputs) {\n        input?.addEventListener(\"input\", () => restoreImgRedMask(elements));\n    }\n\n    function applyZoomAndPan(elemId, isExtension = true) {\n        const targetElement = gradioApp().querySelector(elemId);\n\n        if (!targetElement) {\n            console.log(\"Element not found\", elemId);\n            return;\n        }\n\n        targetElement.style.transformOrigin = \"0 0\";\n\n        elemData[elemId] = {\n            zoom: 1,\n            panX: 0,\n            panY: 0\n        };\n        let fullScreenMode = false;\n\n        // Create tooltip\n        function createTooltip() {\n            const toolTipElement =\n                targetElement.querySelector(\".image-container\");\n            const tooltip = document.createElement(\"div\");\n            tooltip.className = \"canvas-tooltip\";\n\n            // Creating an item of information\n            const info = document.createElement(\"i\");\n            info.className = \"canvas-tooltip-info\";\n            info.textContent = \"\";\n\n            // Create a container for the contents of the tooltip\n            const tooltipContent = document.createElement(\"div\");\n            tooltipContent.className = \"canvas-tooltip-content\";\n\n            // Define an array with hotkey information and their actions\n            const hotkeysInfo = [\n                {\n                    configKey: \"canvas_hotkey_zoom\",\n                    action: \"Zoom canvas\",\n                    keySuffix: \" + wheel\"\n                },\n                {\n                    configKey: \"canvas_hotkey_adjust\",\n                    action: \"Adjust brush size\",\n                    keySuffix: \" + wheel\"\n                },\n                {configKey: \"canvas_hotkey_reset\", action: \"Reset zoom\"},\n                {\n                    configKey: \"canvas_hotkey_fullscreen\",\n                    action: \"Fullscreen mode\"\n                },\n                {configKey: \"canvas_hotkey_move\", action: \"Move canvas\"},\n                {configKey: \"canvas_hotkey_overlap\", action: \"Overlap\"}\n            ];\n\n            // Create hotkeys array with disabled property based on the config values\n            const hotkeys = hotkeysInfo.map(info => {\n                const configValue = hotkeysConfig[info.configKey];\n                const key = info.keySuffix ?\n                    `${configValue}${info.keySuffix}` :\n                    configValue.charAt(configValue.length - 1);\n                return {\n                    key,\n                    action: info.action,\n                    disabled: configValue === \"disable\"\n                };\n            });\n\n            for (const hotkey of hotkeys) {\n                if (hotkey.disabled) {\n                    continue;\n                }\n\n                const p = document.createElement(\"p\");\n                p.innerHTML = `<b>${hotkey.key}</b> - ${hotkey.action}`;\n                tooltipContent.appendChild(p);\n            }\n\n            // Add information and content elements to the tooltip element\n            tooltip.appendChild(info);\n            tooltip.appendChild(tooltipContent);\n\n            // Add a hint element to the target element\n            toolTipElement.appendChild(tooltip);\n        }\n\n        //Show tool tip if setting enable\n        if (hotkeysConfig.canvas_show_tooltip) {\n            createTooltip();\n        }\n\n        // In the course of research, it was found that the tag img is very harmful when zooming and creates white canvases. This hack allows you to almost never think about this problem, it has no effect on webui.\n        function fixCanvas() {\n            const activeTab = getActiveTab(elements)?.textContent.trim();\n\n            if (activeTab && activeTab !== \"img2img\") {\n                const img = targetElement.querySelector(`${elemId} img`);\n\n                if (img && img.style.display !== \"none\") {\n                    img.style.display = \"none\";\n                    img.style.visibility = \"hidden\";\n                }\n            }\n        }\n\n        // Reset the zoom level and pan position of the target element to their initial values\n        function resetZoom() {\n            elemData[elemId] = {\n                zoomLevel: 1,\n                panX: 0,\n                panY: 0\n            };\n\n            if (isExtension) {\n                targetElement.style.overflow = \"hidden\";\n            }\n\n            targetElement.isZoomed = false;\n\n            fixCanvas();\n            targetElement.style.transform = `scale(${elemData[elemId].zoomLevel}) translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px)`;\n\n            const canvas = gradioApp().querySelector(\n                `${elemId} canvas[key=\"interface\"]`\n            );\n\n            toggleOverlap(\"off\");\n            fullScreenMode = false;\n\n            const closeBtn = targetElement.querySelector(\"button[aria-label='Remove Image']\");\n            if (closeBtn) {\n                closeBtn.addEventListener(\"click\", resetZoom);\n            }\n\n            if (canvas && isExtension) {\n                const parentElement = targetElement.closest('[id^=\"component-\"]');\n                if (\n                    canvas &&\n                    parseFloat(canvas.style.width) > parentElement.offsetWidth &&\n                    parseFloat(targetElement.style.width) > parentElement.offsetWidth\n                ) {\n                    fitToElement();\n                    return;\n                }\n\n            }\n\n            if (\n                canvas &&\n                !isExtension &&\n                parseFloat(canvas.style.width) > 865 &&\n                parseFloat(targetElement.style.width) > 865\n            ) {\n                fitToElement();\n                return;\n            }\n\n            targetElement.style.width = \"\";\n        }\n\n        // Toggle the zIndex of the target element between two values, allowing it to overlap or be overlapped by other elements\n        function toggleOverlap(forced = \"\") {\n            const zIndex1 = \"0\";\n            const zIndex2 = \"998\";\n\n            targetElement.style.zIndex =\n                targetElement.style.zIndex !== zIndex2 ? zIndex2 : zIndex1;\n\n            if (forced === \"off\") {\n                targetElement.style.zIndex = zIndex1;\n            } else if (forced === \"on\") {\n                targetElement.style.zIndex = zIndex2;\n            }\n        }\n\n        // Adjust the brush size based on the deltaY value from a mouse wheel event\n        function adjustBrushSize(\n            elemId,\n            deltaY,\n            withoutValue = false,\n            percentage = 5\n        ) {\n            const input =\n                gradioApp().querySelector(\n                    `${elemId} input[aria-label='Brush radius']`\n                ) ||\n                gradioApp().querySelector(\n                    `${elemId} button[aria-label=\"Use brush\"]`\n                );\n\n            if (input) {\n                input.click();\n                if (!withoutValue) {\n                    const maxValue =\n                        parseFloat(input.getAttribute(\"max\")) || 100;\n                    const changeAmount = maxValue * (percentage / 100);\n                    const newValue =\n                        parseFloat(input.value) +\n                        (deltaY > 0 ? -changeAmount : changeAmount);\n                    input.value = Math.min(Math.max(newValue, 0), maxValue);\n                    input.dispatchEvent(new Event(\"change\"));\n                }\n            }\n        }\n\n        // Reset zoom when uploading a new image\n        const fileInput = gradioApp().querySelector(\n            `${elemId} input[type=\"file\"][accept=\"image/*\"].svelte-116rqfv`\n        );\n        fileInput.addEventListener(\"click\", resetZoom);\n\n        // Update the zoom level and pan position of the target element based on the values of the zoomLevel, panX and panY variables\n        function updateZoom(newZoomLevel, mouseX, mouseY) {\n            newZoomLevel = Math.max(0.1, Math.min(newZoomLevel, 15));\n\n            elemData[elemId].panX +=\n                mouseX - (mouseX * newZoomLevel) / elemData[elemId].zoomLevel;\n            elemData[elemId].panY +=\n                mouseY - (mouseY * newZoomLevel) / elemData[elemId].zoomLevel;\n\n            targetElement.style.transformOrigin = \"0 0\";\n            targetElement.style.transform = `translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px) scale(${newZoomLevel})`;\n\n            toggleOverlap(\"on\");\n            if (isExtension) {\n                targetElement.style.overflow = \"visible\";\n            }\n\n            return newZoomLevel;\n        }\n\n        // Change the zoom level based on user interaction\n        function changeZoomLevel(operation, e) {\n            if (isModifierKey(e, hotkeysConfig.canvas_hotkey_zoom)) {\n                e.preventDefault();\n\n                if (hotkeysConfig.canvas_hotkey_zoom === \"Alt\") {\n                    interactedWithAltKey = true;\n                }\n\n                let zoomPosX, zoomPosY;\n                let delta = 0.2;\n                if (elemData[elemId].zoomLevel > 7) {\n                    delta = 0.9;\n                } else if (elemData[elemId].zoomLevel > 2) {\n                    delta = 0.6;\n                }\n\n                zoomPosX = e.clientX;\n                zoomPosY = e.clientY;\n\n                fullScreenMode = false;\n                elemData[elemId].zoomLevel = updateZoom(\n                    elemData[elemId].zoomLevel +\n                    (operation === \"+\" ? delta : -delta),\n                    zoomPosX - targetElement.getBoundingClientRect().left,\n                    zoomPosY - targetElement.getBoundingClientRect().top\n                );\n\n                targetElement.isZoomed = true;\n            }\n        }\n\n        /**\n         * This function fits the target element to the screen by calculating\n         * the required scale and offsets. It also updates the global variables\n         * zoomLevel, panX, and panY to reflect the new state.\n         */\n\n        function fitToElement() {\n            //Reset Zoom\n            targetElement.style.transform = `translate(${0}px, ${0}px) scale(${1})`;\n\n            let parentElement;\n\n            if (isExtension) {\n                parentElement = targetElement.closest('[id^=\"component-\"]');\n            } else {\n                parentElement = targetElement.parentElement;\n            }\n\n\n            // Get element and screen dimensions\n            const elementWidth = targetElement.offsetWidth;\n            const elementHeight = targetElement.offsetHeight;\n\n            const screenWidth = parentElement.clientWidth;\n            const screenHeight = parentElement.clientHeight;\n\n            // Get element's coordinates relative to the parent element\n            const elementRect = targetElement.getBoundingClientRect();\n            const parentRect = parentElement.getBoundingClientRect();\n            const elementX = elementRect.x - parentRect.x;\n\n            // Calculate scale and offsets\n            const scaleX = screenWidth / elementWidth;\n            const scaleY = screenHeight / elementHeight;\n            const scale = Math.min(scaleX, scaleY);\n\n            const transformOrigin =\n                window.getComputedStyle(targetElement).transformOrigin;\n            const [originX, originY] = transformOrigin.split(\" \");\n            const originXValue = parseFloat(originX);\n            const originYValue = parseFloat(originY);\n\n            const offsetX =\n                (screenWidth - elementWidth * scale) / 2 -\n                originXValue * (1 - scale);\n            const offsetY =\n                (screenHeight - elementHeight * scale) / 2.5 -\n                originYValue * (1 - scale);\n\n            // Apply scale and offsets to the element\n            targetElement.style.transform = `translate(${offsetX}px, ${offsetY}px) scale(${scale})`;\n\n            // Update global variables\n            elemData[elemId].zoomLevel = scale;\n            elemData[elemId].panX = offsetX;\n            elemData[elemId].panY = offsetY;\n\n            fullScreenMode = false;\n            toggleOverlap(\"off\");\n        }\n\n        /**\n         * This function fits the target element to the screen by calculating\n         * the required scale and offsets. It also updates the global variables\n         * zoomLevel, panX, and panY to reflect the new state.\n         */\n\n        // Fullscreen mode\n        function fitToScreen() {\n            const canvas = gradioApp().querySelector(\n                `${elemId} canvas[key=\"interface\"]`\n            );\n\n            if (!canvas) return;\n\n            if (canvas.offsetWidth > 862 || isExtension) {\n                targetElement.style.width = (canvas.offsetWidth + 2) + \"px\";\n            }\n\n            if (isExtension) {\n                targetElement.style.overflow = \"visible\";\n            }\n\n            if (fullScreenMode) {\n                resetZoom();\n                fullScreenMode = false;\n                return;\n            }\n\n            //Reset Zoom\n            targetElement.style.transform = `translate(${0}px, ${0}px) scale(${1})`;\n\n            // Get scrollbar width to right-align the image\n            const scrollbarWidth =\n                window.innerWidth - document.documentElement.clientWidth;\n\n            // Get element and screen dimensions\n            const elementWidth = targetElement.offsetWidth;\n            const elementHeight = targetElement.offsetHeight;\n            const screenWidth = window.innerWidth - scrollbarWidth;\n            const screenHeight = window.innerHeight;\n\n            // Get element's coordinates relative to the page\n            const elementRect = targetElement.getBoundingClientRect();\n            const elementY = elementRect.y;\n            const elementX = elementRect.x;\n\n            // Calculate scale and offsets\n            const scaleX = screenWidth / elementWidth;\n            const scaleY = screenHeight / elementHeight;\n            const scale = Math.min(scaleX, scaleY);\n\n            // Get the current transformOrigin\n            const computedStyle = window.getComputedStyle(targetElement);\n            const transformOrigin = computedStyle.transformOrigin;\n            const [originX, originY] = transformOrigin.split(\" \");\n            const originXValue = parseFloat(originX);\n            const originYValue = parseFloat(originY);\n\n            // Calculate offsets with respect to the transformOrigin\n            const offsetX =\n                (screenWidth - elementWidth * scale) / 2 -\n                elementX -\n                originXValue * (1 - scale);\n            const offsetY =\n                (screenHeight - elementHeight * scale) / 2 -\n                elementY -\n                originYValue * (1 - scale);\n\n            // Apply scale and offsets to the element\n            targetElement.style.transform = `translate(${offsetX}px, ${offsetY}px) scale(${scale})`;\n\n            // Update global variables\n            elemData[elemId].zoomLevel = scale;\n            elemData[elemId].panX = offsetX;\n            elemData[elemId].panY = offsetY;\n\n            fullScreenMode = true;\n            toggleOverlap(\"on\");\n        }\n\n        // Handle keydown events\n        function handleKeyDown(event) {\n            // Disable key locks to make pasting from the buffer work correctly\n            if ((event.ctrlKey && event.code === 'KeyV') || (event.ctrlKey && event.code === 'KeyC') || event.code === \"F5\") {\n                return;\n            }\n\n            // before activating shortcut, ensure user is not actively typing in an input field\n            if (!hotkeysConfig.canvas_blur_prompt) {\n                if (event.target.nodeName === 'TEXTAREA' || event.target.nodeName === 'INPUT') {\n                    return;\n                }\n            }\n\n\n            const hotkeyActions = {\n                [hotkeysConfig.canvas_hotkey_reset]: resetZoom,\n                [hotkeysConfig.canvas_hotkey_overlap]: toggleOverlap,\n                [hotkeysConfig.canvas_hotkey_fullscreen]: fitToScreen,\n                [hotkeysConfig.canvas_hotkey_shrink_brush]: () => adjustBrushSize(elemId, 10),\n                [hotkeysConfig.canvas_hotkey_grow_brush]: () => adjustBrushSize(elemId, -10)\n            };\n\n            const action = hotkeyActions[event.code];\n            if (action) {\n                event.preventDefault();\n                action(event);\n            }\n\n            if (\n                isModifierKey(event, hotkeysConfig.canvas_hotkey_zoom) ||\n                isModifierKey(event, hotkeysConfig.canvas_hotkey_adjust)\n            ) {\n                event.preventDefault();\n            }\n        }\n\n        // Get Mouse position\n        function getMousePosition(e) {\n            mouseX = e.offsetX;\n            mouseY = e.offsetY;\n        }\n\n        // Simulation of the function to put a long image into the screen.\n        // We detect if an image has a scroll bar or not, make a fullscreen to reveal the image, then reduce it to fit into the element.\n        // We hide the image and show it to the user when it is ready.\n\n        targetElement.isExpanded = false;\n        function autoExpand() {\n            const canvas = document.querySelector(`${elemId} canvas[key=\"interface\"]`);\n            if (canvas) {\n                if (hasHorizontalScrollbar(targetElement) && targetElement.isExpanded === false) {\n                    targetElement.style.visibility = \"hidden\";\n                    setTimeout(() => {\n                        fitToScreen();\n                        resetZoom();\n                        targetElement.style.visibility = \"visible\";\n                        targetElement.isExpanded = true;\n                    }, 10);\n                }\n            }\n        }\n\n        targetElement.addEventListener(\"mousemove\", getMousePosition);\n\n        //observers\n        // Creating an observer with a callback function to handle DOM changes\n        const observer = new MutationObserver((mutationsList, observer) => {\n            for (let mutation of mutationsList) {\n                // If the style attribute of the canvas has changed, by observation it happens only when the picture changes\n                if (mutation.type === 'attributes' && mutation.attributeName === 'style' &&\n                    mutation.target.tagName.toLowerCase() === 'canvas') {\n                    targetElement.isExpanded = false;\n                    setTimeout(resetZoom, 10);\n                }\n            }\n        });\n\n        // Apply auto expand if enabled\n        if (hotkeysConfig.canvas_auto_expand) {\n            targetElement.addEventListener(\"mousemove\", autoExpand);\n            // Set up an observer to track attribute changes\n            observer.observe(targetElement, {attributes: true, childList: true, subtree: true});\n        }\n\n        // Handle events only inside the targetElement\n        let isKeyDownHandlerAttached = false;\n\n        function handleMouseMove() {\n            if (!isKeyDownHandlerAttached) {\n                document.addEventListener(\"keydown\", handleKeyDown);\n                isKeyDownHandlerAttached = true;\n\n                activeElement = elemId;\n            }\n        }\n\n        function handleMouseLeave() {\n            if (isKeyDownHandlerAttached) {\n                document.removeEventListener(\"keydown\", handleKeyDown);\n                isKeyDownHandlerAttached = false;\n\n                activeElement = null;\n            }\n        }\n\n        // Add mouse event handlers\n        targetElement.addEventListener(\"mousemove\", handleMouseMove);\n        targetElement.addEventListener(\"mouseleave\", handleMouseLeave);\n\n        // Reset zoom when click on another tab\n        if (elements.img2imgTabs) {\n            elements.img2imgTabs.addEventListener(\"click\", resetZoom);\n            elements.img2imgTabs.addEventListener(\"click\", () => {\n                // targetElement.style.width = \"\";\n                if (parseInt(targetElement.style.width) > 865) {\n                    setTimeout(fitToElement, 0);\n                }\n            });\n        }\n\n        targetElement.addEventListener(\"wheel\", e => {\n            // change zoom level\n            const operation = (e.deltaY || -e.wheelDelta) > 0 ? \"-\" : \"+\";\n            changeZoomLevel(operation, e);\n\n            // Handle brush size adjustment with ctrl key pressed\n            if (isModifierKey(e, hotkeysConfig.canvas_hotkey_adjust)) {\n                e.preventDefault();\n\n                if (hotkeysConfig.canvas_hotkey_adjust === \"Alt\") {\n                    interactedWithAltKey = true;\n                }\n\n                // Increase or decrease brush size based on scroll direction\n                adjustBrushSize(elemId, e.deltaY);\n            }\n        });\n\n        // Handle the move event for pan functionality. Updates the panX and panY variables and applies the new transform to the target element.\n        function handleMoveKeyDown(e) {\n\n            // Disable key locks to make pasting from the buffer work correctly\n            if ((e.ctrlKey && e.code === 'KeyV') || (e.ctrlKey && event.code === 'KeyC') || e.code === \"F5\") {\n                return;\n            }\n\n            // before activating shortcut, ensure user is not actively typing in an input field\n            if (!hotkeysConfig.canvas_blur_prompt) {\n                if (e.target.nodeName === 'TEXTAREA' || e.target.nodeName === 'INPUT') {\n                    return;\n                }\n            }\n\n\n            if (e.code === hotkeysConfig.canvas_hotkey_move) {\n                if (!e.ctrlKey && !e.metaKey && isKeyDownHandlerAttached) {\n                    e.preventDefault();\n                    document.activeElement.blur();\n                    isMoving = true;\n                }\n            }\n        }\n\n        function handleMoveKeyUp(e) {\n            if (e.code === hotkeysConfig.canvas_hotkey_move) {\n                isMoving = false;\n            }\n        }\n\n        document.addEventListener(\"keydown\", handleMoveKeyDown);\n        document.addEventListener(\"keyup\", handleMoveKeyUp);\n\n\n        // Prevent firefox from opening main menu when alt is used as a hotkey for zoom or brush size\n        function handleAltKeyUp(e) {\n            if (e.key !== \"Alt\" || !interactedWithAltKey) {\n                return;\n            }\n\n            e.preventDefault();\n            interactedWithAltKey = false;\n        }\n\n        document.addEventListener(\"keyup\", handleAltKeyUp);\n\n\n        // Detect zoom level and update the pan speed.\n        function updatePanPosition(movementX, movementY) {\n            let panSpeed = 2;\n\n            if (elemData[elemId].zoomLevel > 8) {\n                panSpeed = 3.5;\n            }\n\n            elemData[elemId].panX += movementX * panSpeed;\n            elemData[elemId].panY += movementY * panSpeed;\n\n            // Delayed redraw of an element\n            requestAnimationFrame(() => {\n                targetElement.style.transform = `translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px) scale(${elemData[elemId].zoomLevel})`;\n                toggleOverlap(\"on\");\n            });\n        }\n\n        function handleMoveByKey(e) {\n            if (isMoving && elemId === activeElement) {\n                updatePanPosition(e.movementX, e.movementY);\n                targetElement.style.pointerEvents = \"none\";\n\n                if (isExtension) {\n                    targetElement.style.overflow = \"visible\";\n                }\n\n            } else {\n                targetElement.style.pointerEvents = \"auto\";\n            }\n        }\n\n        // Prevents sticking to the mouse\n        window.onblur = function() {\n            isMoving = false;\n        };\n\n        // Checks for extension\n        function checkForOutBox() {\n            const parentElement = targetElement.closest('[id^=\"component-\"]');\n            if (parentElement.offsetWidth < targetElement.offsetWidth && !targetElement.isExpanded) {\n                resetZoom();\n                targetElement.isExpanded = true;\n            }\n\n            if (parentElement.offsetWidth < targetElement.offsetWidth && elemData[elemId].zoomLevel == 1) {\n                resetZoom();\n            }\n\n            if (parentElement.offsetWidth < targetElement.offsetWidth && targetElement.offsetWidth * elemData[elemId].zoomLevel > parentElement.offsetWidth && elemData[elemId].zoomLevel < 1 && !targetElement.isZoomed) {\n                resetZoom();\n            }\n        }\n\n        if (isExtension) {\n            targetElement.addEventListener(\"mousemove\", checkForOutBox);\n        }\n\n\n        window.addEventListener('resize', (e) => {\n            resetZoom();\n\n            if (isExtension) {\n                targetElement.isExpanded = false;\n                targetElement.isZoomed = false;\n            }\n        });\n\n        gradioApp().addEventListener(\"mousemove\", handleMoveByKey);\n\n\n    }\n\n    applyZoomAndPan(elementIDs.sketch, false);\n    applyZoomAndPan(elementIDs.inpaint, false);\n    applyZoomAndPan(elementIDs.inpaintSketch, false);\n\n    // Make the function global so that other extensions can take advantage of this solution\n    const applyZoomAndPanIntegration = async(id, elementIDs) => {\n        const mainEl = document.querySelector(id);\n        if (id.toLocaleLowerCase() === \"none\") {\n            for (const elementID of elementIDs) {\n                const el = await waitForElement(elementID);\n                if (!el) break;\n                applyZoomAndPan(elementID);\n            }\n            return;\n        }\n\n        if (!mainEl) return;\n        mainEl.addEventListener(\"click\", async() => {\n            for (const elementID of elementIDs) {\n                const el = await waitForElement(elementID);\n                if (!el) break;\n                applyZoomAndPan(elementID);\n            }\n        }, {once: true});\n    };\n\n    window.applyZoomAndPan = applyZoomAndPan; // Only 1 elements, argument elementID, for example applyZoomAndPan(\"#txt2img_controlnet_ControlNet_input_image\")\n\n    window.applyZoomAndPanIntegration = applyZoomAndPanIntegration; // for any extension\n\n    /*\n        The function `applyZoomAndPanIntegration` takes two arguments:\n\n        1. `id`: A string identifier for the element to which zoom and pan functionality will be applied on click.\n        If the `id` value is \"none\", the functionality will be applied to all elements specified in the second argument without a click event.\n\n        2. `elementIDs`: An array of string identifiers for elements. Zoom and pan functionality will be applied to each of these elements on click of the element specified by the first argument.\n        If \"none\" is specified in the first argument, the functionality will be applied to each of these elements without a click event.\n\n        Example usage:\n        applyZoomAndPanIntegration(\"#txt2img_controlnet\", [\"#txt2img_controlnet_ControlNet_input_image\"]);\n        In this example, zoom and pan functionality will be applied to the element with the identifier \"txt2img_controlnet_ControlNet_input_image\" upon clicking the element with the identifier \"txt2img_controlnet\".\n    */\n\n    // More examples\n    // Add integration with ControlNet txt2img One TAB\n    // applyZoomAndPanIntegration(\"#txt2img_controlnet\", [\"#txt2img_controlnet_ControlNet_input_image\"]);\n\n    // Add integration with ControlNet txt2img Tabs\n    // applyZoomAndPanIntegration(\"#txt2img_controlnet\",Array.from({ length: 10 }, (_, i) => `#txt2img_controlnet_ControlNet-${i}_input_image`));\n\n    // Add integration with Inpaint Anything\n    // applyZoomAndPanIntegration(\"None\", [\"#ia_sam_image\", \"#ia_sel_mask\"]);\n});\n"
  },
  {
    "path": "extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py",
    "content": "import gradio as gr\nfrom modules import shared\n\nshared.options_templates.update(shared.options_section(('canvas_hotkey', \"Canvas Hotkeys\"), {\n    \"canvas_hotkey_zoom\": shared.OptionInfo(\"Alt\", \"Zoom canvas\", gr.Radio, {\"choices\": [\"Shift\",\"Ctrl\", \"Alt\"]}).info(\"If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox\"),\n    \"canvas_hotkey_adjust\": shared.OptionInfo(\"Ctrl\", \"Adjust brush size\", gr.Radio, {\"choices\": [\"Shift\",\"Ctrl\", \"Alt\"]}).info(\"If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox\"),\n    \"canvas_hotkey_shrink_brush\": shared.OptionInfo(\"Q\", \"Shrink the brush size\"),\n    \"canvas_hotkey_grow_brush\": shared.OptionInfo(\"W\", \"Enlarge the brush size\"),\n    \"canvas_hotkey_move\": shared.OptionInfo(\"F\", \"Moving the canvas\").info(\"To work correctly in firefox, turn off 'Automatically search the page text when typing' in the browser settings\"),\n    \"canvas_hotkey_fullscreen\": shared.OptionInfo(\"S\", \"Fullscreen Mode, maximizes the picture so that it fits into the screen and stretches it to its full width \"),\n    \"canvas_hotkey_reset\": shared.OptionInfo(\"R\", \"Reset zoom and canvas position\"),\n    \"canvas_hotkey_overlap\": shared.OptionInfo(\"O\", \"Toggle overlap\").info(\"Technical button, needed for testing\"),\n    \"canvas_show_tooltip\": shared.OptionInfo(True, \"Enable tooltip on the canvas\"),\n    \"canvas_auto_expand\": shared.OptionInfo(True, \"Automatically expands an image that does not fit completely in the canvas area, similar to manually pressing the S and R buttons\"),\n    \"canvas_blur_prompt\": shared.OptionInfo(False, \"Take the focus off the prompt when working with a canvas\"),\n    \"canvas_disabled_functions\": shared.OptionInfo([\"Overlap\"], \"Disable function that you don't use\", gr.CheckboxGroup, {\"choices\": [\"Zoom\",\"Adjust brush size\",\"Hotkey enlarge brush\",\"Hotkey shrink brush\",\"Moving canvas\",\"Fullscreen\",\"Reset Zoom\",\"Overlap\"]}),\n}))\n"
  },
  {
    "path": "extensions-builtin/canvas-zoom-and-pan/style.css",
    "content": ".canvas-tooltip-info {\n  position: absolute;\n  top: 10px;\n  left: 10px;\n  cursor: help;\n  background-color: rgba(0, 0, 0, 0.3);\n  width: 20px;\n  height: 20px;\n  border-radius: 50%;\n  display: flex;\n  align-items: center;\n  justify-content: center;\n  flex-direction: column; \n\n  z-index: 100;\n}\n\n.canvas-tooltip-info::after {\n  content: '';\n  display: block;\n  width: 2px;\n  height: 7px;\n  background-color: white;\n  margin-top: 2px; \n}\n\n.canvas-tooltip-info::before {\n  content: '';\n  display: block;\n  width: 2px;\n  height: 2px;\n  background-color: white;\n}\n\n.canvas-tooltip-content {\n  display: none;\n  background-color: #f9f9f9; \n  color: #333; \n  border: 1px solid #ddd;\n  padding: 15px; \n  position: absolute;\n  top: 40px;\n  left: 10px;\n  width: 250px;\n  font-size: 16px;\n  opacity: 0;\n  border-radius: 8px;\n  box-shadow: 0px 8px 16px 0px rgba(0,0,0,0.2); \n\n  z-index: 100;\n}\n\n.canvas-tooltip:hover .canvas-tooltip-content {\n  display: block;\n  animation: fadeIn 0.5s; \n  opacity: 1;\n}\n\n@keyframes fadeIn {\n  from {opacity: 0;}\n  to {opacity: 1;}\n}\n\n.styler {\n  overflow:inherit !important;\n}"
  },
  {
    "path": "extensions-builtin/extra-options-section/scripts/extra_options_section.py",
    "content": "import math\r\n\r\nimport gradio as gr\r\nfrom modules import scripts, shared, ui_components, ui_settings, infotext_utils, errors\r\nfrom modules.ui_components import FormColumn\r\n\r\n\r\nclass ExtraOptionsSection(scripts.Script):\r\n    section = \"extra_options\"\r\n\r\n    def __init__(self):\r\n        self.comps = None\r\n        self.setting_names = None\r\n\r\n    def title(self):\r\n        return \"Extra options\"\r\n\r\n    def show(self, is_img2img):\r\n        return scripts.AlwaysVisible\r\n\r\n    def ui(self, is_img2img):\r\n        self.comps = []\r\n        self.setting_names = []\r\n        self.infotext_fields = []\r\n        extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img\r\n        elem_id_tabname = \"extra_options_\" + (\"img2img\" if is_img2img else \"txt2img\")\r\n\r\n        mapping = {k: v for v, k in infotext_utils.infotext_to_setting_name_mapping}\r\n\r\n        with gr.Blocks() as interface:\r\n            with gr.Accordion(\"Options\", open=False, elem_id=elem_id_tabname) if shared.opts.extra_options_accordion and extra_options else gr.Group(elem_id=elem_id_tabname):\r\n\r\n                row_count = math.ceil(len(extra_options) / shared.opts.extra_options_cols)\r\n\r\n                for row in range(row_count):\r\n                    with gr.Row():\r\n                        for col in range(shared.opts.extra_options_cols):\r\n                            index = row * shared.opts.extra_options_cols + col\r\n                            if index >= len(extra_options):\r\n                                break\r\n\r\n                            setting_name = extra_options[index]\r\n\r\n                            with FormColumn():\r\n                                try:\r\n                                    comp = ui_settings.create_setting_component(setting_name)\r\n                                except KeyError:\r\n                                    errors.report(f\"Can't add extra options for {setting_name} in ui\")\r\n                                    continue\r\n\r\n                            self.comps.append(comp)\r\n                            self.setting_names.append(setting_name)\r\n\r\n                            setting_infotext_name = mapping.get(setting_name)\r\n                            if setting_infotext_name is not None:\r\n                                self.infotext_fields.append((comp, setting_infotext_name))\r\n\r\n        def get_settings_values():\r\n            res = [ui_settings.get_value_for_setting(key) for key in self.setting_names]\r\n            return res[0] if len(res) == 1 else res\r\n\r\n        interface.load(fn=get_settings_values, inputs=[], outputs=self.comps, queue=False, show_progress=False)\r\n\r\n        return self.comps\r\n\r\n    def before_process(self, p, *args):\r\n        for name, value in zip(self.setting_names, args):\r\n            if name not in p.override_settings:\r\n                p.override_settings[name] = value\r\n\r\n\r\nshared.options_templates.update(shared.options_section(('settings_in_ui', \"Settings in UI\", \"ui\"), {\r\n    \"settings_in_ui\": shared.OptionHTML(\"\"\"\r\nThis page allows you to add some settings to the main interface of txt2img and img2img tabs.\r\n\"\"\"),\r\n    \"extra_options_txt2img\": shared.OptionInfo([], \"Settings for txt2img\", ui_components.DropdownMulti, lambda: {\"choices\": list(shared.opts.data_labels.keys())}).js(\"info\", \"settingsHintsShowQuicksettings\").info(\"setting entries that also appear in txt2img interfaces\").needs_reload_ui(),\r\n    \"extra_options_img2img\": shared.OptionInfo([], \"Settings for img2img\", ui_components.DropdownMulti, lambda: {\"choices\": list(shared.opts.data_labels.keys())}).js(\"info\", \"settingsHintsShowQuicksettings\").info(\"setting entries that also appear in img2img interfaces\").needs_reload_ui(),\r\n    \"extra_options_cols\": shared.OptionInfo(1, \"Number of columns for added settings\", gr.Slider, {\"step\": 1, \"minimum\": 1, \"maximum\": 20}).info(\"displayed amount will depend on the actual browser window width\").needs_reload_ui(),\r\n    \"extra_options_accordion\": shared.OptionInfo(False, \"Place added settings into an accordion\").needs_reload_ui()\r\n}))\r\n\r\n\r\n"
  },
  {
    "path": "extensions-builtin/hypertile/hypertile.py",
    "content": "\"\"\"\nHypertile module for splitting attention layers in SD-1.5 U-Net and SD-1.5 VAE\nWarn: The patch works well only if the input image has a width and height that are multiples of 128\nOriginal author: @tfernd Github: https://github.com/tfernd/HyperTile\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom dataclasses import dataclass\nfrom typing import Callable\n\nfrom functools import wraps, cache\n\nimport math\nimport torch.nn as nn\nimport random\n\nfrom einops import rearrange\n\n\n@dataclass\nclass HypertileParams:\n    depth = 0\n    layer_name = \"\"\n    tile_size: int = 0\n    swap_size: int = 0\n    aspect_ratio: float = 1.0\n    forward = None\n    enabled = False\n\n\n\n# TODO add SD-XL layers\nDEPTH_LAYERS = {\n    0: [\n        # SD 1.5 U-Net (diffusers)\n        \"down_blocks.0.attentions.0.transformer_blocks.0.attn1\",\n        \"down_blocks.0.attentions.1.transformer_blocks.0.attn1\",\n        \"up_blocks.3.attentions.0.transformer_blocks.0.attn1\",\n        \"up_blocks.3.attentions.1.transformer_blocks.0.attn1\",\n        \"up_blocks.3.attentions.2.transformer_blocks.0.attn1\",\n        # SD 1.5 U-Net (ldm)\n        \"input_blocks.1.1.transformer_blocks.0.attn1\",\n        \"input_blocks.2.1.transformer_blocks.0.attn1\",\n        \"output_blocks.9.1.transformer_blocks.0.attn1\",\n        \"output_blocks.10.1.transformer_blocks.0.attn1\",\n        \"output_blocks.11.1.transformer_blocks.0.attn1\",\n        # SD 1.5 VAE\n        \"decoder.mid_block.attentions.0\",\n        \"decoder.mid.attn_1\",\n    ],\n    1: [\n        # SD 1.5 U-Net (diffusers)\n        \"down_blocks.1.attentions.0.transformer_blocks.0.attn1\",\n        \"down_blocks.1.attentions.1.transformer_blocks.0.attn1\",\n        \"up_blocks.2.attentions.0.transformer_blocks.0.attn1\",\n        \"up_blocks.2.attentions.1.transformer_blocks.0.attn1\",\n        \"up_blocks.2.attentions.2.transformer_blocks.0.attn1\",\n        # SD 1.5 U-Net (ldm)\n        \"input_blocks.4.1.transformer_blocks.0.attn1\",\n        \"input_blocks.5.1.transformer_blocks.0.attn1\",\n        \"output_blocks.6.1.transformer_blocks.0.attn1\",\n        \"output_blocks.7.1.transformer_blocks.0.attn1\",\n        \"output_blocks.8.1.transformer_blocks.0.attn1\",\n    ],\n    2: [\n        # SD 1.5 U-Net (diffusers)\n        \"down_blocks.2.attentions.0.transformer_blocks.0.attn1\",\n        \"down_blocks.2.attentions.1.transformer_blocks.0.attn1\",\n        \"up_blocks.1.attentions.0.transformer_blocks.0.attn1\",\n        \"up_blocks.1.attentions.1.transformer_blocks.0.attn1\",\n        \"up_blocks.1.attentions.2.transformer_blocks.0.attn1\",\n        # SD 1.5 U-Net (ldm)\n        \"input_blocks.7.1.transformer_blocks.0.attn1\",\n        \"input_blocks.8.1.transformer_blocks.0.attn1\",\n        \"output_blocks.3.1.transformer_blocks.0.attn1\",\n        \"output_blocks.4.1.transformer_blocks.0.attn1\",\n        \"output_blocks.5.1.transformer_blocks.0.attn1\",\n    ],\n    3: [\n        # SD 1.5 U-Net (diffusers)\n        \"mid_block.attentions.0.transformer_blocks.0.attn1\",\n        # SD 1.5 U-Net (ldm)\n        \"middle_block.1.transformer_blocks.0.attn1\",\n    ],\n}\n# XL layers, thanks for GitHub@gel-crabs for the help\nDEPTH_LAYERS_XL = {\n    0: [\n        # SD 1.5 U-Net (diffusers)\n        \"down_blocks.0.attentions.0.transformer_blocks.0.attn1\",\n        \"down_blocks.0.attentions.1.transformer_blocks.0.attn1\",\n        \"up_blocks.3.attentions.0.transformer_blocks.0.attn1\",\n        \"up_blocks.3.attentions.1.transformer_blocks.0.attn1\",\n        \"up_blocks.3.attentions.2.transformer_blocks.0.attn1\",\n        # SD 1.5 U-Net (ldm)\n        \"input_blocks.4.1.transformer_blocks.0.attn1\",\n        \"input_blocks.5.1.transformer_blocks.0.attn1\",\n        \"output_blocks.3.1.transformer_blocks.0.attn1\",\n        \"output_blocks.4.1.transformer_blocks.0.attn1\",\n        \"output_blocks.5.1.transformer_blocks.0.attn1\",\n        # SD 1.5 VAE\n        \"decoder.mid_block.attentions.0\",\n        \"decoder.mid.attn_1\",\n    ],\n    1: [\n        # SD 1.5 U-Net (diffusers)\n        #\"down_blocks.1.attentions.0.transformer_blocks.0.attn1\",\n        #\"down_blocks.1.attentions.1.transformer_blocks.0.attn1\",\n        #\"up_blocks.2.attentions.0.transformer_blocks.0.attn1\",\n        #\"up_blocks.2.attentions.1.transformer_blocks.0.attn1\",\n        #\"up_blocks.2.attentions.2.transformer_blocks.0.attn1\",\n        # SD 1.5 U-Net (ldm)\n        \"input_blocks.4.1.transformer_blocks.1.attn1\",\n        \"input_blocks.5.1.transformer_blocks.1.attn1\",\n        \"output_blocks.3.1.transformer_blocks.1.attn1\",\n        \"output_blocks.4.1.transformer_blocks.1.attn1\",\n        \"output_blocks.5.1.transformer_blocks.1.attn1\",\n        \"input_blocks.7.1.transformer_blocks.0.attn1\",\n        \"input_blocks.8.1.transformer_blocks.0.attn1\",\n        \"output_blocks.0.1.transformer_blocks.0.attn1\",\n        \"output_blocks.1.1.transformer_blocks.0.attn1\",\n        \"output_blocks.2.1.transformer_blocks.0.attn1\",\n        \"input_blocks.7.1.transformer_blocks.1.attn1\",\n        \"input_blocks.8.1.transformer_blocks.1.attn1\",\n        \"output_blocks.0.1.transformer_blocks.1.attn1\",\n        \"output_blocks.1.1.transformer_blocks.1.attn1\",\n        \"output_blocks.2.1.transformer_blocks.1.attn1\",\n        \"input_blocks.7.1.transformer_blocks.2.attn1\",\n        \"input_blocks.8.1.transformer_blocks.2.attn1\",\n        \"output_blocks.0.1.transformer_blocks.2.attn1\",\n        \"output_blocks.1.1.transformer_blocks.2.attn1\",\n        \"output_blocks.2.1.transformer_blocks.2.attn1\",\n        \"input_blocks.7.1.transformer_blocks.3.attn1\",\n        \"input_blocks.8.1.transformer_blocks.3.attn1\",\n        \"output_blocks.0.1.transformer_blocks.3.attn1\",\n        \"output_blocks.1.1.transformer_blocks.3.attn1\",\n        \"output_blocks.2.1.transformer_blocks.3.attn1\",\n        \"input_blocks.7.1.transformer_blocks.4.attn1\",\n        \"input_blocks.8.1.transformer_blocks.4.attn1\",\n        \"output_blocks.0.1.transformer_blocks.4.attn1\",\n        \"output_blocks.1.1.transformer_blocks.4.attn1\",\n        \"output_blocks.2.1.transformer_blocks.4.attn1\",\n        \"input_blocks.7.1.transformer_blocks.5.attn1\",\n        \"input_blocks.8.1.transformer_blocks.5.attn1\",\n        \"output_blocks.0.1.transformer_blocks.5.attn1\",\n        \"output_blocks.1.1.transformer_blocks.5.attn1\",\n        \"output_blocks.2.1.transformer_blocks.5.attn1\",\n        \"input_blocks.7.1.transformer_blocks.6.attn1\",\n        \"input_blocks.8.1.transformer_blocks.6.attn1\",\n        \"output_blocks.0.1.transformer_blocks.6.attn1\",\n        \"output_blocks.1.1.transformer_blocks.6.attn1\",\n        \"output_blocks.2.1.transformer_blocks.6.attn1\",\n        \"input_blocks.7.1.transformer_blocks.7.attn1\",\n        \"input_blocks.8.1.transformer_blocks.7.attn1\",\n        \"output_blocks.0.1.transformer_blocks.7.attn1\",\n        \"output_blocks.1.1.transformer_blocks.7.attn1\",\n        \"output_blocks.2.1.transformer_blocks.7.attn1\",\n        \"input_blocks.7.1.transformer_blocks.8.attn1\",\n        \"input_blocks.8.1.transformer_blocks.8.attn1\",\n        \"output_blocks.0.1.transformer_blocks.8.attn1\",\n        \"output_blocks.1.1.transformer_blocks.8.attn1\",\n        \"output_blocks.2.1.transformer_blocks.8.attn1\",\n        \"input_blocks.7.1.transformer_blocks.9.attn1\",\n        \"input_blocks.8.1.transformer_blocks.9.attn1\",\n        \"output_blocks.0.1.transformer_blocks.9.attn1\",\n        \"output_blocks.1.1.transformer_blocks.9.attn1\",\n        \"output_blocks.2.1.transformer_blocks.9.attn1\",\n    ],\n    2: [\n        # SD 1.5 U-Net (diffusers)\n        \"mid_block.attentions.0.transformer_blocks.0.attn1\",\n        # SD 1.5 U-Net (ldm)\n        \"middle_block.1.transformer_blocks.0.attn1\",\n        \"middle_block.1.transformer_blocks.1.attn1\",\n        \"middle_block.1.transformer_blocks.2.attn1\",\n        \"middle_block.1.transformer_blocks.3.attn1\",\n        \"middle_block.1.transformer_blocks.4.attn1\",\n        \"middle_block.1.transformer_blocks.5.attn1\",\n        \"middle_block.1.transformer_blocks.6.attn1\",\n        \"middle_block.1.transformer_blocks.7.attn1\",\n        \"middle_block.1.transformer_blocks.8.attn1\",\n        \"middle_block.1.transformer_blocks.9.attn1\",\n    ],\n    3 : [] # TODO - separate layers for SD-XL\n}\n\n\nRNG_INSTANCE = random.Random()\n\n@cache\ndef get_divisors(value: int, min_value: int, /, max_options: int = 1) -> list[int]:\n    \"\"\"\n    Returns divisors of value that\n        x * min_value <= value\n    in big -> small order, amount of divisors is limited by max_options\n    \"\"\"\n    max_options = max(1, max_options) # at least 1 option should be returned\n    min_value = min(min_value, value)\n    divisors = [i for i in range(min_value, value + 1) if value % i == 0] # divisors in small -> big order\n    ns = [value // i for i in divisors[:max_options]]  # has at least 1 element # big -> small order\n    return ns\n\n\ndef random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:\n    \"\"\"\n    Returns a random divisor of value that\n        x * min_value <= value\n    if max_options is 1, the behavior is deterministic\n    \"\"\"\n    ns = get_divisors(value, min_value, max_options=max_options) # get cached divisors\n    idx = RNG_INSTANCE.randint(0, len(ns) - 1)\n\n    return ns[idx]\n\n\ndef set_hypertile_seed(seed: int) -> None:\n    RNG_INSTANCE.seed(seed)\n\n\n@cache\ndef largest_tile_size_available(width: int, height: int) -> int:\n    \"\"\"\n    Calculates the largest tile size available for a given width and height\n    Tile size is always a power of 2\n    \"\"\"\n    gcd = math.gcd(width, height)\n    largest_tile_size_available = 1\n    while gcd % (largest_tile_size_available * 2) == 0:\n        largest_tile_size_available *= 2\n    return largest_tile_size_available\n\n\ndef iterative_closest_divisors(hw:int, aspect_ratio:float) -> tuple[int, int]:\n    \"\"\"\n    Finds h and w such that h*w = hw and h/w = aspect_ratio\n    We check all possible divisors of hw and return the closest to the aspect ratio\n    \"\"\"\n    divisors = [i for i in range(2, hw + 1) if hw % i == 0] # all divisors of hw\n    pairs = [(i, hw // i) for i in divisors] # all pairs of divisors of hw\n    ratios = [w/h for h, w in pairs] # all ratios of pairs of divisors of hw\n    closest_ratio = min(ratios, key=lambda x: abs(x - aspect_ratio)) # closest ratio to aspect_ratio\n    closest_pair = pairs[ratios.index(closest_ratio)] # closest pair of divisors to aspect_ratio\n    return closest_pair\n\n\n@cache\ndef find_hw_candidates(hw:int, aspect_ratio:float) -> tuple[int, int]:\n    \"\"\"\n    Finds h and w such that h*w = hw and h/w = aspect_ratio\n    \"\"\"\n    h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))\n    # find h and w such that h*w = hw and h/w = aspect_ratio\n    if h * w != hw:\n        w_candidate = hw / h\n        # check if w is an integer\n        if not w_candidate.is_integer():\n            h_candidate = hw / w\n            # check if h is an integer\n            if not h_candidate.is_integer():\n                return iterative_closest_divisors(hw, aspect_ratio)\n            else:\n                h = int(h_candidate)\n        else:\n            w = int(w_candidate)\n    return h, w\n\n\ndef self_attn_forward(params: HypertileParams, scale_depth=True) -> Callable:\n\n    @wraps(params.forward)\n    def wrapper(*args, **kwargs):\n        if not params.enabled:\n            return params.forward(*args, **kwargs)\n\n        latent_tile_size = max(128, params.tile_size) // 8\n        x = args[0]\n\n        # VAE\n        if x.ndim == 4:\n            b, c, h, w = x.shape\n\n            nh = random_divisor(h, latent_tile_size, params.swap_size)\n            nw = random_divisor(w, latent_tile_size, params.swap_size)\n\n            if nh * nw > 1:\n                x = rearrange(x, \"b c (nh h) (nw w) -> (b nh nw) c h w\", nh=nh, nw=nw)  # split into nh * nw tiles\n\n            out = params.forward(x, *args[1:], **kwargs)\n\n            if nh * nw > 1:\n                out = rearrange(out, \"(b nh nw) c h w -> b c (nh h) (nw w)\", nh=nh, nw=nw)\n\n        # U-Net\n        else:\n            hw: int = x.size(1)\n            h, w = find_hw_candidates(hw, params.aspect_ratio)\n            assert h * w == hw, f\"Invalid aspect ratio {params.aspect_ratio} for input of shape {x.shape}, hw={hw}, h={h}, w={w}\"\n\n            factor = 2 ** params.depth if scale_depth else 1\n            nh = random_divisor(h, latent_tile_size * factor, params.swap_size)\n            nw = random_divisor(w, latent_tile_size * factor, params.swap_size)\n\n            if nh * nw > 1:\n                x = rearrange(x, \"b (nh h nw w) c -> (b nh nw) (h w) c\", h=h // nh, w=w // nw, nh=nh, nw=nw)\n\n            out = params.forward(x, *args[1:], **kwargs)\n\n            if nh * nw > 1:\n                out = rearrange(out, \"(b nh nw) hw c -> b nh nw hw c\", nh=nh, nw=nw)\n                out = rearrange(out, \"b nh nw (h w) c -> b (nh h nw w) c\", h=h // nh, w=w // nw)\n\n        return out\n\n    return wrapper\n\n\ndef hypertile_hook_model(model: nn.Module, width, height, *, enable=False, tile_size_max=128, swap_size=1, max_depth=3, is_sdxl=False):\n    hypertile_layers = getattr(model, \"__webui_hypertile_layers\", None)\n    if hypertile_layers is None:\n        if not enable:\n            return\n\n        hypertile_layers = {}\n        layers = DEPTH_LAYERS_XL if is_sdxl else DEPTH_LAYERS\n\n        for depth in range(4):\n            for layer_name, module in model.named_modules():\n                if any(layer_name.endswith(try_name) for try_name in layers[depth]):\n                    params = HypertileParams()\n                    module.__webui_hypertile_params = params\n                    params.forward = module.forward\n                    params.depth = depth\n                    params.layer_name = layer_name\n                    module.forward = self_attn_forward(params)\n\n                    hypertile_layers[layer_name] = 1\n\n        model.__webui_hypertile_layers = hypertile_layers\n\n    aspect_ratio = width / height\n    tile_size = min(largest_tile_size_available(width, height), tile_size_max)\n\n    for layer_name, module in model.named_modules():\n        if layer_name in hypertile_layers:\n            params = module.__webui_hypertile_params\n\n            params.tile_size = tile_size\n            params.swap_size = swap_size\n            params.aspect_ratio = aspect_ratio\n            params.enabled = enable and params.depth <= max_depth\n"
  },
  {
    "path": "extensions-builtin/hypertile/scripts/hypertile_script.py",
    "content": "import hypertile\r\nfrom modules import scripts, script_callbacks, shared\r\n\r\n\r\nclass ScriptHypertile(scripts.Script):\r\n    name = \"Hypertile\"\r\n\r\n    def title(self):\r\n        return self.name\r\n\r\n    def show(self, is_img2img):\r\n        return scripts.AlwaysVisible\r\n\r\n    def process(self, p, *args):\r\n        hypertile.set_hypertile_seed(p.all_seeds[0])\r\n\r\n        configure_hypertile(p.width, p.height, enable_unet=shared.opts.hypertile_enable_unet)\r\n\r\n        self.add_infotext(p)\r\n\r\n    def before_hr(self, p, *args):\r\n\r\n        enable = shared.opts.hypertile_enable_unet_secondpass or shared.opts.hypertile_enable_unet\r\n\r\n        # exclusive hypertile seed for the second pass\r\n        if enable:\r\n            hypertile.set_hypertile_seed(p.all_seeds[0])\r\n\r\n        configure_hypertile(p.hr_upscale_to_x, p.hr_upscale_to_y, enable_unet=enable)\r\n\r\n        if enable and not shared.opts.hypertile_enable_unet:\r\n            p.extra_generation_params[\"Hypertile U-Net second pass\"] = True\r\n\r\n            self.add_infotext(p, add_unet_params=True)\r\n\r\n    def add_infotext(self, p, add_unet_params=False):\r\n        def option(name):\r\n            value = getattr(shared.opts, name)\r\n            default_value = shared.opts.get_default(name)\r\n            return None if value == default_value else value\r\n\r\n        if shared.opts.hypertile_enable_unet:\r\n            p.extra_generation_params[\"Hypertile U-Net\"] = True\r\n\r\n        if shared.opts.hypertile_enable_unet or add_unet_params:\r\n            p.extra_generation_params[\"Hypertile U-Net max depth\"] = option('hypertile_max_depth_unet')\r\n            p.extra_generation_params[\"Hypertile U-Net max tile size\"] = option('hypertile_max_tile_unet')\r\n            p.extra_generation_params[\"Hypertile U-Net swap size\"] = option('hypertile_swap_size_unet')\r\n\r\n        if shared.opts.hypertile_enable_vae:\r\n            p.extra_generation_params[\"Hypertile VAE\"] = True\r\n            p.extra_generation_params[\"Hypertile VAE max depth\"] = option('hypertile_max_depth_vae')\r\n            p.extra_generation_params[\"Hypertile VAE max tile size\"] = option('hypertile_max_tile_vae')\r\n            p.extra_generation_params[\"Hypertile VAE swap size\"] = option('hypertile_swap_size_vae')\r\n\r\n\r\ndef configure_hypertile(width, height, enable_unet=True):\r\n    hypertile.hypertile_hook_model(\r\n        shared.sd_model.first_stage_model,\r\n        width,\r\n        height,\r\n        swap_size=shared.opts.hypertile_swap_size_vae,\r\n        max_depth=shared.opts.hypertile_max_depth_vae,\r\n        tile_size_max=shared.opts.hypertile_max_tile_vae,\r\n        enable=shared.opts.hypertile_enable_vae,\r\n    )\r\n\r\n    hypertile.hypertile_hook_model(\r\n        shared.sd_model.model,\r\n        width,\r\n        height,\r\n        swap_size=shared.opts.hypertile_swap_size_unet,\r\n        max_depth=shared.opts.hypertile_max_depth_unet,\r\n        tile_size_max=shared.opts.hypertile_max_tile_unet,\r\n        enable=enable_unet,\r\n        is_sdxl=shared.sd_model.is_sdxl\r\n    )\r\n\r\n\r\ndef on_ui_settings():\r\n    import gradio as gr\r\n\r\n    options = {\r\n        \"hypertile_explanation\": shared.OptionHTML(\"\"\"\r\n    <a href='https://github.com/tfernd/HyperTile'>Hypertile</a> optimizes the self-attention layer within U-Net and VAE models,\r\n    resulting in a reduction in computation time ranging from 1 to 4 times. The larger the generated image is, the greater the\r\n    benefit.\r\n    \"\"\"),\r\n\r\n        \"hypertile_enable_unet\": shared.OptionInfo(False, \"Enable Hypertile U-Net\", infotext=\"Hypertile U-Net\").info(\"enables hypertile for all modes, including hires fix second pass; noticeable change in details of the generated picture\"),\r\n        \"hypertile_enable_unet_secondpass\": shared.OptionInfo(False, \"Enable Hypertile U-Net for hires fix second pass\", infotext=\"Hypertile U-Net second pass\").info(\"enables hypertile just for hires fix second pass - regardless of whether the above setting is enabled\"),\r\n        \"hypertile_max_depth_unet\": shared.OptionInfo(3, \"Hypertile U-Net max depth\", gr.Slider, {\"minimum\": 0, \"maximum\": 3, \"step\": 1}, infotext=\"Hypertile U-Net max depth\").info(\"larger = more neural network layers affected; minor effect on performance\"),\r\n        \"hypertile_max_tile_unet\": shared.OptionInfo(256, \"Hypertile U-Net max tile size\", gr.Slider, {\"minimum\": 0, \"maximum\": 512, \"step\": 16}, infotext=\"Hypertile U-Net max tile size\").info(\"larger = worse performance\"),\r\n        \"hypertile_swap_size_unet\": shared.OptionInfo(3, \"Hypertile U-Net swap size\", gr.Slider, {\"minimum\": 0, \"maximum\": 64, \"step\": 1}, infotext=\"Hypertile U-Net swap size\"),\r\n        \"hypertile_enable_vae\": shared.OptionInfo(False, \"Enable Hypertile VAE\", infotext=\"Hypertile VAE\").info(\"minimal change in the generated picture\"),\r\n        \"hypertile_max_depth_vae\": shared.OptionInfo(3, \"Hypertile VAE max depth\", gr.Slider, {\"minimum\": 0, \"maximum\": 3, \"step\": 1}, infotext=\"Hypertile VAE max depth\"),\r\n        \"hypertile_max_tile_vae\": shared.OptionInfo(128, \"Hypertile VAE max tile size\", gr.Slider, {\"minimum\": 0, \"maximum\": 512, \"step\": 16}, infotext=\"Hypertile VAE max tile size\"),\r\n        \"hypertile_swap_size_vae\": shared.OptionInfo(3, \"Hypertile VAE swap size \", gr.Slider, {\"minimum\": 0, \"maximum\": 64, \"step\": 1}, infotext=\"Hypertile VAE swap size\"),\r\n    }\r\n\r\n    for name, opt in options.items():\r\n        opt.section = ('hypertile', \"Hypertile\")\r\n        shared.opts.add_option(name, opt)\r\n\r\n\r\ndef add_axis_options():\r\n    xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ == \"xyz_grid.py\"][0].module\r\n    xyz_grid.axis_options.extend([\r\n        xyz_grid.AxisOption(\"[Hypertile] Unet First pass Enabled\", str, xyz_grid.apply_override('hypertile_enable_unet', boolean=True), choices=xyz_grid.boolean_choice(reverse=True)),\r\n        xyz_grid.AxisOption(\"[Hypertile] Unet Second pass Enabled\", str, xyz_grid.apply_override('hypertile_enable_unet_secondpass', boolean=True), choices=xyz_grid.boolean_choice(reverse=True)),\r\n        xyz_grid.AxisOption(\"[Hypertile] Unet Max Depth\", int, xyz_grid.apply_override(\"hypertile_max_depth_unet\"), confirm=xyz_grid.confirm_range(0, 3, '[Hypertile] Unet Max Depth'), choices=lambda: [str(x) for x in range(4)]),\r\n        xyz_grid.AxisOption(\"[Hypertile] Unet Max Tile Size\", int, xyz_grid.apply_override(\"hypertile_max_tile_unet\"), confirm=xyz_grid.confirm_range(0, 512, '[Hypertile] Unet Max Tile Size')),\r\n        xyz_grid.AxisOption(\"[Hypertile] Unet Swap Size\", int, xyz_grid.apply_override(\"hypertile_swap_size_unet\"), confirm=xyz_grid.confirm_range(0, 64, '[Hypertile] Unet Swap Size')),\r\n        xyz_grid.AxisOption(\"[Hypertile] VAE Enabled\", str, xyz_grid.apply_override('hypertile_enable_vae', boolean=True), choices=xyz_grid.boolean_choice(reverse=True)),\r\n        xyz_grid.AxisOption(\"[Hypertile] VAE Max Depth\", int, xyz_grid.apply_override(\"hypertile_max_depth_vae\"), confirm=xyz_grid.confirm_range(0, 3, '[Hypertile] VAE Max Depth'), choices=lambda: [str(x) for x in range(4)]),\r\n        xyz_grid.AxisOption(\"[Hypertile] VAE Max Tile Size\", int, xyz_grid.apply_override(\"hypertile_max_tile_vae\"), confirm=xyz_grid.confirm_range(0, 512, '[Hypertile] VAE Max Tile Size')),\r\n        xyz_grid.AxisOption(\"[Hypertile] VAE Swap Size\", int, xyz_grid.apply_override(\"hypertile_swap_size_vae\"), confirm=xyz_grid.confirm_range(0, 64, '[Hypertile] VAE Swap Size')),\r\n    ])\r\n\r\n\r\nscript_callbacks.on_ui_settings(on_ui_settings)\r\nscript_callbacks.on_before_ui(add_axis_options)\r\n"
  },
  {
    "path": "extensions-builtin/mobile/javascript/mobile.js",
    "content": "var isSetupForMobile = false;\n\nfunction isMobile() {\n    for (var tab of [\"txt2img\", \"img2img\"]) {\n        var imageTab = gradioApp().getElementById(tab + '_results');\n        if (imageTab && imageTab.offsetParent && imageTab.offsetLeft == 0) {\n            return true;\n        }\n    }\n\n    return false;\n}\n\nfunction reportWindowSize() {\n    if (gradioApp().querySelector('.toprow-compact-tools')) return; // not applicable for compact prompt layout\n\n    var currentlyMobile = isMobile();\n    if (currentlyMobile == isSetupForMobile) return;\n    isSetupForMobile = currentlyMobile;\n\n    for (var tab of [\"txt2img\", \"img2img\"]) {\n        var button = gradioApp().getElementById(tab + '_generate_box');\n        var target = gradioApp().getElementById(currentlyMobile ? tab + '_results' : tab + '_actions_column');\n        target.insertBefore(button, target.firstElementChild);\n\n        gradioApp().getElementById(tab + '_results').classList.toggle('mobile', currentlyMobile);\n    }\n}\n\nwindow.addEventListener(\"resize\", reportWindowSize);\n\nonUiLoaded(function() {\n    reportWindowSize();\n});\n"
  },
  {
    "path": "extensions-builtin/postprocessing-for-training/scripts/postprocessing_autosized_crop.py",
    "content": "from PIL import Image\r\n\r\nfrom modules import scripts_postprocessing, ui_components\r\nimport gradio as gr\r\n\r\n\r\ndef center_crop(image: Image, w: int, h: int):\r\n    iw, ih = image.size\r\n    if ih / h < iw / w:\r\n        sw = w * ih / h\r\n        box = (iw - sw) / 2, 0, iw - (iw - sw) / 2, ih\r\n    else:\r\n        sh = h * iw / w\r\n        box = 0, (ih - sh) / 2, iw, ih - (ih - sh) / 2\r\n    return image.resize((w, h), Image.Resampling.LANCZOS, box)\r\n\r\n\r\ndef multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, threshold):\r\n    iw, ih = image.size\r\n    err = lambda w, h: 1 - (lambda x: x if x < 1 else 1 / x)(iw / ih / (w / h))\r\n    wh = max(((w, h) for w in range(mindim, maxdim + 1, 64) for h in range(mindim, maxdim + 1, 64)\r\n              if minarea <= w * h <= maxarea and err(w, h) <= threshold),\r\n             key=lambda wh: (wh[0] * wh[1], -err(*wh))[::1 if objective == 'Maximize area' else -1],\r\n             default=None\r\n             )\r\n    return wh and center_crop(image, *wh)\r\n\r\n\r\nclass ScriptPostprocessingAutosizedCrop(scripts_postprocessing.ScriptPostprocessing):\r\n    name = \"Auto-sized crop\"\r\n    order = 4020\r\n\r\n    def ui(self):\r\n        with ui_components.InputAccordion(False, label=\"Auto-sized crop\") as enable:\r\n            gr.Markdown('Each image is center-cropped with an automatically chosen width and height.')\r\n            with gr.Row():\r\n                mindim = gr.Slider(minimum=64, maximum=2048, step=8, label=\"Dimension lower bound\", value=384, elem_id=\"postprocess_multicrop_mindim\")\r\n                maxdim = gr.Slider(minimum=64, maximum=2048, step=8, label=\"Dimension upper bound\", value=768, elem_id=\"postprocess_multicrop_maxdim\")\r\n            with gr.Row():\r\n                minarea = gr.Slider(minimum=64 * 64, maximum=2048 * 2048, step=1, label=\"Area lower bound\", value=64 * 64, elem_id=\"postprocess_multicrop_minarea\")\r\n                maxarea = gr.Slider(minimum=64 * 64, maximum=2048 * 2048, step=1, label=\"Area upper bound\", value=640 * 640, elem_id=\"postprocess_multicrop_maxarea\")\r\n            with gr.Row():\r\n                objective = gr.Radio([\"Maximize area\", \"Minimize error\"], value=\"Maximize area\", label=\"Resizing objective\", elem_id=\"postprocess_multicrop_objective\")\r\n                threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label=\"Error threshold\", value=0.1, elem_id=\"postprocess_multicrop_threshold\")\r\n\r\n        return {\r\n            \"enable\": enable,\r\n            \"mindim\": mindim,\r\n            \"maxdim\": maxdim,\r\n            \"minarea\": minarea,\r\n            \"maxarea\": maxarea,\r\n            \"objective\": objective,\r\n            \"threshold\": threshold,\r\n        }\r\n\r\n    def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, mindim, maxdim, minarea, maxarea, objective, threshold):\r\n        if not enable:\r\n            return\r\n\r\n        cropped = multicrop_pic(pp.image, mindim, maxdim, minarea, maxarea, objective, threshold)\r\n        if cropped is not None:\r\n            pp.image = cropped\r\n        else:\r\n            print(f\"skipped {pp.image.width}x{pp.image.height} image (can't find suitable size within error threshold)\")\r\n"
  },
  {
    "path": "extensions-builtin/postprocessing-for-training/scripts/postprocessing_caption.py",
    "content": "from modules import scripts_postprocessing, ui_components, deepbooru, shared\r\nimport gradio as gr\r\n\r\n\r\nclass ScriptPostprocessingCeption(scripts_postprocessing.ScriptPostprocessing):\r\n    name = \"Caption\"\r\n    order = 4040\r\n\r\n    def ui(self):\r\n        with ui_components.InputAccordion(False, label=\"Caption\") as enable:\r\n            option = gr.CheckboxGroup(value=[\"Deepbooru\"], choices=[\"Deepbooru\", \"BLIP\"], show_label=False)\r\n\r\n        return {\r\n            \"enable\": enable,\r\n            \"option\": option,\r\n        }\r\n\r\n    def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, option):\r\n        if not enable:\r\n            return\r\n\r\n        captions = [pp.caption]\r\n\r\n        if \"Deepbooru\" in option:\r\n            captions.append(deepbooru.model.tag(pp.image))\r\n\r\n        if \"BLIP\" in option:\r\n            captions.append(shared.interrogator.interrogate(pp.image.convert(\"RGB\")))\r\n\r\n        pp.caption = \", \".join([x for x in captions if x])\r\n"
  },
  {
    "path": "extensions-builtin/postprocessing-for-training/scripts/postprocessing_create_flipped_copies.py",
    "content": "from PIL import ImageOps, Image\r\n\r\nfrom modules import scripts_postprocessing, ui_components\r\nimport gradio as gr\r\n\r\n\r\nclass ScriptPostprocessingCreateFlippedCopies(scripts_postprocessing.ScriptPostprocessing):\r\n    name = \"Create flipped copies\"\r\n    order = 4030\r\n\r\n    def ui(self):\r\n        with ui_components.InputAccordion(False, label=\"Create flipped copies\") as enable:\r\n            with gr.Row():\r\n                option = gr.CheckboxGroup(value=[\"Horizontal\"], choices=[\"Horizontal\", \"Vertical\", \"Both\"], show_label=False)\r\n\r\n        return {\r\n            \"enable\": enable,\r\n            \"option\": option,\r\n        }\r\n\r\n    def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, option):\r\n        if not enable:\r\n            return\r\n\r\n        if \"Horizontal\" in option:\r\n            pp.extra_images.append(ImageOps.mirror(pp.image))\r\n\r\n        if \"Vertical\" in option:\r\n            pp.extra_images.append(pp.image.transpose(Image.Transpose.FLIP_TOP_BOTTOM))\r\n\r\n        if \"Both\" in option:\r\n            pp.extra_images.append(pp.image.transpose(Image.Transpose.FLIP_TOP_BOTTOM).transpose(Image.Transpose.FLIP_LEFT_RIGHT))\r\n"
  },
  {
    "path": "extensions-builtin/postprocessing-for-training/scripts/postprocessing_focal_crop.py",
    "content": "\r\nfrom modules import scripts_postprocessing, ui_components, errors\r\nimport gradio as gr\r\n\r\nfrom modules.textual_inversion import autocrop\r\n\r\n\r\nclass ScriptPostprocessingFocalCrop(scripts_postprocessing.ScriptPostprocessing):\r\n    name = \"Auto focal point crop\"\r\n    order = 4010\r\n\r\n    def ui(self):\r\n        with ui_components.InputAccordion(False, label=\"Auto focal point crop\") as enable:\r\n            face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id=\"postprocess_focal_crop_face_weight\")\r\n            entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id=\"postprocess_focal_crop_entropy_weight\")\r\n            edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id=\"postprocess_focal_crop_edges_weight\")\r\n            debug = gr.Checkbox(label='Create debug image', elem_id=\"train_process_focal_crop_debug\")\r\n\r\n        return {\r\n            \"enable\": enable,\r\n            \"face_weight\": face_weight,\r\n            \"entropy_weight\": entropy_weight,\r\n            \"edges_weight\": edges_weight,\r\n            \"debug\": debug,\r\n        }\r\n\r\n    def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, face_weight, entropy_weight, edges_weight, debug):\r\n        if not enable:\r\n            return\r\n\r\n        if not pp.shared.target_width or not pp.shared.target_height:\r\n            return\r\n\r\n        dnn_model_path = None\r\n        try:\r\n            dnn_model_path = autocrop.download_and_cache_models()\r\n        except Exception:\r\n            errors.report(\"Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.\", exc_info=True)\r\n\r\n        autocrop_settings = autocrop.Settings(\r\n            crop_width=pp.shared.target_width,\r\n            crop_height=pp.shared.target_height,\r\n            face_points_weight=face_weight,\r\n            entropy_points_weight=entropy_weight,\r\n            corner_points_weight=edges_weight,\r\n            annotate_image=debug,\r\n            dnn_model_path=dnn_model_path,\r\n        )\r\n\r\n        result, *others = autocrop.crop_image(pp.image, autocrop_settings)\r\n\r\n        pp.image = result\r\n        pp.extra_images = [pp.create_copy(x, nametags=[\"focal-crop-debug\"], disable_processing=True) for x in others]\r\n\r\n"
  },
  {
    "path": "extensions-builtin/postprocessing-for-training/scripts/postprocessing_split_oversized.py",
    "content": "import math\r\n\r\nfrom modules import scripts_postprocessing, ui_components\r\nimport gradio as gr\r\n\r\n\r\ndef split_pic(image, inverse_xy, width, height, overlap_ratio):\r\n    if inverse_xy:\r\n        from_w, from_h = image.height, image.width\r\n        to_w, to_h = height, width\r\n    else:\r\n        from_w, from_h = image.width, image.height\r\n        to_w, to_h = width, height\r\n    h = from_h * to_w // from_w\r\n    if inverse_xy:\r\n        image = image.resize((h, to_w))\r\n    else:\r\n        image = image.resize((to_w, h))\r\n\r\n    split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio)))\r\n    y_step = (h - to_h) / (split_count - 1)\r\n    for i in range(split_count):\r\n        y = int(y_step * i)\r\n        if inverse_xy:\r\n            splitted = image.crop((y, 0, y + to_h, to_w))\r\n        else:\r\n            splitted = image.crop((0, y, to_w, y + to_h))\r\n        yield splitted\r\n\r\n\r\nclass ScriptPostprocessingSplitOversized(scripts_postprocessing.ScriptPostprocessing):\r\n    name = \"Split oversized images\"\r\n    order = 4000\r\n\r\n    def ui(self):\r\n        with ui_components.InputAccordion(False, label=\"Split oversized images\") as enable:\r\n            with gr.Row():\r\n                split_threshold = gr.Slider(label='Threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id=\"postprocess_split_threshold\")\r\n                overlap_ratio = gr.Slider(label='Overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id=\"postprocess_overlap_ratio\")\r\n\r\n        return {\r\n            \"enable\": enable,\r\n            \"split_threshold\": split_threshold,\r\n            \"overlap_ratio\": overlap_ratio,\r\n        }\r\n\r\n    def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, split_threshold, overlap_ratio):\r\n        if not enable:\r\n            return\r\n\r\n        width = pp.shared.target_width\r\n        height = pp.shared.target_height\r\n\r\n        if not width or not height:\r\n            return\r\n\r\n        if pp.image.height > pp.image.width:\r\n            ratio = (pp.image.width * height) / (pp.image.height * width)\r\n            inverse_xy = False\r\n        else:\r\n            ratio = (pp.image.height * width) / (pp.image.width * height)\r\n            inverse_xy = True\r\n\r\n        if ratio >= 1.0 or ratio > split_threshold:\r\n            return\r\n\r\n        result, *others = split_pic(pp.image, inverse_xy, width, height, overlap_ratio)\r\n\r\n        pp.image = result\r\n        pp.extra_images = [pp.create_copy(x) for x in others]\r\n\r\n"
  },
  {
    "path": "extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js",
    "content": "// Stable Diffusion WebUI - Bracket checker\n// By Hingashi no Florin/Bwin4L & @akx\n// Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs.\n// If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong.\n\nfunction checkBrackets(textArea, counterElt) {\n    var counts = {};\n    (textArea.value.match(/[(){}[\\]]/g) || []).forEach(bracket => {\n        counts[bracket] = (counts[bracket] || 0) + 1;\n    });\n    var errors = [];\n\n    function checkPair(open, close, kind) {\n        if (counts[open] !== counts[close]) {\n            errors.push(\n                `${open}...${close} - Detected ${counts[open] || 0} opening and ${counts[close] || 0} closing ${kind}.`\n            );\n        }\n    }\n\n    checkPair('(', ')', 'round brackets');\n    checkPair('[', ']', 'square brackets');\n    checkPair('{', '}', 'curly brackets');\n    counterElt.title = errors.join('\\n');\n    counterElt.classList.toggle('error', errors.length !== 0);\n}\n\nfunction setupBracketChecking(id_prompt, id_counter) {\n    var textarea = gradioApp().querySelector(\"#\" + id_prompt + \" > label > textarea\");\n    var counter = gradioApp().getElementById(id_counter);\n\n    if (textarea && counter) {\n        textarea.addEventListener(\"input\", () => checkBrackets(textarea, counter));\n    }\n}\n\nonUiLoaded(function() {\n    setupBracketChecking('txt2img_prompt', 'txt2img_token_counter');\n    setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter');\n    setupBracketChecking('img2img_prompt', 'img2img_token_counter');\n    setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter');\n});\n"
  },
  {
    "path": "extensions-builtin/soft-inpainting/scripts/soft_inpainting.py",
    "content": "import numpy as np\nimport gradio as gr\nimport math\nfrom modules.ui_components import InputAccordion\nimport modules.scripts as scripts\nfrom modules.torch_utils import float64\n\n\nclass SoftInpaintingSettings:\n    def __init__(self,\n                 mask_blend_power,\n                 mask_blend_scale,\n                 inpaint_detail_preservation,\n                 composite_mask_influence,\n                 composite_difference_threshold,\n                 composite_difference_contrast):\n        self.mask_blend_power = mask_blend_power\n        self.mask_blend_scale = mask_blend_scale\n        self.inpaint_detail_preservation = inpaint_detail_preservation\n        self.composite_mask_influence = composite_mask_influence\n        self.composite_difference_threshold = composite_difference_threshold\n        self.composite_difference_contrast = composite_difference_contrast\n\n    def add_generation_params(self, dest):\n        dest[enabled_gen_param_label] = True\n        dest[gen_param_labels.mask_blend_power] = self.mask_blend_power\n        dest[gen_param_labels.mask_blend_scale] = self.mask_blend_scale\n        dest[gen_param_labels.inpaint_detail_preservation] = self.inpaint_detail_preservation\n        dest[gen_param_labels.composite_mask_influence] = self.composite_mask_influence\n        dest[gen_param_labels.composite_difference_threshold] = self.composite_difference_threshold\n        dest[gen_param_labels.composite_difference_contrast] = self.composite_difference_contrast\n\n\n# ------------------- Methods -------------------\n\ndef processing_uses_inpainting(p):\n    # TODO: Figure out a better way to determine if inpainting is being used by p\n    if getattr(p, \"image_mask\", None) is not None:\n        return True\n\n    if getattr(p, \"mask\", None) is not None:\n        return True\n\n    if getattr(p, \"nmask\", None) is not None:\n        return True\n\n    return False\n\n\ndef latent_blend(settings, a, b, t):\n    \"\"\"\n    Interpolates two latent image representations according to the parameter t,\n    where the interpolated vectors' magnitudes are also interpolated separately.\n    The \"detail_preservation\" factor biases the magnitude interpolation towards\n    the larger of the two magnitudes.\n    \"\"\"\n    import torch\n\n    # NOTE: We use inplace operations wherever possible.\n\n    if len(t.shape) == 3:\n        # [4][w][h] to [1][4][w][h]\n        t2 = t.unsqueeze(0)\n        # [4][w][h] to [1][1][w][h] - the [4] seem redundant.\n        t3 = t[0].unsqueeze(0).unsqueeze(0)\n    else:\n        t2 = t\n        t3 = t[:, 0][:, None]\n\n    one_minus_t2 = 1 - t2\n    one_minus_t3 = 1 - t3\n\n    # Linearly interpolate the image vectors.\n    a_scaled = a * one_minus_t2\n    b_scaled = b * t2\n    image_interp = a_scaled\n    image_interp.add_(b_scaled)\n    result_type = image_interp.dtype\n    del a_scaled, b_scaled, t2, one_minus_t2\n\n    # Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.)\n    # 64-bit operations are used here to allow large exponents.\n    current_magnitude = torch.norm(image_interp, p=2, dim=1, keepdim=True).to(float64(image_interp)).add_(0.00001)\n\n    # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1).\n    a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(float64(a)).pow_(settings.inpaint_detail_preservation) * one_minus_t3\n    b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(float64(b)).pow_(settings.inpaint_detail_preservation) * t3\n    desired_magnitude = a_magnitude\n    desired_magnitude.add_(b_magnitude).pow_(1 / settings.inpaint_detail_preservation)\n    del a_magnitude, b_magnitude, t3, one_minus_t3\n\n    # Change the linearly interpolated image vectors' magnitudes to the value we want.\n    # This is the last 64-bit operation.\n    image_interp_scaling_factor = desired_magnitude\n    image_interp_scaling_factor.div_(current_magnitude)\n    image_interp_scaling_factor = image_interp_scaling_factor.to(result_type)\n    image_interp_scaled = image_interp\n    image_interp_scaled.mul_(image_interp_scaling_factor)\n    del current_magnitude\n    del desired_magnitude\n    del image_interp\n    del image_interp_scaling_factor\n    del result_type\n\n    return image_interp_scaled\n\n\ndef get_modified_nmask(settings, nmask, sigma):\n    \"\"\"\n    Converts a negative mask representing the transparency of the original latent vectors being overlaid\n    to a mask that is scaled according to the denoising strength for this step.\n\n    Where:\n        0 = fully opaque, infinite density, fully masked\n        1 = fully transparent, zero density, fully unmasked\n\n    We bring this transparency to a power, as this allows one to simulate N number of blending operations\n    where N can be any positive real value. Using this one can control the balance of influence between\n    the denoiser and the original latents according to the sigma value.\n\n    NOTE: \"mask\" is not used\n    \"\"\"\n    import torch\n    return torch.pow(nmask, (sigma ** settings.mask_blend_power) * settings.mask_blend_scale)\n\n\ndef apply_adaptive_masks(\n        settings: SoftInpaintingSettings,\n        nmask,\n        latent_orig,\n        latent_processed,\n        overlay_images,\n        width, height,\n        paste_to):\n    import torch\n    import modules.processing as proc\n    import modules.images as images\n    from PIL import Image, ImageOps, ImageFilter\n\n    # TODO: Bias the blending according to the latent mask, add adjustable parameter for bias control.\n    if len(nmask.shape) == 3:\n        latent_mask = nmask[0].float()\n    else:\n        latent_mask = nmask[:, 0].float()\n    # convert the original mask into a form we use to scale distances for thresholding\n    mask_scalar = 1 - (torch.clamp(latent_mask, min=0, max=1) ** (settings.mask_blend_scale / 2))\n    mask_scalar = (0.5 * (1 - settings.composite_mask_influence)\n                   + mask_scalar * settings.composite_mask_influence)\n    mask_scalar = mask_scalar / (1.00001 - mask_scalar)\n    mask_scalar = mask_scalar.cpu().numpy()\n\n    latent_distance = torch.norm(latent_processed - latent_orig, p=2, dim=1)\n\n    kernel, kernel_center = get_gaussian_kernel(stddev_radius=1.5, max_radius=2)\n\n    masks_for_overlay = []\n\n    for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, overlay_images)):\n        converted_mask = distance_map.float().cpu().numpy()\n        converted_mask = weighted_histogram_filter(converted_mask, kernel, kernel_center,\n                                                   percentile_min=0.9, percentile_max=1, min_width=1)\n        converted_mask = weighted_histogram_filter(converted_mask, kernel, kernel_center,\n                                                   percentile_min=0.25, percentile_max=0.75, min_width=1)\n\n        # The distance at which opacity of original decreases to 50%\n        if len(mask_scalar.shape) == 3:\n            if mask_scalar.shape[0] > i:\n                half_weighted_distance = settings.composite_difference_threshold * mask_scalar[i]\n            else:\n                half_weighted_distance = settings.composite_difference_threshold * mask_scalar[0]\n        else:\n            half_weighted_distance = settings.composite_difference_threshold * mask_scalar\n\n        converted_mask = converted_mask / half_weighted_distance\n\n        converted_mask = 1 / (1 + converted_mask ** settings.composite_difference_contrast)\n        converted_mask = smootherstep(converted_mask)\n        converted_mask = 1 - converted_mask\n        converted_mask = 255. * converted_mask\n        converted_mask = converted_mask.astype(np.uint8)\n        converted_mask = Image.fromarray(converted_mask)\n        converted_mask = images.resize_image(2, converted_mask, width, height)\n        converted_mask = proc.create_binary_mask(converted_mask, round=False)\n\n        # Remove aliasing artifacts using a gaussian blur.\n        converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4))\n\n        # Expand the mask to fit the whole image if needed.\n        if paste_to is not None:\n            converted_mask = proc.uncrop(converted_mask,\n                                         (overlay_image.width, overlay_image.height),\n                                         paste_to)\n\n        masks_for_overlay.append(converted_mask)\n\n        image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height))\n        image_masked.paste(overlay_image.convert(\"RGBA\").convert(\"RGBa\"),\n                           mask=ImageOps.invert(converted_mask.convert('L')))\n\n        overlay_images[i] = image_masked.convert('RGBA')\n\n    return masks_for_overlay\n\n\ndef apply_masks(\n        settings,\n        nmask,\n        overlay_images,\n        width, height,\n        paste_to):\n    import torch\n    import modules.processing as proc\n    import modules.images as images\n    from PIL import Image, ImageOps, ImageFilter\n\n    converted_mask = nmask[0].float()\n    converted_mask = torch.clamp(converted_mask, min=0, max=1).pow_(settings.mask_blend_scale / 2)\n    converted_mask = 255. * converted_mask\n    converted_mask = converted_mask.cpu().numpy().astype(np.uint8)\n    converted_mask = Image.fromarray(converted_mask)\n    converted_mask = images.resize_image(2, converted_mask, width, height)\n    converted_mask = proc.create_binary_mask(converted_mask, round=False)\n\n    # Remove aliasing artifacts using a gaussian blur.\n    converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4))\n\n    # Expand the mask to fit the whole image if needed.\n    if paste_to is not None:\n        converted_mask = proc.uncrop(converted_mask,\n                                     (width, height),\n                                     paste_to)\n\n    masks_for_overlay = []\n\n    for i, overlay_image in enumerate(overlay_images):\n        masks_for_overlay[i] = converted_mask\n\n        image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height))\n        image_masked.paste(overlay_image.convert(\"RGBA\").convert(\"RGBa\"),\n                           mask=ImageOps.invert(converted_mask.convert('L')))\n\n        overlay_images[i] = image_masked.convert('RGBA')\n\n    return masks_for_overlay\n\n\ndef weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, percentile_max=1.0, min_width=1.0):\n    \"\"\"\n    Generalization convolution filter capable of applying\n    weighted mean, median, maximum, and minimum filters\n    parametrically using an arbitrary kernel.\n\n    Args:\n        img (nparray):\n            The image, a 2-D array of floats, to which the filter is being applied.\n        kernel (nparray):\n            The kernel, a 2-D array of floats.\n        kernel_center (nparray):\n            The kernel center coordinate, a 1-D array with two elements.\n        percentile_min (float):\n            The lower bound of the histogram window used by the filter,\n            from 0 to 1.\n        percentile_max (float):\n            The upper bound of the histogram window used by the filter,\n            from 0 to 1.\n        min_width (float):\n            The minimum size of the histogram window bounds, in weight units.\n            Must be greater than 0.\n\n    Returns:\n        (nparray): A filtered copy of the input image \"img\", a 2-D array of floats.\n    \"\"\"\n\n    # Converts an index tuple into a vector.\n    def vec(x):\n        return np.array(x)\n\n    kernel_min = -kernel_center\n    kernel_max = vec(kernel.shape) - kernel_center\n\n    def weighted_histogram_filter_single(idx):\n        idx = vec(idx)\n        min_index = np.maximum(0, idx + kernel_min)\n        max_index = np.minimum(vec(img.shape), idx + kernel_max)\n        window_shape = max_index - min_index\n\n        class WeightedElement:\n            \"\"\"\n            An element of the histogram, its weight\n            and bounds.\n            \"\"\"\n\n            def __init__(self, value, weight):\n                self.value: float = value\n                self.weight: float = weight\n                self.window_min: float = 0.0\n                self.window_max: float = 1.0\n\n        # Collect the values in the image as WeightedElements,\n        # weighted by their corresponding kernel values.\n        values = []\n        for window_tup in np.ndindex(tuple(window_shape)):\n            window_index = vec(window_tup)\n            image_index = window_index + min_index\n            centered_kernel_index = image_index - idx\n            kernel_index = centered_kernel_index + kernel_center\n            element = WeightedElement(img[tuple(image_index)], kernel[tuple(kernel_index)])\n            values.append(element)\n\n        def sort_key(x: WeightedElement):\n            return x.value\n\n        values.sort(key=sort_key)\n\n        # Calculate the height of the stack (sum)\n        # and each sample's range they occupy in the stack\n        sum = 0\n        for i in range(len(values)):\n            values[i].window_min = sum\n            sum += values[i].weight\n            values[i].window_max = sum\n\n        # Calculate what range of this stack (\"window\")\n        # we want to get the weighted average across.\n        window_min = sum * percentile_min\n        window_max = sum * percentile_max\n        window_width = window_max - window_min\n\n        # Ensure the window is within the stack and at least a certain size.\n        if window_width < min_width:\n            window_center = (window_min + window_max) / 2\n            window_min = window_center - min_width / 2\n            window_max = window_center + min_width / 2\n\n            if window_max > sum:\n                window_max = sum\n                window_min = sum - min_width\n\n            if window_min < 0:\n                window_min = 0\n                window_max = min_width\n\n        value = 0\n        value_weight = 0\n\n        # Get the weighted average of all the samples\n        # that overlap with the window, weighted\n        # by the size of their overlap.\n        for i in range(len(values)):\n            if window_min >= values[i].window_max:\n                continue\n            if window_max <= values[i].window_min:\n                break\n\n            s = max(window_min, values[i].window_min)\n            e = min(window_max, values[i].window_max)\n            w = e - s\n\n            value += values[i].value * w\n            value_weight += w\n\n        return value / value_weight if value_weight != 0 else 0\n\n    img_out = img.copy()\n\n    # Apply the kernel operation over each pixel.\n    for index in np.ndindex(img.shape):\n        img_out[index] = weighted_histogram_filter_single(index)\n\n    return img_out\n\n\ndef smoothstep(x):\n    \"\"\"\n    The smoothstep function, input should be clamped to 0-1 range.\n    Turns a diagonal line (f(x) = x) into a sigmoid-like curve.\n    \"\"\"\n    return x * x * (3 - 2 * x)\n\n\ndef smootherstep(x):\n    \"\"\"\n    The smootherstep function, input should be clamped to 0-1 range.\n    Turns a diagonal line (f(x) = x) into a sigmoid-like curve.\n    \"\"\"\n    return x * x * x * (x * (6 * x - 15) + 10)\n\n\ndef get_gaussian_kernel(stddev_radius=1.0, max_radius=2):\n    \"\"\"\n    Creates a Gaussian kernel with thresholded edges.\n\n    Args:\n        stddev_radius (float):\n            Standard deviation of the gaussian kernel, in pixels.\n        max_radius (int):\n            The size of the filter kernel. The number of pixels is (max_radius*2+1) ** 2.\n            The kernel is thresholded so that any values one pixel beyond this radius\n            is weighted at 0.\n\n    Returns:\n        (nparray, nparray): A kernel array (shape: (N, N)), its center coordinate (shape: (2))\n    \"\"\"\n\n    # Evaluates a 0-1 normalized gaussian function for a given square distance from the mean.\n    def gaussian(sqr_mag):\n        return math.exp(-sqr_mag / (stddev_radius * stddev_radius))\n\n    # Helper function for converting a tuple to an array.\n    def vec(x):\n        return np.array(x)\n\n    \"\"\"\n    Since a gaussian is unbounded, we need to limit ourselves\n    to a finite range.\n    We taper the ends off at the end of that range so they equal zero\n    while preserving the maximum value of 1 at the mean.\n    \"\"\"\n    zero_radius = max_radius + 1.0\n    gauss_zero = gaussian(zero_radius * zero_radius)\n    gauss_kernel_scale = 1 / (1 - gauss_zero)\n\n    def gaussian_kernel_func(coordinate):\n        x = coordinate[0] ** 2.0 + coordinate[1] ** 2.0\n        x = gaussian(x)\n        x -= gauss_zero\n        x *= gauss_kernel_scale\n        x = max(0.0, x)\n        return x\n\n    size = max_radius * 2 + 1\n    kernel_center = max_radius\n    kernel = np.zeros((size, size))\n\n    for index in np.ndindex(kernel.shape):\n        kernel[index] = gaussian_kernel_func(vec(index) - kernel_center)\n\n    return kernel, kernel_center\n\n\n# ------------------- Constants -------------------\n\n\ndefault = SoftInpaintingSettings(1, 0.5, 4, 0, 0.5, 2)\n\nenabled_ui_label = \"Soft inpainting\"\nenabled_gen_param_label = \"Soft inpainting enabled\"\nenabled_el_id = \"soft_inpainting_enabled\"\n\nui_labels = SoftInpaintingSettings(\n    \"Schedule bias\",\n    \"Preservation strength\",\n    \"Transition contrast boost\",\n    \"Mask influence\",\n    \"Difference threshold\",\n    \"Difference contrast\")\n\nui_info = SoftInpaintingSettings(\n    \"Shifts when preservation of original content occurs during denoising.\",\n    \"How strongly partially masked content should be preserved.\",\n    \"Amplifies the contrast that may be lost in partially masked regions.\",\n    \"How strongly the original mask should bias the difference threshold.\",\n    \"How much an image region can change before the original pixels are not blended in anymore.\",\n    \"How sharp the transition should be between blended and not blended.\")\n\ngen_param_labels = SoftInpaintingSettings(\n    \"Soft inpainting schedule bias\",\n    \"Soft inpainting preservation strength\",\n    \"Soft inpainting transition contrast boost\",\n    \"Soft inpainting mask influence\",\n    \"Soft inpainting difference threshold\",\n    \"Soft inpainting difference contrast\")\n\nel_ids = SoftInpaintingSettings(\n    \"mask_blend_power\",\n    \"mask_blend_scale\",\n    \"inpaint_detail_preservation\",\n    \"composite_mask_influence\",\n    \"composite_difference_threshold\",\n    \"composite_difference_contrast\")\n\n\n# ------------------- Script -------------------\n\n\nclass Script(scripts.Script):\n    def __init__(self):\n        self.section = \"inpaint\"\n        self.masks_for_overlay = None\n        self.overlay_images = None\n\n    def title(self):\n        return \"Soft Inpainting\"\n\n    def show(self, is_img2img):\n        return scripts.AlwaysVisible if is_img2img else False\n\n    def ui(self, is_img2img):\n        if not is_img2img:\n            return\n\n        with InputAccordion(False, label=enabled_ui_label, elem_id=enabled_el_id) as soft_inpainting_enabled:\n            with gr.Group():\n                gr.Markdown(\n                    \"\"\"\n                    Soft inpainting allows you to **seamlessly blend original content with inpainted content** according to the mask opacity.\n                    **High _Mask blur_** values are recommended!\n                    \"\"\")\n\n                power = \\\n                    gr.Slider(label=ui_labels.mask_blend_power,\n                              info=ui_info.mask_blend_power,\n                              minimum=0,\n                              maximum=8,\n                              step=0.1,\n                              value=default.mask_blend_power,\n                              elem_id=el_ids.mask_blend_power)\n                scale = \\\n                    gr.Slider(label=ui_labels.mask_blend_scale,\n                              info=ui_info.mask_blend_scale,\n                              minimum=0,\n                              maximum=8,\n                              step=0.05,\n                              value=default.mask_blend_scale,\n                              elem_id=el_ids.mask_blend_scale)\n                detail = \\\n                    gr.Slider(label=ui_labels.inpaint_detail_preservation,\n                              info=ui_info.inpaint_detail_preservation,\n                              minimum=1,\n                              maximum=32,\n                              step=0.5,\n                              value=default.inpaint_detail_preservation,\n                              elem_id=el_ids.inpaint_detail_preservation)\n\n                gr.Markdown(\n                    \"\"\"\n                    ### Pixel Composite Settings\n                    \"\"\")\n\n                mask_inf = \\\n                    gr.Slider(label=ui_labels.composite_mask_influence,\n                              info=ui_info.composite_mask_influence,\n                              minimum=0,\n                              maximum=1,\n                              step=0.05,\n                              value=default.composite_mask_influence,\n                              elem_id=el_ids.composite_mask_influence)\n\n                dif_thresh = \\\n                    gr.Slider(label=ui_labels.composite_difference_threshold,\n                              info=ui_info.composite_difference_threshold,\n                              minimum=0,\n                              maximum=8,\n                              step=0.25,\n                              value=default.composite_difference_threshold,\n                              elem_id=el_ids.composite_difference_threshold)\n\n                dif_contr = \\\n                    gr.Slider(label=ui_labels.composite_difference_contrast,\n                              info=ui_info.composite_difference_contrast,\n                              minimum=0,\n                              maximum=8,\n                              step=0.25,\n                              value=default.composite_difference_contrast,\n                              elem_id=el_ids.composite_difference_contrast)\n\n                with gr.Accordion(\"Help\", open=False):\n                    gr.Markdown(\n                        f\"\"\"\n                        ### {ui_labels.mask_blend_power}\n\n                        The blending strength of original content is scaled proportionally with the decreasing noise level values at each step (sigmas).\n                        This ensures that the influence of the denoiser and original content preservation is roughly balanced at each step.\n                        This balance can be shifted using this parameter, controlling whether earlier or later steps have stronger preservation.\n\n                        - **Below 1**: Stronger preservation near the end (with low sigma)\n                        - **1**: Balanced (proportional to sigma)\n                        - **Above 1**: Stronger preservation in the beginning (with high sigma)\n                        \"\"\")\n                    gr.Markdown(\n                        f\"\"\"\n                        ### {ui_labels.mask_blend_scale}\n\n                        Skews whether partially masked image regions should be more likely to preserve the original content or favor inpainted content.\n                        This may need to be adjusted depending on the {ui_labels.mask_blend_power}, CFG Scale, prompt and Denoising strength.\n\n                        - **Low values**: Favors generated content.\n                        - **High values**: Favors original content.\n                        \"\"\")\n                    gr.Markdown(\n                        f\"\"\"\n                        ### {ui_labels.inpaint_detail_preservation}\n\n                        This parameter controls how the original latent vectors and denoised latent vectors are interpolated.\n                        With higher values, the magnitude of the resulting blended vector will be closer to the maximum of the two interpolated vectors.\n                        This can prevent the loss of contrast that occurs with linear interpolation.\n\n                        - **Low values**: Softer blending, details may fade.\n                        - **High values**: Stronger contrast, may over-saturate colors.\n                        \"\"\")\n\n                    gr.Markdown(\n                        \"\"\"\n                        ## Pixel Composite Settings\n\n                        Masks are generated based on how much a part of the image changed after denoising.\n                        These masks are used to blend the original and final images together.\n                        If the difference is low, the original pixels are used instead of the pixels returned by the inpainting process.\n                        \"\"\")\n\n                    gr.Markdown(\n                        f\"\"\"\n                        ### {ui_labels.composite_mask_influence}\n\n                        This parameter controls how much the mask should bias this sensitivity to difference.\n\n                        - **0**: Ignore the mask, only consider differences in image content.\n                        - **1**: Follow the mask closely despite image content changes.\n                        \"\"\")\n\n                    gr.Markdown(\n                        f\"\"\"\n                        ### {ui_labels.composite_difference_threshold}\n\n                        This value represents the difference at which the original pixels will have less than 50% opacity.\n\n                        - **Low values**: Two images patches must be almost the same in order to retain original pixels.\n                        - **High values**: Two images patches can be very different and still retain original pixels.\n                        \"\"\")\n\n                    gr.Markdown(\n                        f\"\"\"\n                        ### {ui_labels.composite_difference_contrast}\n\n                        This value represents the contrast between the opacity of the original and inpainted content.\n\n                        - **Low values**: The blend will be more gradual and have longer transitions, but may cause ghosting.\n                        - **High values**: Ghosting will be less common, but transitions may be very sudden.\n                        \"\"\")\n\n        self.infotext_fields = [(soft_inpainting_enabled, enabled_gen_param_label),\n                                (power, gen_param_labels.mask_blend_power),\n                                (scale, gen_param_labels.mask_blend_scale),\n                                (detail, gen_param_labels.inpaint_detail_preservation),\n                                (mask_inf, gen_param_labels.composite_mask_influence),\n                                (dif_thresh, gen_param_labels.composite_difference_threshold),\n                                (dif_contr, gen_param_labels.composite_difference_contrast)]\n\n        self.paste_field_names = []\n        for _, field_name in self.infotext_fields:\n            self.paste_field_names.append(field_name)\n\n        return [soft_inpainting_enabled,\n                power,\n                scale,\n                detail,\n                mask_inf,\n                dif_thresh,\n                dif_contr]\n\n    def process(self, p, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr):\n        if not enabled:\n            return\n\n        if not processing_uses_inpainting(p):\n            return\n\n        # Shut off the rounding it normally does.\n        p.mask_round = False\n\n        settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr)\n\n        # p.extra_generation_params[\"Mask rounding\"] = False\n        settings.add_generation_params(p.extra_generation_params)\n\n    def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, detail_preservation, mask_inf,\n                      dif_thresh, dif_contr):\n        if not enabled:\n            return\n\n        if not processing_uses_inpainting(p):\n            return\n\n        if mba.is_final_blend:\n            mba.blended_latent = mba.current_latent\n            return\n\n        settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr)\n\n        # todo: Why is sigma 2D? Both values are the same.\n        mba.blended_latent = latent_blend(settings,\n                                          mba.init_latent,\n                                          mba.current_latent,\n                                          get_modified_nmask(settings, mba.nmask, mba.sigma[0]))\n\n    def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, detail_preservation, mask_inf,\n                    dif_thresh, dif_contr):\n        if not enabled:\n            return\n\n        if not processing_uses_inpainting(p):\n            return\n\n        nmask = getattr(p, \"nmask\", None)\n        if nmask is None:\n            return\n\n        from modules import images\n        from modules.shared import opts\n\n        settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr)\n\n        # since the original code puts holes in the existing overlay images,\n        # we have to rebuild them.\n        self.overlay_images = []\n        for img in p.init_images:\n\n            image = images.flatten(img, opts.img2img_background_color)\n\n            if p.paste_to is None and p.resize_mode != 3:\n                image = images.resize_image(p.resize_mode, image, p.width, p.height)\n\n            self.overlay_images.append(image.convert('RGBA'))\n\n        if len(p.init_images) == 1:\n            self.overlay_images = self.overlay_images * p.batch_size\n\n        if getattr(ps.samples, 'already_decoded', False):\n            self.masks_for_overlay = apply_masks(settings=settings,\n                                                 nmask=nmask,\n                                                 overlay_images=self.overlay_images,\n                                                 width=p.width,\n                                                 height=p.height,\n                                                 paste_to=p.paste_to)\n        else:\n            self.masks_for_overlay = apply_adaptive_masks(settings=settings,\n                                                          nmask=nmask,\n                                                          latent_orig=p.init_latent,\n                                                          latent_processed=ps.samples,\n                                                          overlay_images=self.overlay_images,\n                                                          width=p.width,\n                                                          height=p.height,\n                                                          paste_to=p.paste_to)\n\n    def postprocess_maskoverlay(self, p, ppmo: scripts.PostProcessMaskOverlayArgs, enabled, power, scale,\n                                detail_preservation, mask_inf, dif_thresh, dif_contr):\n        if not enabled:\n            return\n\n        if not processing_uses_inpainting(p):\n            return\n\n        if self.masks_for_overlay is None:\n            return\n\n        if self.overlay_images is None:\n            return\n\n        ppmo.mask_for_overlay = self.masks_for_overlay[ppmo.index]\n        ppmo.overlay_image = self.overlay_images[ppmo.index]\n"
  },
  {
    "path": "html/extra-networks-card.html",
    "content": "<div class=\"card\" style=\"{style}\" onclick=\"{card_clicked}\" data-name=\"{name}\" {sort_keys}>\n\t{background_image}\n\t<div class=\"button-row\">{copy_path_button}{metadata_button}{edit_button}</div>\n\t<div class=\"actions\">\n\t\t<div class=\"additional\">{search_terms}</div>\n\t\t<span class=\"name\">{name}</span>\n\t\t<span class=\"description\">{description}</span>\n\t</div>\n</div>\n"
  },
  {
    "path": "html/extra-networks-copy-path-button.html",
    "content": "<div class=\"copy-path-button card-button\"\n    title=\"Copy path to clipboard\"\n    onclick=\"extraNetworksCopyCardPath(event)\"\n    data-clipboard-text=\"{filename}\">\n</div>"
  },
  {
    "path": "html/extra-networks-edit-item-button.html",
    "content": "<div class=\"edit-button card-button\"\n    title=\"Edit metadata\"\n    onclick=\"extraNetworksEditUserMetadata(event, '{tabname}', '{extra_networks_tabname}')\">\n</div>"
  },
  {
    "path": "html/extra-networks-metadata-button.html",
    "content": "<div class=\"metadata-button card-button\"\n    title=\"Show internal metadata\"\n    onclick=\"extraNetworksRequestMetadata(event, '{extra_networks_tabname}')\">\n</div>"
  },
  {
    "path": "html/extra-networks-no-cards.html",
    "content": "<div class='nocards'>\n<h1>Nothing here. Add some content to the following directories:</h1>\n\n<ul>\n{dirs}\n</ul>\n</div>\n\n"
  },
  {
    "path": "html/extra-networks-pane-dirs.html",
    "content": "    <div class=\"extra-network-pane-content-dirs\">\r\n        <div id='{tabname}_{extra_networks_tabname}_dirs' class='extra-network-dirs'>\r\n            {dirs_html}\r\n        </div>\r\n        <div id='{tabname}_{extra_networks_tabname}_cards' class='extra-network-cards'>\r\n            {items_html}\r\n        </div>\r\n    </div>\r\n"
  },
  {
    "path": "html/extra-networks-pane-tree.html",
    "content": "    <div class=\"extra-network-pane-content-tree resize-handle-row\">\r\n        <div id='{tabname}_{extra_networks_tabname}_tree' class='extra-network-tree' style='flex-basis: {extra_networks_tree_view_default_width}px'>\r\n            {tree_html}\r\n        </div>\r\n        <div id='{tabname}_{extra_networks_tabname}_cards' class='extra-network-cards' style='flex-grow: 1;'>\r\n            {items_html}\r\n        </div>\r\n    </div>"
  },
  {
    "path": "html/extra-networks-pane.html",
    "content": "<div id='{tabname}_{extra_networks_tabname}_pane' class='extra-network-pane {tree_view_div_default_display_class}'>\n    <div class=\"extra-network-control\" id=\"{tabname}_{extra_networks_tabname}_controls\" style=\"display:none\" >\n        <div class=\"extra-network-control--search\">\n            <input\n                id=\"{tabname}_{extra_networks_tabname}_extra_search\"\n                class=\"extra-network-control--search-text\"\n                type=\"search\"\n                placeholder=\"Search\"\n            >\n        </div>\n\n        <small>Sort: </small>\n        <div \n            id=\"{tabname}_{extra_networks_tabname}_extra_sort_path\"\n            class=\"extra-network-control--sort{sort_path_active}\"\n            data-sortkey=\"default\"\n            title=\"Sort by path\"\n            onclick=\"extraNetworksControlSortOnClick(event, '{tabname}', '{extra_networks_tabname}');\"\n        >\n            <i class=\"extra-network-control--icon extra-network-control--sort-icon\"></i>\n        </div>\n        <div\n            id=\"{tabname}_{extra_networks_tabname}_extra_sort_name\"\n            class=\"extra-network-control--sort{sort_name_active}\"\n            data-sortkey=\"name\"\n            title=\"Sort by name\"\n            onclick=\"extraNetworksControlSortOnClick(event, '{tabname}', '{extra_networks_tabname}');\"\n        >\n            <i class=\"extra-network-control--icon extra-network-control--sort-icon\"></i>\n        </div>\n        <div\n            id=\"{tabname}_{extra_networks_tabname}_extra_sort_date_created\"\n            class=\"extra-network-control--sort{sort_date_created_active}\"\n            data-sortkey=\"date_created\"\n            title=\"Sort by date created\"\n            onclick=\"extraNetworksControlSortOnClick(event, '{tabname}', '{extra_networks_tabname}');\"\n        >\n            <i class=\"extra-network-control--icon extra-network-control--sort-icon\"></i>\n        </div>\n        <div\n            id=\"{tabname}_{extra_networks_tabname}_extra_sort_date_modified\"\n            class=\"extra-network-control--sort{sort_date_modified_active}\"\n            data-sortkey=\"date_modified\"\n            title=\"Sort by date modified\"\n            onclick=\"extraNetworksControlSortOnClick(event, '{tabname}', '{extra_networks_tabname}');\"\n        >\n            <i class=\"extra-network-control--icon extra-network-control--sort-icon\"></i>\n        </div>\n\n        <small> </small>\n        <div\n            id=\"{tabname}_{extra_networks_tabname}_extra_sort_dir\"\n            class=\"extra-network-control--sort-dir\"\n            data-sortdir=\"{data_sortdir}\"\n            title=\"Sort ascending\"\n            onclick=\"extraNetworksControlSortDirOnClick(event, '{tabname}', '{extra_networks_tabname}');\"\n        >\n            <i class=\"extra-network-control--icon extra-network-control--sort-dir-icon\"></i>\n        </div>\n\n\n        <small> </small>\n        <div\n            id=\"{tabname}_{extra_networks_tabname}_extra_tree_view\"\n            class=\"extra-network-control--tree-view {tree_view_btn_extra_class}\"\n            title=\"Enable Tree View\"\n            onclick=\"extraNetworksControlTreeViewOnClick(event, '{tabname}', '{extra_networks_tabname}');\"\n        >\n            <i class=\"extra-network-control--icon extra-network-control--tree-view-icon\"></i>\n        </div>\n        <div\n            id=\"{tabname}_{extra_networks_tabname}_extra_refresh\"\n            class=\"extra-network-control--refresh\"\n            title=\"Refresh page\"\n            onclick=\"extraNetworksControlRefreshOnClick(event, '{tabname}', '{extra_networks_tabname}');\"\n        >\n            <i class=\"extra-network-control--icon extra-network-control--refresh-icon\"></i>\n        </div>\n    </div>\n    {pane_content}\n</div>\n"
  },
  {
    "path": "html/extra-networks-tree-button.html",
    "content": "<span data-filterable-item-text hidden>{search_terms}</span>\n<div class=\"tree-list-content {subclass}\"\n    type=\"button\"\n    onclick=\"extraNetworksTreeOnClick(event, '{tabname}', '{extra_networks_tabname}');{onclick_extra}\"\n    data-path=\"{data_path}\"\n    data-hash=\"{data_hash}\"\n>\n    <span class='tree-list-item-action tree-list-item-action--leading'>\n        {action_list_item_action_leading}\n    </span>\n    <span class=\"tree-list-item-visual tree-list-item-visual--leading\">\n        {action_list_item_visual_leading}\n    </span>\n    <span class=\"tree-list-item-label tree-list-item-label--truncate\">\n        {action_list_item_label}\n    </span>\n    <span class=\"tree-list-item-visual tree-list-item-visual--trailing\">\n        {action_list_item_visual_trailing}\n    </span>\n    <span class=\"tree-list-item-action tree-list-item-action--trailing\">\n        {action_list_item_action_trailing}\n    </span>\n</div>"
  },
  {
    "path": "html/footer.html",
    "content": "<div>\r\n        <a href=\"{api_docs}\">API</a>\r\n         • \r\n        <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui\">Github</a>\r\n         • \r\n        <a href=\"https://gradio.app\">Gradio</a>\r\n         • \r\n        <a href=\"#\" onclick=\"showProfile('./internal/profile-startup'); return false;\">Startup profile</a>\r\n         • \r\n        <a href=\"/\" onclick=\"javascript:gradioApp().getElementById('settings_restart_gradio').click(); return false\">Reload UI</a>\r\n</div>\r\n<br />\r\n<div class=\"versions\">\r\n{versions}\r\n</div>\r\n"
  },
  {
    "path": "html/licenses.html",
    "content": "<style>\r\n    #licenses h2 {font-size: 1.2em; font-weight: bold; margin-bottom: 0.2em;}\r\n    #licenses small {font-size: 0.95em; opacity: 0.85;}\r\n    #licenses pre { margin: 1em 0 2em 0;}\r\n</style>\r\n\r\n<h2><a href=\"https://github.com/invoke-ai/InvokeAI/blob/main/LICENSE\">InvokeAI</a></h2>\r\n<small>Some code for compatibility with OSX is taken from lstein's repository.</small>\r\n<pre>\r\nMIT License\r\n\r\nCopyright (c) 2022 InvokeAI Team\r\n\r\nPermission is hereby granted, free of charge, to any person obtaining a copy\r\nof this software and associated documentation files (the \"Software\"), to deal\r\nin the Software without restriction, including without limitation the rights\r\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\r\ncopies of the Software, and to permit persons to whom the Software is\r\nfurnished to do so, subject to the following conditions:\r\n\r\nThe above copyright notice and this permission notice shall be included in all\r\ncopies or substantial portions of the Software.\r\n\r\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\r\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\r\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\r\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\r\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\r\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\r\nSOFTWARE.\r\n</pre>\r\n\r\n<h2><a href=\"https://github.com/Hafiidz/latent-diffusion/blob/main/LICENSE\">LDSR</a></h2>\r\n<small>Code added by contirubtors, most likely copied from this repository.</small>\r\n<pre>\r\nMIT License\r\n\r\nCopyright (c) 2022 Machine Vision and Learning Group, LMU Munich\r\n\r\nPermission is hereby granted, free of charge, to any person obtaining a copy\r\nof this software and associated documentation files (the \"Software\"), to deal\r\nin the Software without restriction, including without limitation the rights\r\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\r\ncopies of the Software, and to permit persons to whom the Software is\r\nfurnished to do so, subject to the following conditions:\r\n\r\nThe above copyright notice and this permission notice shall be included in all\r\ncopies or substantial portions of the Software.\r\n\r\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\r\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\r\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\r\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\r\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\r\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\r\nSOFTWARE.\r\n</pre>\r\n\r\n<h2><a href=\"https://github.com/pharmapsychotic/clip-interrogator/blob/main/LICENSE\">CLIP Interrogator</a></h2>\r\n<small>Some small amounts of code borrowed and reworked.</small>\r\n<pre>\r\nMIT License\r\n\r\nCopyright (c) 2022 pharmapsychotic\r\n\r\nPermission is hereby granted, free of charge, to any person obtaining a copy\r\nof this software and associated documentation files (the \"Software\"), to deal\r\nin the Software without restriction, including without limitation the rights\r\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\r\ncopies of the Software, and to permit persons to whom the Software is\r\nfurnished to do so, subject to the following conditions:\r\n\r\nThe above copyright notice and this permission notice shall be included in all\r\ncopies or substantial portions of the Software.\r\n\r\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\r\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\r\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\r\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\r\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\r\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\r\nSOFTWARE.\r\n</pre>\r\n\r\n<h2><a href=\"https://github.com/AminRezaei0x443/memory-efficient-attention/blob/main/LICENSE\">Memory Efficient Attention</a></h2>\r\n<small>The sub-quadratic cross attention optimization uses modified code from the Memory Efficient Attention package that Alex Birch optimized for 3D tensors. This license is updated to reflect that.</small>\r\n<pre>\r\nMIT License\r\n\r\nCopyright (c) 2023 Alex Birch\r\nCopyright (c) 2023 Amin Rezaei\r\n\r\nPermission is hereby granted, free of charge, to any person obtaining a copy\r\nof this software and associated documentation files (the \"Software\"), to deal\r\nin the Software without restriction, including without limitation the rights\r\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\r\ncopies of the Software, and to permit persons to whom the Software is\r\nfurnished to do so, subject to the following conditions:\r\n\r\nThe above copyright notice and this permission notice shall be included in all\r\ncopies or substantial portions of the Software.\r\n\r\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\r\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\r\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\r\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\r\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\r\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\r\nSOFTWARE.\r\n</pre>\r\n\r\n<h2><a href=\"https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/LICENSE\">Scaled Dot Product Attention</a></h2>\r\n<small>Some small amounts of code borrowed and reworked.</small>\r\n<pre>\r\n   Copyright 2023 The HuggingFace Team. All rights reserved.\r\n\r\n   Licensed under the Apache License, Version 2.0 (the \"License\");\r\n   you may not use this file except in compliance with the License.\r\n   You may obtain a copy of the License at\r\n\r\n      http://www.apache.org/licenses/LICENSE-2.0\r\n\r\n   Unless required by applicable law or agreed to in writing, software\r\n   distributed under the License is distributed on an \"AS IS\" BASIS,\r\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n   See the License for the specific language governing permissions and\r\n   limitations under the License.\r\n\r\n                                 Apache License\r\n                           Version 2.0, January 2004\r\n                        http://www.apache.org/licenses/\r\n\r\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\r\n\r\n   1. Definitions.\r\n\r\n      \"License\" shall mean the terms and conditions for use, reproduction,\r\n      and distribution as defined by Sections 1 through 9 of this document.\r\n\r\n      \"Licensor\" shall mean the copyright owner or entity authorized by\r\n      the copyright owner that is granting the License.\r\n\r\n      \"Legal Entity\" shall mean the union of the acting entity and all\r\n      other entities that control, are controlled by, or are under common\r\n      control with that entity. For the purposes of this definition,\r\n      \"control\" means (i) the power, direct or indirect, to cause the\r\n      direction or management of such entity, whether by contract or\r\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\r\n      outstanding shares, or (iii) beneficial ownership of such entity.\r\n\r\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\r\n      exercising permissions granted by this License.\r\n\r\n      \"Source\" form shall mean the preferred form for making modifications,\r\n      including but not limited to software source code, documentation\r\n      source, and configuration files.\r\n\r\n      \"Object\" form shall mean any form resulting from mechanical\r\n      transformation or translation of a Source form, including but\r\n      not limited to compiled object code, generated documentation,\r\n      and conversions to other media types.\r\n\r\n      \"Work\" shall mean the work of authorship, whether in Source or\r\n      Object form, made available under the License, as indicated by a\r\n      copyright notice that is included in or attached to the work\r\n      (an example is provided in the Appendix below).\r\n\r\n      \"Derivative Works\" shall mean any work, whether in Source or Object\r\n      form, that is based on (or derived from) the Work and for which the\r\n      editorial revisions, annotations, elaborations, or other modifications\r\n      represent, as a whole, an original work of authorship. For the purposes\r\n      of this License, Derivative Works shall not include works that remain\r\n      separable from, or merely link (or bind by name) to the interfaces of,\r\n      the Work and Derivative Works thereof.\r\n\r\n      \"Contribution\" shall mean any work of authorship, including\r\n      the original version of the Work and any modifications or additions\r\n      to that Work or Derivative Works thereof, that is intentionally\r\n      submitted to Licensor for inclusion in the Work by the copyright owner\r\n      or by an individual or Legal Entity authorized to submit on behalf of\r\n      the copyright owner. For the purposes of this definition, \"submitted\"\r\n      means any form of electronic, verbal, or written communication sent\r\n      to the Licensor or its representatives, including but not limited to\r\n      communication on electronic mailing lists, source code control systems,\r\n      and issue tracking systems that are managed by, or on behalf of, the\r\n      Licensor for the purpose of discussing and improving the Work, but\r\n      excluding communication that is conspicuously marked or otherwise\r\n      designated in writing by the copyright owner as \"Not a Contribution.\"\r\n\r\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\r\n      on behalf of whom a Contribution has been received by Licensor and\r\n      subsequently incorporated within the Work.\r\n\r\n   2. Grant of Copyright License. Subject to the terms and conditions of\r\n      this License, each Contributor hereby grants to You a perpetual,\r\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\r\n      copyright license to reproduce, prepare Derivative Works of,\r\n      publicly display, publicly perform, sublicense, and distribute the\r\n      Work and such Derivative Works in Source or Object form.\r\n\r\n   3. Grant of Patent License. Subject to the terms and conditions of\r\n      this License, each Contributor hereby grants to You a perpetual,\r\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\r\n      (except as stated in this section) patent license to make, have made,\r\n      use, offer to sell, sell, import, and otherwise transfer the Work,\r\n      where such license applies only to those patent claims licensable\r\n      by such Contributor that are necessarily infringed by their\r\n      Contribution(s) alone or by combination of their Contribution(s)\r\n      with the Work to which such Contribution(s) was submitted. If You\r\n      institute patent litigation against any entity (including a\r\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\r\n      or a Contribution incorporated within the Work constitutes direct\r\n      or contributory patent infringement, then any patent licenses\r\n      granted to You under this License for that Work shall terminate\r\n      as of the date such litigation is filed.\r\n\r\n   4. Redistribution. You may reproduce and distribute copies of the\r\n      Work or Derivative Works thereof in any medium, with or without\r\n      modifications, and in Source or Object form, provided that You\r\n      meet the following conditions:\r\n\r\n      (a) You must give any other recipients of the Work or\r\n          Derivative Works a copy of this License; and\r\n\r\n      (b) You must cause any modified files to carry prominent notices\r\n          stating that You changed the files; and\r\n\r\n      (c) You must retain, in the Source form of any Derivative Works\r\n          that You distribute, all copyright, patent, trademark, and\r\n          attribution notices from the Source form of the Work,\r\n          excluding those notices that do not pertain to any part of\r\n          the Derivative Works; and\r\n\r\n      (d) If the Work includes a \"NOTICE\" text file as part of its\r\n          distribution, then any Derivative Works that You distribute must\r\n          include a readable copy of the attribution notices contained\r\n          within such NOTICE file, excluding those notices that do not\r\n          pertain to any part of the Derivative Works, in at least one\r\n          of the following places: within a NOTICE text file distributed\r\n          as part of the Derivative Works; within the Source form or\r\n          documentation, if provided along with the Derivative Works; or,\r\n          within a display generated by the Derivative Works, if and\r\n          wherever such third-party notices normally appear. The contents\r\n          of the NOTICE file are for informational purposes only and\r\n          do not modify the License. You may add Your own attribution\r\n          notices within Derivative Works that You distribute, alongside\r\n          or as an addendum to the NOTICE text from the Work, provided\r\n          that such additional attribution notices cannot be construed\r\n          as modifying the License.\r\n\r\n      You may add Your own copyright statement to Your modifications and\r\n      may provide additional or different license terms and conditions\r\n      for use, reproduction, or distribution of Your modifications, or\r\n      for any such Derivative Works as a whole, provided Your use,\r\n      reproduction, and distribution of the Work otherwise complies with\r\n      the conditions stated in this License.\r\n\r\n   5. Submission of Contributions. Unless You explicitly state otherwise,\r\n      any Contribution intentionally submitted for inclusion in the Work\r\n      by You to the Licensor shall be under the terms and conditions of\r\n      this License, without any additional terms or conditions.\r\n      Notwithstanding the above, nothing herein shall supersede or modify\r\n      the terms of any separate license agreement you may have executed\r\n      with Licensor regarding such Contributions.\r\n\r\n   6. Trademarks. This License does not grant permission to use the trade\r\n      names, trademarks, service marks, or product names of the Licensor,\r\n      except as required for reasonable and customary use in describing the\r\n      origin of the Work and reproducing the content of the NOTICE file.\r\n\r\n   7. Disclaimer of Warranty. Unless required by applicable law or\r\n      agreed to in writing, Licensor provides the Work (and each\r\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\r\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\r\n      implied, including, without limitation, any warranties or conditions\r\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\r\n      PARTICULAR PURPOSE. You are solely responsible for determining the\r\n      appropriateness of using or redistributing the Work and assume any\r\n      risks associated with Your exercise of permissions under this License.\r\n\r\n   8. Limitation of Liability. In no event and under no legal theory,\r\n      whether in tort (including negligence), contract, or otherwise,\r\n      unless required by applicable law (such as deliberate and grossly\r\n      negligent acts) or agreed to in writing, shall any Contributor be\r\n      liable to You for damages, including any direct, indirect, special,\r\n      incidental, or consequential damages of any character arising as a\r\n      result of this License or out of the use or inability to use the\r\n      Work (including but not limited to damages for loss of goodwill,\r\n      work stoppage, computer failure or malfunction, or any and all\r\n      other commercial damages or losses), even if such Contributor\r\n      has been advised of the possibility of such damages.\r\n\r\n   9. Accepting Warranty or Additional Liability. While redistributing\r\n      the Work or Derivative Works thereof, You may choose to offer,\r\n      and charge a fee for, acceptance of support, warranty, indemnity,\r\n      or other liability obligations and/or rights consistent with this\r\n      License. However, in accepting such obligations, You may act only\r\n      on Your own behalf and on Your sole responsibility, not on behalf\r\n      of any other Contributor, and only if You agree to indemnify,\r\n      defend, and hold each Contributor harmless for any liability\r\n      incurred by, or claims asserted against, such Contributor by reason\r\n      of your accepting any such warranty or additional liability.\r\n\r\n   END OF TERMS AND CONDITIONS\r\n\r\n   APPENDIX: How to apply the Apache License to your work.\r\n\r\n      To apply the Apache License to your work, attach the following\r\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\r\n      replaced with your own identifying information. (Don't include\r\n      the brackets!)  The text should be enclosed in the appropriate\r\n      comment syntax for the file format. We also recommend that a\r\n      file or class name and description of purpose be included on the\r\n      same \"printed page\" as the copyright notice for easier\r\n      identification within third-party archives.\r\n\r\n   Copyright [yyyy] [name of copyright owner]\r\n\r\n   Licensed under the Apache License, Version 2.0 (the \"License\");\r\n   you may not use this file except in compliance with the License.\r\n   You may obtain a copy of the License at\r\n\r\n       http://www.apache.org/licenses/LICENSE-2.0\r\n\r\n   Unless required by applicable law or agreed to in writing, software\r\n   distributed under the License is distributed on an \"AS IS\" BASIS,\r\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r\n   See the License for the specific language governing permissions and\r\n   limitations under the License.\r\n</pre>\r\n\r\n<h2><a href=\"https://github.com/explosion/curated-transformers/blob/main/LICENSE\">Curated transformers</a></h2>\r\n<small>The MPS workaround for nn.Linear on macOS 13.2.X is based on the MPS workaround for nn.Linear created by danieldk for Curated transformers</small>\r\n<pre>\r\nThe MIT License (MIT)\r\n\r\nCopyright (C) 2021 ExplosionAI GmbH\r\n\r\nPermission is hereby granted, free of charge, to any person obtaining a copy\r\nof this software and associated documentation files (the \"Software\"), to deal\r\nin the Software without restriction, including without limitation the rights\r\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\r\ncopies of the Software, and to permit persons to whom the Software is\r\nfurnished to do so, subject to the following conditions:\r\n\r\nThe above copyright notice and this permission notice shall be included in\r\nall copies or substantial portions of the Software.\r\n\r\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\r\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\r\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\r\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\r\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\r\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\r\nTHE SOFTWARE.\r\n</pre>\r\n\r\n<h2><a href=\"https://github.com/madebyollin/taesd/blob/main/LICENSE\">TAESD</a></h2>\r\n<small>Tiny AutoEncoder for Stable Diffusion option for live previews</small>\r\n<pre>\r\nMIT License\r\n\r\nCopyright (c) 2023 Ollin Boer Bohan\r\n\r\nPermission is hereby granted, free of charge, to any person obtaining a copy\r\nof this software and associated documentation files (the \"Software\"), to deal\r\nin the Software without restriction, including without limitation the rights\r\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\r\ncopies of the Software, and to permit persons to whom the Software is\r\nfurnished to do so, subject to the following conditions:\r\n\r\nThe above copyright notice and this permission notice shall be included in all\r\ncopies or substantial portions of the Software.\r\n\r\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\r\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\r\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\r\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\r\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\r\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\r\nSOFTWARE.\r\n</pre>\r\n"
  },
  {
    "path": "javascript/aspectRatioOverlay.js",
    "content": "\nlet currentWidth = null;\nlet currentHeight = null;\nlet arFrameTimeout = setTimeout(function() {}, 0);\n\nfunction dimensionChange(e, is_width, is_height) {\n\n    if (is_width) {\n        currentWidth = e.target.value * 1.0;\n    }\n    if (is_height) {\n        currentHeight = e.target.value * 1.0;\n    }\n\n    var inImg2img = gradioApp().querySelector(\"#tab_img2img\").style.display == \"block\";\n\n    if (!inImg2img) {\n        return;\n    }\n\n    var targetElement = null;\n\n    var tabIndex = get_tab_index('mode_img2img');\n    if (tabIndex == 0) { // img2img\n        targetElement = gradioApp().querySelector('#img2img_image div[data-testid=image] img');\n    } else if (tabIndex == 1) { //Sketch\n        targetElement = gradioApp().querySelector('#img2img_sketch div[data-testid=image] img');\n    } else if (tabIndex == 2) { // Inpaint\n        targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img');\n    } else if (tabIndex == 3) { // Inpaint sketch\n        targetElement = gradioApp().querySelector('#inpaint_sketch div[data-testid=image] img');\n    }\n\n\n    if (targetElement) {\n\n        var arPreviewRect = gradioApp().querySelector('#imageARPreview');\n        if (!arPreviewRect) {\n            arPreviewRect = document.createElement('div');\n            arPreviewRect.id = \"imageARPreview\";\n            gradioApp().appendChild(arPreviewRect);\n        }\n\n\n\n        var viewportOffset = targetElement.getBoundingClientRect();\n\n        var viewportscale = Math.min(targetElement.clientWidth / targetElement.naturalWidth, targetElement.clientHeight / targetElement.naturalHeight);\n\n        var scaledx = targetElement.naturalWidth * viewportscale;\n        var scaledy = targetElement.naturalHeight * viewportscale;\n\n        var clientRectTop = (viewportOffset.top + window.scrollY);\n        var clientRectLeft = (viewportOffset.left + window.scrollX);\n        var clientRectCentreY = clientRectTop + (targetElement.clientHeight / 2);\n        var clientRectCentreX = clientRectLeft + (targetElement.clientWidth / 2);\n\n        var arscale = Math.min(scaledx / currentWidth, scaledy / currentHeight);\n        var arscaledx = currentWidth * arscale;\n        var arscaledy = currentHeight * arscale;\n\n        var arRectTop = clientRectCentreY - (arscaledy / 2);\n        var arRectLeft = clientRectCentreX - (arscaledx / 2);\n        var arRectWidth = arscaledx;\n        var arRectHeight = arscaledy;\n\n        arPreviewRect.style.top = arRectTop + 'px';\n        arPreviewRect.style.left = arRectLeft + 'px';\n        arPreviewRect.style.width = arRectWidth + 'px';\n        arPreviewRect.style.height = arRectHeight + 'px';\n\n        clearTimeout(arFrameTimeout);\n        arFrameTimeout = setTimeout(function() {\n            arPreviewRect.style.display = 'none';\n        }, 2000);\n\n        arPreviewRect.style.display = 'block';\n\n    }\n\n}\n\n\nonAfterUiUpdate(function() {\n    var arPreviewRect = gradioApp().querySelector('#imageARPreview');\n    if (arPreviewRect) {\n        arPreviewRect.style.display = 'none';\n    }\n    var tabImg2img = gradioApp().querySelector(\"#tab_img2img\");\n    if (tabImg2img) {\n        var inImg2img = tabImg2img.style.display == \"block\";\n        if (inImg2img) {\n            let inputs = gradioApp().querySelectorAll('input');\n            inputs.forEach(function(e) {\n                var is_width = e.parentElement.id == \"img2img_width\";\n                var is_height = e.parentElement.id == \"img2img_height\";\n\n                if ((is_width || is_height) && !e.classList.contains('scrollwatch')) {\n                    e.addEventListener('input', function(e) {\n                        dimensionChange(e, is_width, is_height);\n                    });\n                    e.classList.add('scrollwatch');\n                }\n                if (is_width) {\n                    currentWidth = e.value * 1.0;\n                }\n                if (is_height) {\n                    currentHeight = e.value * 1.0;\n                }\n            });\n        }\n    }\n});\n"
  },
  {
    "path": "javascript/contextMenus.js",
    "content": "\nvar contextMenuInit = function() {\n    let eventListenerApplied = false;\n    let menuSpecs = new Map();\n\n    const uid = function() {\n        return Date.now().toString(36) + Math.random().toString(36).substring(2);\n    };\n\n    function showContextMenu(event, element, menuEntries) {\n        let oldMenu = gradioApp().querySelector('#context-menu');\n        if (oldMenu) {\n            oldMenu.remove();\n        }\n\n        let baseStyle = window.getComputedStyle(uiCurrentTab);\n\n        const contextMenu = document.createElement('nav');\n        contextMenu.id = \"context-menu\";\n        contextMenu.style.background = baseStyle.background;\n        contextMenu.style.color = baseStyle.color;\n        contextMenu.style.fontFamily = baseStyle.fontFamily;\n        contextMenu.style.top = event.pageY + 'px';\n        contextMenu.style.left = event.pageX + 'px';\n\n        const contextMenuList = document.createElement('ul');\n        contextMenuList.className = 'context-menu-items';\n        contextMenu.append(contextMenuList);\n\n        menuEntries.forEach(function(entry) {\n            let contextMenuEntry = document.createElement('a');\n            contextMenuEntry.innerHTML = entry['name'];\n            contextMenuEntry.addEventListener(\"click\", function() {\n                entry['func']();\n            });\n            contextMenuList.append(contextMenuEntry);\n\n        });\n\n        gradioApp().appendChild(contextMenu);\n    }\n\n    function appendContextMenuOption(targetElementSelector, entryName, entryFunction) {\n\n        var currentItems = menuSpecs.get(targetElementSelector);\n\n        if (!currentItems) {\n            currentItems = [];\n            menuSpecs.set(targetElementSelector, currentItems);\n        }\n        let newItem = {\n            id: targetElementSelector + '_' + uid(),\n            name: entryName,\n            func: entryFunction,\n            isNew: true\n        };\n\n        currentItems.push(newItem);\n        return newItem['id'];\n    }\n\n    function removeContextMenuOption(uid) {\n        menuSpecs.forEach(function(v) {\n            let index = -1;\n            v.forEach(function(e, ei) {\n                if (e['id'] == uid) {\n                    index = ei;\n                }\n            });\n            if (index >= 0) {\n                v.splice(index, 1);\n            }\n        });\n    }\n\n    function addContextMenuEventListener() {\n        if (eventListenerApplied) {\n            return;\n        }\n        gradioApp().addEventListener(\"click\", function(e) {\n            if (!e.isTrusted) {\n                return;\n            }\n\n            let oldMenu = gradioApp().querySelector('#context-menu');\n            if (oldMenu) {\n                oldMenu.remove();\n            }\n        });\n        ['contextmenu', 'touchstart'].forEach((eventType) => {\n            gradioApp().addEventListener(eventType, function(e) {\n                let ev = e;\n                if (eventType.startsWith('touch')) {\n                    if (e.touches.length !== 2) return;\n                    ev = e.touches[0];\n                }\n                let oldMenu = gradioApp().querySelector('#context-menu');\n                if (oldMenu) {\n                    oldMenu.remove();\n                }\n                menuSpecs.forEach(function(v, k) {\n                    if (e.composedPath()[0].matches(k)) {\n                        showContextMenu(ev, e.composedPath()[0], v);\n                        e.preventDefault();\n                    }\n                });\n            });\n        });\n        eventListenerApplied = true;\n\n    }\n\n    return [appendContextMenuOption, removeContextMenuOption, addContextMenuEventListener];\n};\n\nvar initResponse = contextMenuInit();\nvar appendContextMenuOption = initResponse[0];\nvar removeContextMenuOption = initResponse[1];\nvar addContextMenuEventListener = initResponse[2];\n\n(function() {\n    //Start example Context Menu Items\n    let generateOnRepeat = function(genbuttonid, interruptbuttonid) {\n        let genbutton = gradioApp().querySelector(genbuttonid);\n        let interruptbutton = gradioApp().querySelector(interruptbuttonid);\n        if (!interruptbutton.offsetParent) {\n            genbutton.click();\n        }\n        clearInterval(window.generateOnRepeatInterval);\n        window.generateOnRepeatInterval = setInterval(function() {\n            if (!interruptbutton.offsetParent) {\n                genbutton.click();\n            }\n        },\n        500);\n    };\n\n    let generateOnRepeat_txt2img = function() {\n        generateOnRepeat('#txt2img_generate', '#txt2img_interrupt');\n    };\n\n    let generateOnRepeat_img2img = function() {\n        generateOnRepeat('#img2img_generate', '#img2img_interrupt');\n    };\n\n    appendContextMenuOption('#txt2img_generate', 'Generate forever', generateOnRepeat_txt2img);\n    appendContextMenuOption('#txt2img_interrupt', 'Generate forever', generateOnRepeat_txt2img);\n    appendContextMenuOption('#img2img_generate', 'Generate forever', generateOnRepeat_img2img);\n    appendContextMenuOption('#img2img_interrupt', 'Generate forever', generateOnRepeat_img2img);\n\n    let cancelGenerateForever = function() {\n        clearInterval(window.generateOnRepeatInterval);\n    };\n\n    appendContextMenuOption('#txt2img_interrupt', 'Cancel generate forever', cancelGenerateForever);\n    appendContextMenuOption('#txt2img_generate', 'Cancel generate forever', cancelGenerateForever);\n    appendContextMenuOption('#img2img_interrupt', 'Cancel generate forever', cancelGenerateForever);\n    appendContextMenuOption('#img2img_generate', 'Cancel generate forever', cancelGenerateForever);\n\n})();\n//End example Context Menu Items\n\nonAfterUiUpdate(addContextMenuEventListener);\n"
  },
  {
    "path": "javascript/dragdrop.js",
    "content": "// allows drag-dropping files into gradio image elements, and also pasting images from clipboard\n\nfunction isValidImageList(files) {\n    return files && files?.length === 1 && ['image/png', 'image/gif', 'image/jpeg'].includes(files[0].type);\n}\n\nfunction dropReplaceImage(imgWrap, files) {\n    if (!isValidImageList(files)) {\n        return;\n    }\n\n    const tmpFile = files[0];\n\n    imgWrap.querySelector('.modify-upload button + button, .touch-none + div button + button')?.click();\n    const callback = () => {\n        const fileInput = imgWrap.querySelector('input[type=\"file\"]');\n        if (fileInput) {\n            if (files.length === 0) {\n                files = new DataTransfer();\n                files.items.add(tmpFile);\n                fileInput.files = files.files;\n            } else {\n                fileInput.files = files;\n            }\n            fileInput.dispatchEvent(new Event('change'));\n        }\n    };\n\n    if (imgWrap.closest('#pnginfo_image')) {\n        // special treatment for PNG Info tab, wait for fetch request to finish\n        const oldFetch = window.fetch;\n        window.fetch = async(input, options) => {\n            const response = await oldFetch(input, options);\n            if ('api/predict/' === input) {\n                const content = await response.text();\n                window.fetch = oldFetch;\n                window.requestAnimationFrame(() => callback());\n                return new Response(content, {\n                    status: response.status,\n                    statusText: response.statusText,\n                    headers: response.headers\n                });\n            }\n            return response;\n        };\n    } else {\n        window.requestAnimationFrame(() => callback());\n    }\n}\n\nfunction eventHasFiles(e) {\n    if (!e.dataTransfer || !e.dataTransfer.files) return false;\n    if (e.dataTransfer.files.length > 0) return true;\n    if (e.dataTransfer.items.length > 0 && e.dataTransfer.items[0].kind == \"file\") return true;\n\n    return false;\n}\n\nfunction isURL(url) {\n    try {\n        const _ = new URL(url);\n        return true;\n    } catch {\n        return false;\n    }\n}\n\nfunction dragDropTargetIsPrompt(target) {\n    if (target?.placeholder && target?.placeholder.indexOf(\"Prompt\") >= 0) return true;\n    if (target?.parentNode?.parentNode?.className?.indexOf(\"prompt\") > 0) return true;\n    return false;\n}\n\nwindow.document.addEventListener('dragover', e => {\n    const target = e.composedPath()[0];\n    if (!eventHasFiles(e)) return;\n\n    var targetImage = target.closest('[data-testid=\"image\"]');\n    if (!dragDropTargetIsPrompt(target) && !targetImage) return;\n\n    e.stopPropagation();\n    e.preventDefault();\n    e.dataTransfer.dropEffect = 'copy';\n});\n\nwindow.document.addEventListener('drop', async e => {\n    const target = e.composedPath()[0];\n    const url = e.dataTransfer.getData('text/uri-list') || e.dataTransfer.getData('text/plain');\n    if (!eventHasFiles(e) && !isURL(url)) return;\n\n    if (dragDropTargetIsPrompt(target)) {\n        e.stopPropagation();\n        e.preventDefault();\n\n        const isImg2img = get_tab_index('tabs') == 1;\n        let prompt_image_target = isImg2img ? \"img2img_prompt_image\" : \"txt2img_prompt_image\";\n\n        const imgParent = gradioApp().getElementById(prompt_image_target);\n        const files = e.dataTransfer.files;\n        const fileInput = imgParent.querySelector('input[type=\"file\"]');\n        if (eventHasFiles(e) && fileInput) {\n            fileInput.files = files;\n            fileInput.dispatchEvent(new Event('change'));\n        } else if (url) {\n            try {\n                const request = await fetch(url);\n                if (!request.ok) {\n                    console.error('Error fetching URL:', url, request.status);\n                    return;\n                }\n                const data = new DataTransfer();\n                data.items.add(new File([await request.blob()], 'image.png'));\n                fileInput.files = data.files;\n                fileInput.dispatchEvent(new Event('change'));\n            } catch (error) {\n                console.error('Error fetching URL:', url, error);\n                return;\n            }\n        }\n    }\n\n    var targetImage = target.closest('[data-testid=\"image\"]');\n    if (targetImage) {\n        e.stopPropagation();\n        e.preventDefault();\n        const files = e.dataTransfer.files;\n        dropReplaceImage(targetImage, files);\n        return;\n    }\n});\n\nwindow.addEventListener('paste', e => {\n    const files = e.clipboardData.files;\n    if (!isValidImageList(files)) {\n        return;\n    }\n\n    const visibleImageFields = [...gradioApp().querySelectorAll('[data-testid=\"image\"]')]\n        .filter(el => uiElementIsVisible(el))\n        .sort((a, b) => uiElementInSight(b) - uiElementInSight(a));\n\n\n    if (!visibleImageFields.length) {\n        return;\n    }\n\n    const firstFreeImageField = visibleImageFields\n        .filter(el => !el.querySelector('img'))?.[0];\n\n    dropReplaceImage(\n        firstFreeImageField ?\n            firstFreeImageField :\n            visibleImageFields[visibleImageFields.length - 1]\n        , files\n    );\n});\n"
  },
  {
    "path": "javascript/edit-attention.js",
    "content": "function keyupEditAttention(event) {\n    let target = event.originalTarget || event.composedPath()[0];\n    if (!target.matches(\"*:is([id*='_toprow'] [id*='_prompt'], .prompt) textarea\")) return;\n    if (!(event.metaKey || event.ctrlKey)) return;\n\n    let isPlus = event.key == \"ArrowUp\";\n    let isMinus = event.key == \"ArrowDown\";\n    if (!isPlus && !isMinus) return;\n\n    let selectionStart = target.selectionStart;\n    let selectionEnd = target.selectionEnd;\n    let text = target.value;\n\n    function selectCurrentParenthesisBlock(OPEN, CLOSE) {\n        if (selectionStart !== selectionEnd) return false;\n\n        // Find opening parenthesis around current cursor\n        const before = text.substring(0, selectionStart);\n        let beforeParen = before.lastIndexOf(OPEN);\n        if (beforeParen == -1) return false;\n\n        let beforeClosingParen = before.lastIndexOf(CLOSE);\n        if (beforeClosingParen != -1 && beforeClosingParen > beforeParen) return false;\n\n        // Find closing parenthesis around current cursor\n        const after = text.substring(selectionStart);\n        let afterParen = after.indexOf(CLOSE);\n        if (afterParen == -1) return false;\n\n        let afterOpeningParen = after.indexOf(OPEN);\n        if (afterOpeningParen != -1 && afterOpeningParen < afterParen) return false;\n\n        // Set the selection to the text between the parenthesis\n        const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen);\n        if (/.*:-?[\\d.]+/s.test(parenContent)) {\n            const lastColon = parenContent.lastIndexOf(\":\");\n            selectionStart = beforeParen + 1;\n            selectionEnd = selectionStart + lastColon;\n        } else {\n            selectionStart = beforeParen + 1;\n            selectionEnd = selectionStart + parenContent.length;\n        }\n\n        target.setSelectionRange(selectionStart, selectionEnd);\n        return true;\n    }\n\n    function selectCurrentWord() {\n        if (selectionStart !== selectionEnd) return false;\n        const whitespace_delimiters = {\"Tab\": \"\\t\", \"Carriage Return\": \"\\r\", \"Line Feed\": \"\\n\"};\n        let delimiters = opts.keyedit_delimiters;\n\n        for (let i of opts.keyedit_delimiters_whitespace) {\n            delimiters += whitespace_delimiters[i];\n        }\n\n        // seek backward to find beginning\n        while (!delimiters.includes(text[selectionStart - 1]) && selectionStart > 0) {\n            selectionStart--;\n        }\n\n        // seek forward to find end\n        while (!delimiters.includes(text[selectionEnd]) && selectionEnd < text.length) {\n            selectionEnd++;\n        }\n\n        // deselect surrounding whitespace\n        while (text[selectionStart] == \" \" && selectionStart < selectionEnd) {\n            selectionStart++;\n        }\n        while (text[selectionEnd - 1] == \" \" && selectionEnd > selectionStart) {\n            selectionEnd--;\n        }\n\n        target.setSelectionRange(selectionStart, selectionEnd);\n        return true;\n    }\n\n    // If the user hasn't selected anything, let's select their current parenthesis block or word\n    if (!selectCurrentParenthesisBlock('<', '>') && !selectCurrentParenthesisBlock('(', ')') && !selectCurrentParenthesisBlock('[', ']')) {\n        selectCurrentWord();\n    }\n\n    event.preventDefault();\n\n    var closeCharacter = ')';\n    var delta = opts.keyedit_precision_attention;\n    var start = selectionStart > 0 ? text[selectionStart - 1] : \"\";\n    var end = text[selectionEnd];\n\n    if (start == '<') {\n        closeCharacter = '>';\n        delta = opts.keyedit_precision_extra;\n    } else if (start == '(' && end == ')' || start == '[' && end == ']') { // convert old-style (((emphasis)))\n        let numParen = 0;\n\n        while (text[selectionStart - numParen - 1] == start && text[selectionEnd + numParen] == end) {\n            numParen++;\n        }\n\n        if (start == \"[\") {\n            weight = (1 / 1.1) ** numParen;\n        } else {\n            weight = 1.1 ** numParen;\n        }\n\n        weight = Math.round(weight / opts.keyedit_precision_attention) * opts.keyedit_precision_attention;\n\n        text = text.slice(0, selectionStart - numParen) + \"(\" + text.slice(selectionStart, selectionEnd) + \":\" + weight + \")\" + text.slice(selectionEnd + numParen);\n        selectionStart -= numParen - 1;\n        selectionEnd -= numParen - 1;\n    } else if (start != '(') {\n        // do not include spaces at the end\n        while (selectionEnd > selectionStart && text[selectionEnd - 1] == ' ') {\n            selectionEnd--;\n        }\n\n        if (selectionStart == selectionEnd) {\n            return;\n        }\n\n        text = text.slice(0, selectionStart) + \"(\" + text.slice(selectionStart, selectionEnd) + \":1.0)\" + text.slice(selectionEnd);\n\n        selectionStart++;\n        selectionEnd++;\n    }\n\n    if (text[selectionEnd] != ':') return;\n    var weightLength = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;\n    var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + weightLength));\n    if (isNaN(weight)) return;\n\n    weight += isPlus ? delta : -delta;\n    weight = parseFloat(weight.toPrecision(12));\n    if (Number.isInteger(weight)) weight += \".0\";\n\n    if (closeCharacter == ')' && weight == 1) {\n        var endParenPos = text.substring(selectionEnd).indexOf(')');\n        text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + endParenPos + 1);\n        selectionStart--;\n        selectionEnd--;\n    } else {\n        text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + weightLength);\n    }\n\n    target.focus();\n    target.value = text;\n    target.selectionStart = selectionStart;\n    target.selectionEnd = selectionEnd;\n\n    updateInput(target);\n}\n\naddEventListener('keydown', (event) => {\n    keyupEditAttention(event);\n});\n"
  },
  {
    "path": "javascript/edit-order.js",
    "content": "/* alt+left/right moves text in prompt */\n\nfunction keyupEditOrder(event) {\n    if (!opts.keyedit_move) return;\n\n    let target = event.originalTarget || event.composedPath()[0];\n    if (!target.matches(\"*:is([id*='_toprow'] [id*='_prompt'], .prompt) textarea\")) return;\n    if (!event.altKey) return;\n\n    let isLeft = event.key == \"ArrowLeft\";\n    let isRight = event.key == \"ArrowRight\";\n    if (!isLeft && !isRight) return;\n    event.preventDefault();\n\n    let selectionStart = target.selectionStart;\n    let selectionEnd = target.selectionEnd;\n    let text = target.value;\n    let items = text.split(\",\");\n    let indexStart = (text.slice(0, selectionStart).match(/,/g) || []).length;\n    let indexEnd = (text.slice(0, selectionEnd).match(/,/g) || []).length;\n    let range = indexEnd - indexStart + 1;\n\n    if (isLeft && indexStart > 0) {\n        items.splice(indexStart - 1, 0, ...items.splice(indexStart, range));\n        target.value = items.join();\n        target.selectionStart = items.slice(0, indexStart - 1).join().length + (indexStart == 1 ? 0 : 1);\n        target.selectionEnd = items.slice(0, indexEnd).join().length;\n    } else if (isRight && indexEnd < items.length - 1) {\n        items.splice(indexStart + 1, 0, ...items.splice(indexStart, range));\n        target.value = items.join();\n        target.selectionStart = items.slice(0, indexStart + 1).join().length + 1;\n        target.selectionEnd = items.slice(0, indexEnd + 2).join().length;\n    }\n\n    event.preventDefault();\n    updateInput(target);\n}\n\naddEventListener('keydown', (event) => {\n    keyupEditOrder(event);\n});\n"
  },
  {
    "path": "javascript/extensions.js",
    "content": "\nfunction extensions_apply(_disabled_list, _update_list, disable_all) {\n    var disable = [];\n    var update = [];\n    const extensions_input = gradioApp().querySelectorAll('#extensions input[type=\"checkbox\"]');\n    if (extensions_input.length == 0) {\n        throw Error(\"Extensions page not yet loaded.\");\n    }\n    extensions_input.forEach(function(x) {\n        if (x.name.startsWith(\"enable_\") && !x.checked) {\n            disable.push(x.name.substring(7));\n        }\n\n        if (x.name.startsWith(\"update_\") && x.checked) {\n            update.push(x.name.substring(7));\n        }\n    });\n\n    restart_reload();\n\n    return [JSON.stringify(disable), JSON.stringify(update), disable_all];\n}\n\nfunction extensions_check() {\n    var disable = [];\n\n    gradioApp().querySelectorAll('#extensions input[type=\"checkbox\"]').forEach(function(x) {\n        if (x.name.startsWith(\"enable_\") && !x.checked) {\n            disable.push(x.name.substring(7));\n        }\n    });\n\n    gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x) {\n        x.innerHTML = \"Loading...\";\n    });\n\n\n    var id = randomId();\n    requestProgress(id, gradioApp().getElementById('extensions_installed_html'), null, function() {\n\n    });\n\n    return [id, JSON.stringify(disable)];\n}\n\nfunction install_extension_from_index(button, url) {\n    button.disabled = \"disabled\";\n    button.value = \"Installing...\";\n\n    var textarea = gradioApp().querySelector('#extension_to_install textarea');\n    textarea.value = url;\n    updateInput(textarea);\n\n    gradioApp().querySelector('#install_extension_button').click();\n}\n\nfunction config_state_confirm_restore(_, config_state_name, config_restore_type) {\n    if (config_state_name == \"Current\") {\n        return [false, config_state_name, config_restore_type];\n    }\n    let restored = \"\";\n    if (config_restore_type == \"extensions\") {\n        restored = \"all saved extension versions\";\n    } else if (config_restore_type == \"webui\") {\n        restored = \"the webui version\";\n    } else {\n        restored = \"the webui version and all saved extension versions\";\n    }\n    let confirmed = confirm(\"Are you sure you want to restore from this state?\\nThis will reset \" + restored + \".\");\n    if (confirmed) {\n        restart_reload();\n        gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x) {\n            x.innerHTML = \"Loading...\";\n        });\n    }\n    return [confirmed, config_state_name, config_restore_type];\n}\n\nfunction toggle_all_extensions(event) {\n    gradioApp().querySelectorAll('#extensions .extension_toggle').forEach(function(checkbox_el) {\n        checkbox_el.checked = event.target.checked;\n    });\n}\n\nfunction toggle_extension() {\n    let all_extensions_toggled = true;\n    for (const checkbox_el of gradioApp().querySelectorAll('#extensions .extension_toggle')) {\n        if (!checkbox_el.checked) {\n            all_extensions_toggled = false;\n            break;\n        }\n    }\n\n    gradioApp().querySelector('#extensions .all_extensions_toggle').checked = all_extensions_toggled;\n}\n"
  },
  {
    "path": "javascript/extraNetworks.js",
    "content": "function toggleCss(key, css, enable) {\n    var style = document.getElementById(key);\n    if (enable && !style) {\n        style = document.createElement('style');\n        style.id = key;\n        style.type = 'text/css';\n        document.head.appendChild(style);\n    }\n    if (style && !enable) {\n        document.head.removeChild(style);\n    }\n    if (style) {\n        style.innerHTML == '';\n        style.appendChild(document.createTextNode(css));\n    }\n}\n\nfunction setupExtraNetworksForTab(tabname) {\n    function registerPrompt(tabname, id) {\n        var textarea = gradioApp().querySelector(\"#\" + id + \" > label > textarea\");\n\n        if (!activePromptTextarea[tabname]) {\n            activePromptTextarea[tabname] = textarea;\n        }\n\n        textarea.addEventListener(\"focus\", function() {\n            activePromptTextarea[tabname] = textarea;\n        });\n    }\n\n    var tabnav = gradioApp().querySelector('#' + tabname + '_extra_tabs > div.tab-nav');\n    var controlsDiv = document.createElement('DIV');\n    controlsDiv.classList.add('extra-networks-controls-div');\n    tabnav.appendChild(controlsDiv);\n    tabnav.insertBefore(controlsDiv, null);\n\n    var this_tab = gradioApp().querySelector('#' + tabname + '_extra_tabs');\n    this_tab.querySelectorAll(\":scope > [id^='\" + tabname + \"_']\").forEach(function(elem) {\n        // tabname_full = {tabname}_{extra_networks_tabname}\n        var tabname_full = elem.id;\n        var search = gradioApp().querySelector(\"#\" + tabname_full + \"_extra_search\");\n        var sort_dir = gradioApp().querySelector(\"#\" + tabname_full + \"_extra_sort_dir\");\n        var refresh = gradioApp().querySelector(\"#\" + tabname_full + \"_extra_refresh\");\n        var currentSort = '';\n\n        // If any of the buttons above don't exist, we want to skip this iteration of the loop.\n        if (!search || !sort_dir || !refresh) {\n            return; // `return` is equivalent of `continue` but for forEach loops.\n        }\n\n        var applyFilter = function(force) {\n            var searchTerm = search.value.toLowerCase();\n            gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card').forEach(function(elem) {\n                var searchOnly = elem.querySelector('.search_only');\n                var text = Array.prototype.map.call(elem.querySelectorAll('.search_terms, .description'), function(t) {\n                    return t.textContent.toLowerCase();\n                }).join(\" \");\n\n                var visible = text.indexOf(searchTerm) != -1;\n                if (searchOnly && searchTerm.length < 4) {\n                    visible = false;\n                }\n                if (visible) {\n                    elem.classList.remove(\"hidden\");\n                } else {\n                    elem.classList.add(\"hidden\");\n                }\n            });\n\n            applySort(force);\n        };\n\n        var applySort = function(force) {\n            var cards = gradioApp().querySelectorAll('#' + tabname_full + ' div.card');\n            var parent = gradioApp().querySelector('#' + tabname_full + \"_cards\");\n            var reverse = sort_dir.dataset.sortdir == \"Descending\";\n            var activeSearchElem = gradioApp().querySelector('#' + tabname_full + \"_controls .extra-network-control--sort.extra-network-control--enabled\");\n            var sortKey = activeSearchElem ? activeSearchElem.dataset.sortkey : \"default\";\n            var sortKeyDataField = \"sort\" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1);\n            var sortKeyStore = sortKey + \"-\" + sort_dir.dataset.sortdir + \"-\" + cards.length;\n\n            if (sortKeyStore == currentSort && !force) {\n                return;\n            }\n            currentSort = sortKeyStore;\n\n            var sortedCards = Array.from(cards);\n            sortedCards.sort(function(cardA, cardB) {\n                var a = cardA.dataset[sortKeyDataField];\n                var b = cardB.dataset[sortKeyDataField];\n                if (!isNaN(a) && !isNaN(b)) {\n                    return parseInt(a) - parseInt(b);\n                }\n\n                return (a < b ? -1 : (a > b ? 1 : 0));\n            });\n\n            if (reverse) {\n                sortedCards.reverse();\n            }\n\n            parent.innerHTML = '';\n\n            var frag = document.createDocumentFragment();\n            sortedCards.forEach(function(card) {\n                frag.appendChild(card);\n            });\n            parent.appendChild(frag);\n        };\n\n        search.addEventListener(\"input\", function() {\n            applyFilter();\n        });\n        applySort();\n        applyFilter();\n        extraNetworksApplySort[tabname_full] = applySort;\n        extraNetworksApplyFilter[tabname_full] = applyFilter;\n\n        var controls = gradioApp().querySelector(\"#\" + tabname_full + \"_controls\");\n        controlsDiv.insertBefore(controls, null);\n\n        if (elem.style.display != \"none\") {\n            extraNetworksShowControlsForPage(tabname, tabname_full);\n        }\n    });\n\n    registerPrompt(tabname, tabname + \"_prompt\");\n    registerPrompt(tabname, tabname + \"_neg_prompt\");\n}\n\nfunction extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePrompt) {\n    if (!gradioApp().querySelector('.toprow-compact-tools')) return; // only applicable for compact prompt layout\n\n    var promptContainer = gradioApp().getElementById(tabname + '_prompt_container');\n    var prompt = gradioApp().getElementById(tabname + '_prompt_row');\n    var negPrompt = gradioApp().getElementById(tabname + '_neg_prompt_row');\n    var elem = id ? gradioApp().getElementById(id) : null;\n\n    if (showNegativePrompt && elem) {\n        elem.insertBefore(negPrompt, elem.firstChild);\n    } else {\n        promptContainer.insertBefore(negPrompt, promptContainer.firstChild);\n    }\n\n    if (showPrompt && elem) {\n        elem.insertBefore(prompt, elem.firstChild);\n    } else {\n        promptContainer.insertBefore(prompt, promptContainer.firstChild);\n    }\n\n    if (elem) {\n        elem.classList.toggle('extra-page-prompts-active', showNegativePrompt || showPrompt);\n    }\n}\n\n\nfunction extraNetworksShowControlsForPage(tabname, tabname_full) {\n    gradioApp().querySelectorAll('#' + tabname + '_extra_tabs .extra-networks-controls-div > div').forEach(function(elem) {\n        var targetId = tabname_full + \"_controls\";\n        elem.style.display = elem.id == targetId ? \"\" : \"none\";\n    });\n}\n\n\nfunction extraNetworksUnrelatedTabSelected(tabname) { // called from python when user selects an unrelated tab (generate)\n    extraNetworksMovePromptToTab(tabname, '', false, false);\n\n    extraNetworksShowControlsForPage(tabname, null);\n}\n\nfunction extraNetworksTabSelected(tabname, id, showPrompt, showNegativePrompt, tabname_full) { // called from python when user selects an extra networks tab\n    extraNetworksMovePromptToTab(tabname, id, showPrompt, showNegativePrompt);\n\n    extraNetworksShowControlsForPage(tabname, tabname_full);\n}\n\nfunction applyExtraNetworkFilter(tabname_full) {\n    var doFilter = function() {\n        var applyFunction = extraNetworksApplyFilter[tabname_full];\n\n        if (applyFunction) {\n            applyFunction(true);\n        }\n    };\n    setTimeout(doFilter, 1);\n}\n\nfunction applyExtraNetworkSort(tabname_full) {\n    var doSort = function() {\n        extraNetworksApplySort[tabname_full](true);\n    };\n    setTimeout(doSort, 1);\n}\n\nvar extraNetworksApplyFilter = {};\nvar extraNetworksApplySort = {};\nvar activePromptTextarea = {};\n\nfunction setupExtraNetworks() {\n    setupExtraNetworksForTab('txt2img');\n    setupExtraNetworksForTab('img2img');\n}\n\nvar re_extranet = /<([^:^>]+:[^:]+):[\\d.]+>(.*)/;\nvar re_extranet_g = /<([^:^>]+:[^:]+):[\\d.]+>/g;\n\nvar re_extranet_neg = /\\(([^:^>]+:[\\d.]+)\\)/;\nvar re_extranet_g_neg = /\\(([^:^>]+:[\\d.]+)\\)/g;\nfunction tryToRemoveExtraNetworkFromPrompt(textarea, text, isNeg) {\n    var m = text.match(isNeg ? re_extranet_neg : re_extranet);\n    var replaced = false;\n    var newTextareaText;\n    var extraTextBeforeNet = opts.extra_networks_add_text_separator;\n    if (m) {\n        var extraTextAfterNet = m[2];\n        var partToSearch = m[1];\n        var foundAtPosition = -1;\n        newTextareaText = textarea.value.replaceAll(isNeg ? re_extranet_g_neg : re_extranet_g, function(found, net, pos) {\n            m = found.match(isNeg ? re_extranet_neg : re_extranet);\n            if (m[1] == partToSearch) {\n                replaced = true;\n                foundAtPosition = pos;\n                return \"\";\n            }\n            return found;\n        });\n        if (foundAtPosition >= 0) {\n            if (extraTextAfterNet && newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) {\n                newTextareaText = newTextareaText.substr(0, foundAtPosition) + newTextareaText.substr(foundAtPosition + extraTextAfterNet.length);\n            }\n            if (newTextareaText.substr(foundAtPosition - extraTextBeforeNet.length, extraTextBeforeNet.length) == extraTextBeforeNet) {\n                newTextareaText = newTextareaText.substr(0, foundAtPosition - extraTextBeforeNet.length) + newTextareaText.substr(foundAtPosition);\n            }\n        }\n    } else {\n        newTextareaText = textarea.value.replaceAll(new RegExp(`((?:${extraTextBeforeNet})?${text})`, \"g\"), \"\");\n        replaced = (newTextareaText != textarea.value);\n    }\n\n    if (replaced) {\n        textarea.value = newTextareaText;\n        return true;\n    }\n\n    return false;\n}\n\nfunction updatePromptArea(text, textArea, isNeg) {\n    if (!tryToRemoveExtraNetworkFromPrompt(textArea, text, isNeg)) {\n        textArea.value = textArea.value + opts.extra_networks_add_text_separator + text;\n    }\n\n    updateInput(textArea);\n}\n\nfunction cardClicked(tabname, textToAdd, textToAddNegative, allowNegativePrompt) {\n    if (textToAddNegative.length > 0) {\n        updatePromptArea(textToAdd, gradioApp().querySelector(\"#\" + tabname + \"_prompt > label > textarea\"));\n        updatePromptArea(textToAddNegative, gradioApp().querySelector(\"#\" + tabname + \"_neg_prompt > label > textarea\"), true);\n    } else {\n        var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector(\"#\" + tabname + \"_prompt > label > textarea\");\n        updatePromptArea(textToAdd, textarea);\n    }\n}\n\nfunction saveCardPreview(event, tabname, filename) {\n    var textarea = gradioApp().querySelector(\"#\" + tabname + '_preview_filename  > label > textarea');\n    var button = gradioApp().getElementById(tabname + '_save_preview');\n\n    textarea.value = filename;\n    updateInput(textarea);\n\n    button.click();\n\n    event.stopPropagation();\n    event.preventDefault();\n}\n\nfunction extraNetworksSearchButton(tabname, extra_networks_tabname, event) {\n    var searchTextarea = gradioApp().querySelector(\"#\" + tabname + \"_\" + extra_networks_tabname + \"_extra_search\");\n    var button = event.target;\n    var text = button.classList.contains(\"search-all\") ? \"\" : button.textContent.trim();\n\n    searchTextarea.value = text;\n    updateInput(searchTextarea);\n}\n\nfunction extraNetworksTreeProcessFileClick(event, btn, tabname, extra_networks_tabname) {\n    /**\n     * Processes `onclick` events when user clicks on files in tree.\n     *\n     * @param event                     The generated event.\n     * @param btn                       The clicked `tree-list-item` button.\n     * @param tabname                   The name of the active tab in the sd webui. Ex: txt2img, img2img, etc.\n     * @param extra_networks_tabname    The id of the active extraNetworks tab. Ex: lora, checkpoints, etc.\n     */\n    // NOTE: Currently unused.\n    return;\n}\n\nfunction extraNetworksTreeProcessDirectoryClick(event, btn, tabname, extra_networks_tabname) {\n    /**\n     * Processes `onclick` events when user clicks on directories in tree.\n     *\n     * Here is how the tree reacts to clicks for various states:\n     * unselected unopened directory: Directory is selected and expanded.\n     * unselected opened directory: Directory is selected.\n     * selected opened directory: Directory is collapsed and deselected.\n     * chevron is clicked: Directory is expanded or collapsed. Selected state unchanged.\n     *\n     * @param event                     The generated event.\n     * @param btn                       The clicked `tree-list-item` button.\n     * @param tabname                   The name of the active tab in the sd webui. Ex: txt2img, img2img, etc.\n     * @param extra_networks_tabname    The id of the active extraNetworks tab. Ex: lora, checkpoints, etc.\n     */\n    var ul = btn.nextElementSibling;\n    // This is the actual target that the user clicked on within the target button.\n    // We use this to detect if the chevron was clicked.\n    var true_targ = event.target;\n\n    function _expand_or_collapse(_ul, _btn) {\n        // Expands <ul> if it is collapsed, collapses otherwise. Updates button attributes.\n        if (_ul.hasAttribute(\"hidden\")) {\n            _ul.removeAttribute(\"hidden\");\n            _btn.dataset.expanded = \"\";\n        } else {\n            _ul.setAttribute(\"hidden\", \"\");\n            delete _btn.dataset.expanded;\n        }\n    }\n\n    function _remove_selected_from_all() {\n        // Removes the `selected` attribute from all buttons.\n        var sels = document.querySelectorAll(\"div.tree-list-content\");\n        [...sels].forEach(el => {\n            delete el.dataset.selected;\n        });\n    }\n\n    function _select_button(_btn) {\n        // Removes `data-selected` attribute from all buttons then adds to passed button.\n        _remove_selected_from_all();\n        _btn.dataset.selected = \"\";\n    }\n\n    function _update_search(_tabname, _extra_networks_tabname, _search_text) {\n        // Update search input with select button's path.\n        var search_input_elem = gradioApp().querySelector(\"#\" + tabname + \"_\" + extra_networks_tabname + \"_extra_search\");\n        search_input_elem.value = _search_text;\n        updateInput(search_input_elem);\n    }\n\n\n    // If user clicks on the chevron, then we do not select the folder.\n    if (true_targ.matches(\".tree-list-item-action--leading, .tree-list-item-action-chevron\")) {\n        _expand_or_collapse(ul, btn);\n    } else {\n        // User clicked anywhere else on the button.\n        if (\"selected\" in btn.dataset && !(ul.hasAttribute(\"hidden\"))) {\n            // If folder is select and open, collapse and deselect button.\n            _expand_or_collapse(ul, btn);\n            delete btn.dataset.selected;\n            _update_search(tabname, extra_networks_tabname, \"\");\n        } else if (!(!(\"selected\" in btn.dataset) && !(ul.hasAttribute(\"hidden\")))) {\n            // If folder is open and not selected, then we don't collapse; just select.\n            // NOTE: Double inversion sucks but it is the clearest way to show the branching here.\n            _expand_or_collapse(ul, btn);\n            _select_button(btn, tabname, extra_networks_tabname);\n            _update_search(tabname, extra_networks_tabname, btn.dataset.path);\n        } else {\n            // All other cases, just select the button.\n            _select_button(btn, tabname, extra_networks_tabname);\n            _update_search(tabname, extra_networks_tabname, btn.dataset.path);\n        }\n    }\n}\n\nfunction extraNetworksTreeOnClick(event, tabname, extra_networks_tabname) {\n    /**\n     * Handles `onclick` events for buttons within an `extra-network-tree .tree-list--tree`.\n     *\n     * Determines whether the clicked button in the tree is for a file entry or a directory\n     * then calls the appropriate function.\n     *\n     * @param event                     The generated event.\n     * @param tabname                   The name of the active tab in the sd webui. Ex: txt2img, img2img, etc.\n     * @param extra_networks_tabname    The id of the active extraNetworks tab. Ex: lora, checkpoints, etc.\n     */\n    var btn = event.currentTarget;\n    var par = btn.parentElement;\n    if (par.dataset.treeEntryType === \"file\") {\n        extraNetworksTreeProcessFileClick(event, btn, tabname, extra_networks_tabname);\n    } else {\n        extraNetworksTreeProcessDirectoryClick(event, btn, tabname, extra_networks_tabname);\n    }\n}\n\nfunction extraNetworksControlSortOnClick(event, tabname, extra_networks_tabname) {\n    /** Handles `onclick` events for Sort Mode buttons. */\n\n    var self = event.currentTarget;\n    var parent = event.currentTarget.parentElement;\n\n    parent.querySelectorAll('.extra-network-control--sort').forEach(function(x) {\n        x.classList.remove('extra-network-control--enabled');\n    });\n\n    self.classList.add('extra-network-control--enabled');\n\n    applyExtraNetworkSort(tabname + \"_\" + extra_networks_tabname);\n}\n\nfunction extraNetworksControlSortDirOnClick(event, tabname, extra_networks_tabname) {\n    /**\n     * Handles `onclick` events for the Sort Direction button.\n     *\n     * Modifies the data attributes of the Sort Direction button to cycle between\n     * ascending and descending sort directions.\n     *\n     * @param event                     The generated event.\n     * @param tabname                   The name of the active tab in the sd webui. Ex: txt2img, img2img, etc.\n     * @param extra_networks_tabname    The id of the active extraNetworks tab. Ex: lora, checkpoints, etc.\n     */\n    if (event.currentTarget.dataset.sortdir == \"Ascending\") {\n        event.currentTarget.dataset.sortdir = \"Descending\";\n        event.currentTarget.setAttribute(\"title\", \"Sort descending\");\n    } else {\n        event.currentTarget.dataset.sortdir = \"Ascending\";\n        event.currentTarget.setAttribute(\"title\", \"Sort ascending\");\n    }\n    applyExtraNetworkSort(tabname + \"_\" + extra_networks_tabname);\n}\n\nfunction extraNetworksControlTreeViewOnClick(event, tabname, extra_networks_tabname) {\n    /**\n     * Handles `onclick` events for the Tree View button.\n     *\n     * Toggles the tree view in the extra networks pane.\n     *\n     * @param event                     The generated event.\n     * @param tabname                   The name of the active tab in the sd webui. Ex: txt2img, img2img, etc.\n     * @param extra_networks_tabname    The id of the active extraNetworks tab. Ex: lora, checkpoints, etc.\n     */\n    var button = event.currentTarget;\n    button.classList.toggle(\"extra-network-control--enabled\");\n    var show = !button.classList.contains(\"extra-network-control--enabled\");\n\n    var pane = gradioApp().getElementById(tabname + \"_\" + extra_networks_tabname + \"_pane\");\n    pane.classList.toggle(\"extra-network-dirs-hidden\", show);\n}\n\nfunction extraNetworksControlRefreshOnClick(event, tabname, extra_networks_tabname) {\n    /**\n     * Handles `onclick` events for the Refresh Page button.\n     *\n     * In order to actually call the python functions in `ui_extra_networks.py`\n     * to refresh the page, we created an empty gradio button in that file with an\n     * event handler that refreshes the page. So what this function here does\n     * is it manually raises a `click` event on that button.\n     *\n     * @param event                     The generated event.\n     * @param tabname                   The name of the active tab in the sd webui. Ex: txt2img, img2img, etc.\n     * @param extra_networks_tabname    The id of the active extraNetworks tab. Ex: lora, checkpoints, etc.\n     */\n    var btn_refresh_internal = gradioApp().getElementById(tabname + \"_\" + extra_networks_tabname + \"_extra_refresh_internal\");\n    btn_refresh_internal.dispatchEvent(new Event(\"click\"));\n}\n\nvar globalPopup = null;\nvar globalPopupInner = null;\n\nfunction closePopup() {\n    if (!globalPopup) return;\n    globalPopup.style.display = \"none\";\n}\n\nfunction popup(contents) {\n    if (!globalPopup) {\n        globalPopup = document.createElement('div');\n        globalPopup.classList.add('global-popup');\n\n        var close = document.createElement('div');\n        close.classList.add('global-popup-close');\n        close.addEventListener(\"click\", closePopup);\n        close.title = \"Close\";\n        globalPopup.appendChild(close);\n\n        globalPopupInner = document.createElement('div');\n        globalPopupInner.classList.add('global-popup-inner');\n        globalPopup.appendChild(globalPopupInner);\n\n        gradioApp().querySelector('.main').appendChild(globalPopup);\n    }\n\n    globalPopupInner.innerHTML = '';\n    globalPopupInner.appendChild(contents);\n\n    globalPopup.style.display = \"flex\";\n}\n\nvar storedPopupIds = {};\nfunction popupId(id) {\n    if (!storedPopupIds[id]) {\n        storedPopupIds[id] = gradioApp().getElementById(id);\n    }\n\n    popup(storedPopupIds[id]);\n}\n\nfunction extraNetworksFlattenMetadata(obj) {\n    const result = {};\n\n    // Convert any stringified JSON objects to actual objects\n    for (const key of Object.keys(obj)) {\n        if (typeof obj[key] === 'string') {\n            try {\n                const parsed = JSON.parse(obj[key]);\n                if (parsed && typeof parsed === 'object') {\n                    obj[key] = parsed;\n                }\n            } catch (error) {\n                continue;\n            }\n        }\n    }\n\n    // Flatten the object\n    for (const key of Object.keys(obj)) {\n        if (typeof obj[key] === 'object' && obj[key] !== null) {\n            const nested = extraNetworksFlattenMetadata(obj[key]);\n            for (const nestedKey of Object.keys(nested)) {\n                result[`${key}/${nestedKey}`] = nested[nestedKey];\n            }\n        } else {\n            result[key] = obj[key];\n        }\n    }\n\n    // Special case for handling modelspec keys\n    for (const key of Object.keys(result)) {\n        if (key.startsWith(\"modelspec.\")) {\n            result[key.replaceAll(\".\", \"/\")] = result[key];\n            delete result[key];\n        }\n    }\n\n    // Add empty keys to designate hierarchy\n    for (const key of Object.keys(result)) {\n        const parts = key.split(\"/\");\n        for (let i = 1; i < parts.length; i++) {\n            const parent = parts.slice(0, i).join(\"/\");\n            if (!result[parent]) {\n                result[parent] = \"\";\n            }\n        }\n    }\n\n    return result;\n}\n\nfunction extraNetworksShowMetadata(text) {\n    try {\n        let parsed = JSON.parse(text);\n        if (parsed && typeof parsed === 'object') {\n            parsed = extraNetworksFlattenMetadata(parsed);\n            const table = createVisualizationTable(parsed, 0);\n            popup(table);\n            return;\n        }\n    } catch (error) {\n        console.error(error);\n    }\n\n    var elem = document.createElement('pre');\n    elem.classList.add('popup-metadata');\n    elem.textContent = text;\n\n    popup(elem);\n    return;\n}\n\nfunction requestGet(url, data, handler, errorHandler) {\n    var xhr = new XMLHttpRequest();\n    var args = Object.keys(data).map(function(k) {\n        return encodeURIComponent(k) + '=' + encodeURIComponent(data[k]);\n    }).join('&');\n    xhr.open(\"GET\", url + \"?\" + args, true);\n\n    xhr.onreadystatechange = function() {\n        if (xhr.readyState === 4) {\n            if (xhr.status === 200) {\n                try {\n                    var js = JSON.parse(xhr.responseText);\n                    handler(js);\n                } catch (error) {\n                    console.error(error);\n                    errorHandler();\n                }\n            } else {\n                errorHandler();\n            }\n        }\n    };\n    var js = JSON.stringify(data);\n    xhr.send(js);\n}\n\nfunction extraNetworksCopyCardPath(event) {\n    navigator.clipboard.writeText(event.target.getAttribute(\"data-clipboard-text\"));\n    event.stopPropagation();\n}\n\nfunction extraNetworksRequestMetadata(event, extraPage) {\n    var showError = function() {\n        extraNetworksShowMetadata(\"there was an error getting metadata\");\n    };\n\n    var cardName = event.target.parentElement.parentElement.getAttribute(\"data-name\");\n\n    requestGet(\"./sd_extra_networks/metadata\", {page: extraPage, item: cardName}, function(data) {\n        if (data && data.metadata) {\n            extraNetworksShowMetadata(data.metadata);\n        } else {\n            showError();\n        }\n    }, showError);\n\n    event.stopPropagation();\n}\n\nvar extraPageUserMetadataEditors = {};\n\nfunction extraNetworksEditUserMetadata(event, tabname, extraPage) {\n    var id = tabname + '_' + extraPage + '_edit_user_metadata';\n\n    var editor = extraPageUserMetadataEditors[id];\n    if (!editor) {\n        editor = {};\n        editor.page = gradioApp().getElementById(id);\n        editor.nameTextarea = gradioApp().querySelector(\"#\" + id + \"_name\" + ' textarea');\n        editor.button = gradioApp().querySelector(\"#\" + id + \"_button\");\n        extraPageUserMetadataEditors[id] = editor;\n    }\n\n    var cardName = event.target.parentElement.parentElement.getAttribute(\"data-name\");\n    editor.nameTextarea.value = cardName;\n    updateInput(editor.nameTextarea);\n\n    editor.button.click();\n\n    popup(editor.page);\n\n    event.stopPropagation();\n}\n\nfunction extraNetworksRefreshSingleCard(page, tabname, name) {\n    requestGet(\"./sd_extra_networks/get-single-card\", {page: page, tabname: tabname, name: name}, function(data) {\n        if (data && data.html) {\n            var card = gradioApp().querySelector(`#${tabname}_${page.replace(\" \", \"_\")}_cards > .card[data-name=\"${name}\"]`);\n\n            var newDiv = document.createElement('DIV');\n            newDiv.innerHTML = data.html;\n            var newCard = newDiv.firstElementChild;\n\n            newCard.style.display = '';\n            card.parentElement.insertBefore(newCard, card);\n            card.parentElement.removeChild(card);\n        }\n    });\n}\n\nwindow.addEventListener(\"keydown\", function(event) {\n    if (event.key == \"Escape\") {\n        closePopup();\n    }\n});\n\n/**\n * Setup custom loading for this script.\n * We need to wait for all of our HTML to be generated in the extra networks tabs\n * before we can actually run the `setupExtraNetworks` function.\n * The `onUiLoaded` function actually runs before all of our extra network tabs are\n * finished generating. Thus we needed this new method.\n *\n */\n\nvar uiAfterScriptsCallbacks = [];\nvar uiAfterScriptsTimeout = null;\nvar executedAfterScripts = false;\n\nfunction scheduleAfterScriptsCallbacks() {\n    clearTimeout(uiAfterScriptsTimeout);\n    uiAfterScriptsTimeout = setTimeout(function() {\n        executeCallbacks(uiAfterScriptsCallbacks);\n    }, 200);\n}\n\nonUiLoaded(function() {\n    var mutationObserver = new MutationObserver(function(m) {\n        let existingSearchfields = gradioApp().querySelectorAll(\"[id$='_extra_search']\").length;\n        let neededSearchfields = gradioApp().querySelectorAll(\"[id$='_extra_tabs'] > .tab-nav > button\").length - 2;\n\n        if (!executedAfterScripts && existingSearchfields >= neededSearchfields) {\n            mutationObserver.disconnect();\n            executedAfterScripts = true;\n            scheduleAfterScriptsCallbacks();\n        }\n    });\n    mutationObserver.observe(gradioApp(), {childList: true, subtree: true});\n});\n\nuiAfterScriptsCallbacks.push(setupExtraNetworks);\n"
  },
  {
    "path": "javascript/generationParams.js",
    "content": "// attaches listeners to the txt2img and img2img galleries to update displayed generation param text when the image changes\n\nlet txt2img_gallery, img2img_gallery, modal = undefined;\nonAfterUiUpdate(function() {\n    if (!txt2img_gallery) {\n        txt2img_gallery = attachGalleryListeners(\"txt2img\");\n    }\n    if (!img2img_gallery) {\n        img2img_gallery = attachGalleryListeners(\"img2img\");\n    }\n    if (!modal) {\n        modal = gradioApp().getElementById('lightboxModal');\n        modalObserver.observe(modal, {attributes: true, attributeFilter: ['style']});\n    }\n});\n\nlet modalObserver = new MutationObserver(function(mutations) {\n    mutations.forEach(function(mutationRecord) {\n        let selectedTab = gradioApp().querySelector('#tabs div button.selected')?.innerText;\n        if (mutationRecord.target.style.display === 'none' && (selectedTab === 'txt2img' || selectedTab === 'img2img')) {\n            gradioApp().getElementById(selectedTab + \"_generation_info_button\")?.click();\n        }\n    });\n});\n\nfunction attachGalleryListeners(tab_name) {\n    var gallery = gradioApp().querySelector('#' + tab_name + '_gallery');\n    gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name + \"_generation_info_button\").click());\n    gallery?.addEventListener('keydown', (e) => {\n        if (e.keyCode == 37 || e.keyCode == 39) { // left or right arrow\n            gradioApp().getElementById(tab_name + \"_generation_info_button\").click();\n        }\n    });\n    return gallery;\n}\n"
  },
  {
    "path": "javascript/hints.js",
    "content": "// mouseover tooltips for various UI elements\n\nvar titles = {\n    \"Sampling steps\": \"How many times to improve the generated image iteratively; higher values take longer; very low values can produce bad results\",\n    \"Sampling method\": \"Which algorithm to use to produce the image\",\n    \"GFPGAN\": \"Restore low quality faces using GFPGAN neural network\",\n    \"Euler a\": \"Euler Ancestral - very creative, each can get a completely different picture depending on step count, setting steps higher than 30-40 does not help\",\n    \"DDIM\": \"Denoising Diffusion Implicit Models - best at inpainting\",\n    \"UniPC\": \"Unified Predictor-Corrector Framework for Fast Sampling of Diffusion Models\",\n    \"DPM adaptive\": \"Ignores step count - uses a number of steps determined by the CFG and resolution\",\n\n    \"\\u{1F4D0}\": \"Auto detect size from img2img\",\n    \"Batch count\": \"How many batches of images to create (has no impact on generation performance or VRAM usage)\",\n    \"Batch size\": \"How many image to create in a single batch (increases generation performance at cost of higher VRAM usage)\",\n    \"CFG Scale\": \"Classifier Free Guidance Scale - how strongly the image should conform to prompt - lower values produce more creative results\",\n    \"Seed\": \"A value that determines the output of random number generator - if you create an image with same parameters and seed as another image, you'll get the same result\",\n    \"\\u{1f3b2}\\ufe0f\": \"Set seed to -1, which will cause a new random number to be used every time\",\n    \"\\u267b\\ufe0f\": \"Reuse seed from last generation, mostly useful if it was randomized\",\n    \"\\u2199\\ufe0f\": \"Read generation parameters from prompt or last generation if prompt is empty into user interface.\",\n    \"\\u{1f4c2}\": \"Open images output directory\",\n    \"\\u{1f4be}\": \"Save style\",\n    \"\\u{1f5d1}\\ufe0f\": \"Clear prompt\",\n    \"\\u{1f4cb}\": \"Apply selected styles to current prompt\",\n    \"\\u{1f4d2}\": \"Paste available values into the field\",\n    \"\\u{1f3b4}\": \"Show/hide extra networks\",\n    \"\\u{1f300}\": \"Restore progress\",\n\n    \"Inpaint a part of image\": \"Draw a mask over an image, and the script will regenerate the masked area with content according to prompt\",\n    \"SD upscale\": \"Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back\",\n\n    \"Just resize\": \"Resize image to target resolution. Unless height and width match, you will get incorrect aspect ratio.\",\n    \"Crop and resize\": \"Resize the image so that entirety of target resolution is filled with the image. Crop parts that stick out.\",\n    \"Resize and fill\": \"Resize the image so that entirety of image is inside target resolution. Fill empty space with image's colors.\",\n\n    \"Mask blur\": \"How much to blur the mask before processing, in pixels.\",\n    \"Masked content\": \"What to put inside the masked area before processing it with Stable Diffusion.\",\n    \"fill\": \"fill it with colors of the image\",\n    \"original\": \"keep whatever was there originally\",\n    \"latent noise\": \"fill it with latent space noise\",\n    \"latent nothing\": \"fill it with latent space zeroes\",\n    \"Inpaint at full resolution\": \"Upscale masked region to target resolution, do inpainting, downscale back and paste into original image\",\n\n    \"Denoising strength\": \"Determines how little respect the algorithm should have for image's content. At 0, nothing will change, and at 1 you'll get an unrelated image. With values below 1.0, processing will take less steps than the Sampling Steps slider specifies.\",\n\n    \"Skip\": \"Stop processing current image and continue processing.\",\n    \"Interrupt\": \"Stop processing images and return any results accumulated so far.\",\n    \"Save\": \"Write image to a directory (default - log/images) and generation parameters into csv file.\",\n\n    \"X values\": \"Separate values for X axis using commas.\",\n    \"Y values\": \"Separate values for Y axis using commas.\",\n\n    \"None\": \"Do not do anything special\",\n    \"Prompt matrix\": \"Separate prompts into parts using vertical pipe character (|) and the script will create a picture for every combination of them (except for the first part, which will be present in all combinations)\",\n    \"X/Y/Z plot\": \"Create grid(s) where images will have different parameters. Use inputs below to specify which parameters will be shared by columns and rows\",\n    \"Custom code\": \"Run Python code. Advanced user only. Must run program with --allow-code for this to work\",\n\n    \"Prompt S/R\": \"Separate a list of words with commas, and the first word will be used as a keyword: script will search for this word in the prompt, and replace it with others\",\n    \"Prompt order\": \"Separate a list of words with commas, and the script will make a variation of prompt with those words for their every possible order\",\n\n    \"Tiling\": \"Produce an image that can be tiled.\",\n    \"Tile overlap\": \"For SD upscale, how much overlap in pixels should there be between tiles. Tiles overlap so that when they are merged back into one picture, there is no clearly visible seam.\",\n\n    \"Variation seed\": \"Seed of a different picture to be mixed into the generation.\",\n    \"Variation strength\": \"How strong of a variation to produce. At 0, there will be no effect. At 1, you will get the complete picture with variation seed (except for ancestral samplers, where you will just get something).\",\n    \"Resize seed from height\": \"Make an attempt to produce a picture similar to what would have been produced with same seed at specified resolution\",\n    \"Resize seed from width\": \"Make an attempt to produce a picture similar to what would have been produced with same seed at specified resolution\",\n\n    \"Interrogate\": \"Reconstruct prompt from existing image and put it into the prompt field.\",\n\n    \"Images filename pattern\": \"Use tags like [seed] and [date] to define how filenames for images are chosen. Leave empty for default.\",\n    \"Directory name pattern\": \"Use tags like [seed] and [date] to define how subdirectories for images and grids are chosen. Leave empty for default.\",\n    \"Max prompt words\": \"Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle\",\n\n    \"Loopback\": \"Performs img2img processing multiple times. Output images are used as input for the next loop.\",\n    \"Loops\": \"How many times to process an image. Each output is used as the input of the next loop. If set to 1, behavior will be as if this script were not used.\",\n    \"Final denoising strength\": \"The denoising strength for the final loop of each image in the batch.\",\n    \"Denoising strength curve\": \"The denoising curve controls the rate of denoising strength change each loop. Aggressive: Most of the change will happen towards the start of the loops. Linear: Change will be constant through all loops. Lazy: Most of the change will happen towards the end of the loops.\",\n\n    \"Style 1\": \"Style to apply; styles have components for both positive and negative prompts and apply to both\",\n    \"Style 2\": \"Style to apply; styles have components for both positive and negative prompts and apply to both\",\n    \"Apply style\": \"Insert selected styles into prompt fields\",\n    \"Create style\": \"Save current prompts as a style. If you add the token {prompt} to the text, the style uses that as a placeholder for your prompt when you use the style in the future.\",\n\n    \"Checkpoint name\": \"Loads weights from checkpoint before making images. You can either use hash or a part of filename (as seen in settings) for checkpoint name. Recommended to use with Y axis for less switching.\",\n    \"Inpainting conditioning mask strength\": \"Only applies to inpainting models. Determines how strongly to mask off the original image for inpainting and img2img. 1.0 means fully masked, which is the default behaviour. 0.0 means a fully unmasked conditioning. Lower values will help preserve the overall composition of the image, but will struggle with large changes.\",\n\n    \"Eta noise seed delta\": \"If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.\",\n\n    \"Filename word regex\": \"This regular expression will be used extract words from filename, and they will be joined using the option below into label text used for training. Leave empty to keep filename text as it is.\",\n    \"Filename join string\": \"This string will be used to join split words into a single line if the option above is enabled.\",\n\n    \"Quicksettings list\": \"List of setting names, separated by commas, for settings that should go to the quick access bar at the top, rather than the usual setting tab. See modules/shared.py for setting names. Requires restarting to apply.\",\n\n    \"Weighted sum\": \"Result = A * (1 - M) + B * M\",\n    \"Add difference\": \"Result = A + (B - C) * M\",\n    \"No interpolation\": \"Result = A\",\n\n    \"Initialization text\": \"If the number of tokens is more than the number of vectors, some may be skipped.\\nLeave the textbox empty to start with zeroed out vectors\",\n    \"Learning rate\": \"How fast should training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\\n\\nYou can set a single numeric value, or multiple learning rates using the syntax:\\n\\n   rate_1:max_steps_1, rate_2:max_steps_2, ...\\n\\nEG:   0.005:100, 1e-3:1000, 1e-5\\n\\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.\",\n\n    \"Clip skip\": \"Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc.\",\n\n    \"Approx NN\": \"Cheap neural network approximation. Very fast compared to VAE, but produces pictures with 4 times smaller horizontal/vertical resolution and lower quality.\",\n    \"Approx cheap\": \"Very cheap approximation. Very fast compared to VAE, but produces pictures with 8 times smaller horizontal/vertical resolution and extremely low quality.\",\n\n    \"Hires. fix\": \"Use a two step process to partially create an image at smaller resolution, upscale, and then improve details in it without changing composition\",\n    \"Hires steps\": \"Number of sampling steps for upscaled picture. If 0, uses same as for original.\",\n    \"Upscale by\": \"Adjusts the size of the image by multiplying the original width and height by the selected value. Ignored if either Resize width to or Resize height to are non-zero.\",\n    \"Resize width to\": \"Resizes image to this width. If 0, width is inferred from either of two nearby sliders.\",\n    \"Resize height to\": \"Resizes image to this height. If 0, height is inferred from either of two nearby sliders.\",\n    \"Discard weights with matching name\": \"Regular expression; if weights's name matches it, the weights is not written to the resulting checkpoint. Use ^model_ema to discard EMA weights.\",\n    \"Extra networks tab order\": \"Comma-separated list of tab names; tabs listed here will appear in the extra networks UI first and in order listed.\",\n    \"Negative Guidance minimum sigma\": \"Skip negative prompt for steps where image is already mostly denoised; the higher this value, the more skips there will be; provides increased performance in exchange for minor quality reduction.\"\n};\n\nfunction updateTooltip(element) {\n    if (element.title) return; // already has a title\n\n    let text = element.textContent;\n    let tooltip = localization[titles[text]] || titles[text];\n\n    if (!tooltip) {\n        let value = element.value;\n        if (value) tooltip = localization[titles[value]] || titles[value];\n    }\n\n    if (!tooltip) {\n        // Gradio dropdown options have `data-value`.\n        let dataValue = element.dataset.value;\n        if (dataValue) tooltip = localization[titles[dataValue]] || titles[dataValue];\n    }\n\n    if (!tooltip) {\n        for (const c of element.classList) {\n            if (c in titles) {\n                tooltip = localization[titles[c]] || titles[c];\n                break;\n            }\n        }\n    }\n\n    if (tooltip) {\n        element.title = tooltip;\n    }\n}\n\n// Nodes to check for adding tooltips.\nconst tooltipCheckNodes = new Set();\n// Timer for debouncing tooltip check.\nlet tooltipCheckTimer = null;\n\nfunction processTooltipCheckNodes() {\n    for (const node of tooltipCheckNodes) {\n        updateTooltip(node);\n    }\n    tooltipCheckNodes.clear();\n}\n\nonUiUpdate(function(mutationRecords) {\n    for (const record of mutationRecords) {\n        if (record.type === \"childList\" && record.target.classList.contains(\"options\")) {\n            // This smells like a Gradio dropdown menu having changed,\n            // so let's enqueue an update for the input element that shows the current value.\n            let wrap = record.target.parentNode;\n            let input = wrap?.querySelector(\"input\");\n            if (input) {\n                input.title = \"\"; // So we'll even have a chance to update it.\n                tooltipCheckNodes.add(input);\n            }\n        }\n        for (const node of record.addedNodes) {\n            if (node.nodeType === Node.ELEMENT_NODE && !node.classList.contains(\"hide\")) {\n                if (!node.title) {\n                    if (\n                        node.tagName === \"SPAN\" ||\n                        node.tagName === \"BUTTON\" ||\n                        node.tagName === \"P\" ||\n                        node.tagName === \"INPUT\" ||\n                        (node.tagName === \"LI\" && node.classList.contains(\"item\")) // Gradio dropdown item\n                    ) {\n                        tooltipCheckNodes.add(node);\n                    }\n                }\n                node.querySelectorAll('span, button, p').forEach(n => tooltipCheckNodes.add(n));\n            }\n        }\n    }\n    if (tooltipCheckNodes.size) {\n        clearTimeout(tooltipCheckTimer);\n        tooltipCheckTimer = setTimeout(processTooltipCheckNodes, 1000);\n    }\n});\n\nonUiLoaded(function() {\n    for (var comp of window.gradio_config.components) {\n        if (comp.props.webui_tooltip && comp.props.elem_id) {\n            var elem = gradioApp().getElementById(comp.props.elem_id);\n            if (elem) {\n                elem.title = comp.props.webui_tooltip;\n            }\n        }\n    }\n});\n"
  },
  {
    "path": "javascript/hires_fix.js",
    "content": "\nfunction onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y) {\n    function setInactive(elem, inactive) {\n        elem.classList.toggle('inactive', !!inactive);\n    }\n\n    var hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale');\n    var hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x');\n    var hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y');\n\n    gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? \"none\" : \"\";\n\n    setInactive(hrUpscaleBy, opts.use_old_hires_fix_width_height || hr_resize_x > 0 || hr_resize_y > 0);\n    setInactive(hrResizeX, opts.use_old_hires_fix_width_height || hr_resize_x == 0);\n    setInactive(hrResizeY, opts.use_old_hires_fix_width_height || hr_resize_y == 0);\n\n    return [enable, width, height, hr_scale, hr_resize_x, hr_resize_y];\n}\n"
  },
  {
    "path": "javascript/imageMaskFix.js",
    "content": "/**\n * temporary fix for https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/668\n * @see https://github.com/gradio-app/gradio/issues/1721\n */\nfunction imageMaskResize() {\n    const canvases = gradioApp().querySelectorAll('#img2maskimg .touch-none canvas');\n    if (!canvases.length) {\n        window.removeEventListener('resize', imageMaskResize);\n        return;\n    }\n\n    const wrapper = canvases[0].closest('.touch-none');\n    const previewImage = wrapper.previousElementSibling;\n\n    if (!previewImage.complete) {\n        previewImage.addEventListener('load', imageMaskResize);\n        return;\n    }\n\n    const w = previewImage.width;\n    const h = previewImage.height;\n    const nw = previewImage.naturalWidth;\n    const nh = previewImage.naturalHeight;\n    const portrait = nh > nw;\n\n    const wW = Math.min(w, portrait ? h / nh * nw : w / nw * nw);\n    const wH = Math.min(h, portrait ? h / nh * nh : w / nw * nh);\n\n    wrapper.style.width = `${wW}px`;\n    wrapper.style.height = `${wH}px`;\n    wrapper.style.left = `0px`;\n    wrapper.style.top = `0px`;\n\n    canvases.forEach(c => {\n        c.style.width = c.style.height = '';\n        c.style.maxWidth = '100%';\n        c.style.maxHeight = '100%';\n        c.style.objectFit = 'contain';\n    });\n}\n\nonAfterUiUpdate(imageMaskResize);\nwindow.addEventListener('resize', imageMaskResize);\n"
  },
  {
    "path": "javascript/imageviewer.js",
    "content": "// A full size 'lightbox' preview modal shown when left clicking on gallery previews\nfunction closeModal() {\n    gradioApp().getElementById(\"lightboxModal\").style.display = \"none\";\n}\n\nfunction showModal(event) {\n    const source = event.target || event.srcElement;\n    const modalImage = gradioApp().getElementById(\"modalImage\");\n    const modalToggleLivePreviewBtn = gradioApp().getElementById(\"modal_toggle_live_preview\");\n    modalToggleLivePreviewBtn.innerHTML = opts.js_live_preview_in_modal_lightbox ? \"&#x1F5C7;\" : \"&#x1F5C6;\";\n    const lb = gradioApp().getElementById(\"lightboxModal\");\n    modalImage.src = source.src;\n    if (modalImage.style.display === 'none') {\n        lb.style.setProperty('background-image', 'url(' + source.src + ')');\n    }\n    lb.style.display = \"flex\";\n    lb.focus();\n\n    const tabTxt2Img = gradioApp().getElementById(\"tab_txt2img\");\n    const tabImg2Img = gradioApp().getElementById(\"tab_img2img\");\n    // show the save button in modal only on txt2img or img2img tabs\n    if (tabTxt2Img.style.display != \"none\" || tabImg2Img.style.display != \"none\") {\n        gradioApp().getElementById(\"modal_save\").style.display = \"inline\";\n    } else {\n        gradioApp().getElementById(\"modal_save\").style.display = \"none\";\n    }\n    event.stopPropagation();\n}\n\nfunction negmod(n, m) {\n    return ((n % m) + m) % m;\n}\n\nfunction updateOnBackgroundChange() {\n    const modalImage = gradioApp().getElementById(\"modalImage\");\n    if (modalImage && modalImage.offsetParent) {\n        let currentButton = selected_gallery_button();\n        let preview = gradioApp().querySelectorAll('.livePreview > img');\n        if (opts.js_live_preview_in_modal_lightbox && preview.length > 0) {\n            // show preview image if available\n            modalImage.src = preview[preview.length - 1].src;\n        } else if (currentButton?.children?.length > 0 && modalImage.src != currentButton.children[0].src) {\n            modalImage.src = currentButton.children[0].src;\n            if (modalImage.style.display === 'none') {\n                const modal = gradioApp().getElementById(\"lightboxModal\");\n                modal.style.setProperty('background-image', `url(${modalImage.src})`);\n            }\n        }\n    }\n}\n\nfunction modalImageSwitch(offset) {\n    var galleryButtons = all_gallery_buttons();\n\n    if (galleryButtons.length > 1) {\n        var result = selected_gallery_index();\n\n        if (result != -1) {\n            var nextButton = galleryButtons[negmod((result + offset), galleryButtons.length)];\n            nextButton.click();\n            const modalImage = gradioApp().getElementById(\"modalImage\");\n            const modal = gradioApp().getElementById(\"lightboxModal\");\n            modalImage.src = nextButton.children[0].src;\n            if (modalImage.style.display === 'none') {\n                modal.style.setProperty('background-image', `url(${modalImage.src})`);\n            }\n            setTimeout(function() {\n                modal.focus();\n            }, 10);\n        }\n    }\n}\n\nfunction saveImage() {\n    const tabTxt2Img = gradioApp().getElementById(\"tab_txt2img\");\n    const tabImg2Img = gradioApp().getElementById(\"tab_img2img\");\n    const saveTxt2Img = \"save_txt2img\";\n    const saveImg2Img = \"save_img2img\";\n    if (tabTxt2Img.style.display != \"none\") {\n        gradioApp().getElementById(saveTxt2Img).click();\n    } else if (tabImg2Img.style.display != \"none\") {\n        gradioApp().getElementById(saveImg2Img).click();\n    } else {\n        console.error(\"missing implementation for saving modal of this type\");\n    }\n}\n\nfunction modalSaveImage(event) {\n    saveImage();\n    event.stopPropagation();\n}\n\nfunction modalNextImage(event) {\n    modalImageSwitch(1);\n    event.stopPropagation();\n}\n\nfunction modalPrevImage(event) {\n    modalImageSwitch(-1);\n    event.stopPropagation();\n}\n\nfunction modalKeyHandler(event) {\n    switch (event.key) {\n    case \"s\":\n        saveImage();\n        break;\n    case \"ArrowLeft\":\n        modalPrevImage(event);\n        break;\n    case \"ArrowRight\":\n        modalNextImage(event);\n        break;\n    case \"Escape\":\n        closeModal();\n        break;\n    }\n}\n\nfunction setupImageForLightbox(e) {\n    if (e.dataset.modded) {\n        return;\n    }\n\n    e.dataset.modded = true;\n    e.style.cursor = 'pointer';\n    e.style.userSelect = 'none';\n\n    e.addEventListener('mousedown', function(evt) {\n        if (evt.button == 1) {\n            open(evt.target.src);\n            evt.preventDefault();\n            return;\n        }\n    }, true);\n\n    e.addEventListener('click', function(evt) {\n        if (!opts.js_modal_lightbox || evt.button != 0) return;\n\n        modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed);\n        evt.preventDefault();\n        showModal(evt);\n    }, true);\n\n}\n\nfunction modalZoomSet(modalImage, enable) {\n    if (modalImage) modalImage.classList.toggle('modalImageFullscreen', !!enable);\n}\n\nfunction modalZoomToggle(event) {\n    var modalImage = gradioApp().getElementById(\"modalImage\");\n    modalZoomSet(modalImage, !modalImage.classList.contains('modalImageFullscreen'));\n    event.stopPropagation();\n}\n\nfunction modalLivePreviewToggle(event) {\n    const modalToggleLivePreview = gradioApp().getElementById(\"modal_toggle_live_preview\");\n    opts.js_live_preview_in_modal_lightbox = !opts.js_live_preview_in_modal_lightbox;\n    modalToggleLivePreview.innerHTML = opts.js_live_preview_in_modal_lightbox ? \"&#x1F5C7;\" : \"&#x1F5C6;\";\n    event.stopPropagation();\n}\n\nfunction modalTileImageToggle(event) {\n    const modalImage = gradioApp().getElementById(\"modalImage\");\n    const modal = gradioApp().getElementById(\"lightboxModal\");\n    const isTiling = modalImage.style.display === 'none';\n    if (isTiling) {\n        modalImage.style.display = 'block';\n        modal.style.setProperty('background-image', 'none');\n    } else {\n        modalImage.style.display = 'none';\n        modal.style.setProperty('background-image', `url(${modalImage.src})`);\n    }\n\n    event.stopPropagation();\n}\n\nonAfterUiUpdate(function() {\n    var fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > div > img');\n    if (fullImg_preview != null) {\n        fullImg_preview.forEach(setupImageForLightbox);\n    }\n    updateOnBackgroundChange();\n});\n\ndocument.addEventListener(\"DOMContentLoaded\", function() {\n    //const modalFragment = document.createDocumentFragment();\n    const modal = document.createElement('div');\n    modal.onclick = closeModal;\n    modal.id = \"lightboxModal\";\n    modal.tabIndex = 0;\n    modal.addEventListener('keydown', modalKeyHandler, true);\n\n    const modalControls = document.createElement('div');\n    modalControls.className = 'modalControls gradio-container';\n    modal.append(modalControls);\n\n    const modalZoom = document.createElement('span');\n    modalZoom.className = 'modalZoom cursor';\n    modalZoom.innerHTML = '&#10529;';\n    modalZoom.addEventListener('click', modalZoomToggle, true);\n    modalZoom.title = \"Toggle zoomed view\";\n    modalControls.appendChild(modalZoom);\n\n    const modalTileImage = document.createElement('span');\n    modalTileImage.className = 'modalTileImage cursor';\n    modalTileImage.innerHTML = '&#8862;';\n    modalTileImage.addEventListener('click', modalTileImageToggle, true);\n    modalTileImage.title = \"Preview tiling\";\n    modalControls.appendChild(modalTileImage);\n\n    const modalSave = document.createElement(\"span\");\n    modalSave.className = \"modalSave cursor\";\n    modalSave.id = \"modal_save\";\n    modalSave.innerHTML = \"&#x1F5AB;\";\n    modalSave.addEventListener(\"click\", modalSaveImage, true);\n    modalSave.title = \"Save Image(s)\";\n    modalControls.appendChild(modalSave);\n\n    const modalToggleLivePreview = document.createElement('span');\n    modalToggleLivePreview.className = 'modalToggleLivePreview cursor';\n    modalToggleLivePreview.id = \"modal_toggle_live_preview\";\n    modalToggleLivePreview.innerHTML = \"&#x1F5C6;\";\n    modalToggleLivePreview.onclick = modalLivePreviewToggle;\n    modalToggleLivePreview.title = \"Toggle live preview\";\n    modalControls.appendChild(modalToggleLivePreview);\n\n    const modalClose = document.createElement('span');\n    modalClose.className = 'modalClose cursor';\n    modalClose.innerHTML = '&times;';\n    modalClose.onclick = closeModal;\n    modalClose.title = \"Close image viewer\";\n    modalControls.appendChild(modalClose);\n\n    const modalImage = document.createElement('img');\n    modalImage.id = 'modalImage';\n    modalImage.onclick = closeModal;\n    modalImage.tabIndex = 0;\n    modalImage.addEventListener('keydown', modalKeyHandler, true);\n    modal.appendChild(modalImage);\n\n    const modalPrev = document.createElement('a');\n    modalPrev.className = 'modalPrev';\n    modalPrev.innerHTML = '&#10094;';\n    modalPrev.tabIndex = 0;\n    modalPrev.addEventListener('click', modalPrevImage, true);\n    modalPrev.addEventListener('keydown', modalKeyHandler, true);\n    modal.appendChild(modalPrev);\n\n    const modalNext = document.createElement('a');\n    modalNext.className = 'modalNext';\n    modalNext.innerHTML = '&#10095;';\n    modalNext.tabIndex = 0;\n    modalNext.addEventListener('click', modalNextImage, true);\n    modalNext.addEventListener('keydown', modalKeyHandler, true);\n\n    modal.appendChild(modalNext);\n\n    try {\n        gradioApp().appendChild(modal);\n    } catch (e) {\n        gradioApp().body.appendChild(modal);\n    }\n\n    document.body.appendChild(modal);\n\n});\n"
  },
  {
    "path": "javascript/imageviewerGamepad.js",
    "content": "let gamepads = [];\n\nwindow.addEventListener('gamepadconnected', (e) => {\n    const index = e.gamepad.index;\n    let isWaiting = false;\n    gamepads[index] = setInterval(async() => {\n        if (!opts.js_modal_lightbox_gamepad || isWaiting) return;\n        const gamepad = navigator.getGamepads()[index];\n        const xValue = gamepad.axes[0];\n        if (xValue <= -0.3) {\n            modalPrevImage(e);\n            isWaiting = true;\n        } else if (xValue >= 0.3) {\n            modalNextImage(e);\n            isWaiting = true;\n        }\n        if (isWaiting) {\n            await sleepUntil(() => {\n                const xValue = navigator.getGamepads()[index].axes[0];\n                if (xValue < 0.3 && xValue > -0.3) {\n                    return true;\n                }\n            }, opts.js_modal_lightbox_gamepad_repeat);\n            isWaiting = false;\n        }\n    }, 10);\n});\n\nwindow.addEventListener('gamepaddisconnected', (e) => {\n    clearInterval(gamepads[e.gamepad.index]);\n});\n\n/*\nPrimarily for vr controller type pointer devices.\nI use the wheel event because there's currently no way to do it properly with web xr.\n */\nlet isScrolling = false;\nwindow.addEventListener('wheel', (e) => {\n    if (!opts.js_modal_lightbox_gamepad || isScrolling) return;\n    isScrolling = true;\n\n    if (e.deltaX <= -0.6) {\n        modalPrevImage(e);\n    } else if (e.deltaX >= 0.6) {\n        modalNextImage(e);\n    }\n\n    setTimeout(() => {\n        isScrolling = false;\n    }, opts.js_modal_lightbox_gamepad_repeat);\n});\n\nfunction sleepUntil(f, timeout) {\n    return new Promise((resolve) => {\n        const timeStart = new Date();\n        const wait = setInterval(function() {\n            if (f() || new Date() - timeStart > timeout) {\n                clearInterval(wait);\n                resolve();\n            }\n        }, 20);\n    });\n}\n"
  },
  {
    "path": "javascript/inputAccordion.js",
    "content": "function inputAccordionChecked(id, checked) {\n    var accordion = gradioApp().getElementById(id);\n    accordion.visibleCheckbox.checked = checked;\n    accordion.onVisibleCheckboxChange();\n}\n\nfunction setupAccordion(accordion) {\n    var labelWrap = accordion.querySelector('.label-wrap');\n    var gradioCheckbox = gradioApp().querySelector('#' + accordion.id + \"-checkbox input\");\n    var extra = gradioApp().querySelector('#' + accordion.id + \"-extra\");\n    var span = labelWrap.querySelector('span');\n    var linked = true;\n\n    var isOpen = function() {\n        return labelWrap.classList.contains('open');\n    };\n\n    var observerAccordionOpen = new MutationObserver(function(mutations) {\n        mutations.forEach(function(mutationRecord) {\n            accordion.classList.toggle('input-accordion-open', isOpen());\n\n            if (linked) {\n                accordion.visibleCheckbox.checked = isOpen();\n                accordion.onVisibleCheckboxChange();\n            }\n        });\n    });\n    observerAccordionOpen.observe(labelWrap, {attributes: true, attributeFilter: ['class']});\n\n    if (extra) {\n        labelWrap.insertBefore(extra, labelWrap.lastElementChild);\n    }\n\n    accordion.onChecked = function(checked) {\n        if (isOpen() != checked) {\n            labelWrap.click();\n        }\n    };\n\n    var visibleCheckbox = document.createElement('INPUT');\n    visibleCheckbox.type = 'checkbox';\n    visibleCheckbox.checked = isOpen();\n    visibleCheckbox.id = accordion.id + \"-visible-checkbox\";\n    visibleCheckbox.className = gradioCheckbox.className + \" input-accordion-checkbox\";\n    span.insertBefore(visibleCheckbox, span.firstChild);\n\n    accordion.visibleCheckbox = visibleCheckbox;\n    accordion.onVisibleCheckboxChange = function() {\n        if (linked && isOpen() != visibleCheckbox.checked) {\n            labelWrap.click();\n        }\n\n        gradioCheckbox.checked = visibleCheckbox.checked;\n        updateInput(gradioCheckbox);\n    };\n\n    visibleCheckbox.addEventListener('click', function(event) {\n        linked = false;\n        event.stopPropagation();\n    });\n    visibleCheckbox.addEventListener('input', accordion.onVisibleCheckboxChange);\n}\n\nonUiLoaded(function() {\n    for (var accordion of gradioApp().querySelectorAll('.input-accordion')) {\n        setupAccordion(accordion);\n    }\n});\n"
  },
  {
    "path": "javascript/localStorage.js",
    "content": "\nfunction localSet(k, v) {\n    try {\n        localStorage.setItem(k, v);\n    } catch (e) {\n        console.warn(`Failed to save ${k} to localStorage: ${e}`);\n    }\n}\n\nfunction localGet(k, def) {\n    try {\n        return localStorage.getItem(k);\n    } catch (e) {\n        console.warn(`Failed to load ${k} from localStorage: ${e}`);\n    }\n\n    return def;\n}\n\nfunction localRemove(k) {\n    try {\n        return localStorage.removeItem(k);\n    } catch (e) {\n        console.warn(`Failed to remove ${k} from localStorage: ${e}`);\n    }\n}\n"
  },
  {
    "path": "javascript/localization.js",
    "content": "\n// localization = {} -- the dict with translations is created by the backend\n\nvar ignore_ids_for_localization = {\n    setting_sd_hypernetwork: 'OPTION',\n    setting_sd_model_checkpoint: 'OPTION',\n    modelmerger_primary_model_name: 'OPTION',\n    modelmerger_secondary_model_name: 'OPTION',\n    modelmerger_tertiary_model_name: 'OPTION',\n    train_embedding: 'OPTION',\n    train_hypernetwork: 'OPTION',\n    txt2img_styles: 'OPTION',\n    img2img_styles: 'OPTION',\n    setting_random_artist_categories: 'OPTION',\n    setting_face_restoration_model: 'OPTION',\n    setting_realesrgan_enabled_models: 'OPTION',\n    extras_upscaler_1: 'OPTION',\n    extras_upscaler_2: 'OPTION',\n};\n\nvar re_num = /^[.\\d]+$/;\nvar re_emoji = /[\\p{Extended_Pictographic}\\u{1F3FB}-\\u{1F3FF}\\u{1F9B0}-\\u{1F9B3}]/u;\n\nvar original_lines = {};\nvar translated_lines = {};\n\nfunction hasLocalization() {\n    return window.localization && Object.keys(window.localization).length > 0;\n}\n\nfunction textNodesUnder(el) {\n    var n, a = [], walk = document.createTreeWalker(el, NodeFilter.SHOW_TEXT, null, false);\n    while ((n = walk.nextNode())) a.push(n);\n    return a;\n}\n\nfunction canBeTranslated(node, text) {\n    if (!text) return false;\n    if (!node.parentElement) return false;\n\n    var parentType = node.parentElement.nodeName;\n    if (parentType == 'SCRIPT' || parentType == 'STYLE' || parentType == 'TEXTAREA') return false;\n\n    if (parentType == 'OPTION' || parentType == 'SPAN') {\n        var pnode = node;\n        for (var level = 0; level < 4; level++) {\n            pnode = pnode.parentElement;\n            if (!pnode) break;\n\n            if (ignore_ids_for_localization[pnode.id] == parentType) return false;\n        }\n    }\n\n    if (re_num.test(text)) return false;\n    if (re_emoji.test(text)) return false;\n    return true;\n}\n\nfunction getTranslation(text) {\n    if (!text) return undefined;\n\n    if (translated_lines[text] === undefined) {\n        original_lines[text] = 1;\n    }\n\n    var tl = localization[text];\n    if (tl !== undefined) {\n        translated_lines[tl] = 1;\n    }\n\n    return tl;\n}\n\nfunction processTextNode(node) {\n    var text = node.textContent.trim();\n\n    if (!canBeTranslated(node, text)) return;\n\n    var tl = getTranslation(text);\n    if (tl !== undefined) {\n        node.textContent = tl;\n    }\n}\n\nfunction processNode(node) {\n    if (node.nodeType == 3) {\n        processTextNode(node);\n        return;\n    }\n\n    if (node.title) {\n        let tl = getTranslation(node.title);\n        if (tl !== undefined) {\n            node.title = tl;\n        }\n    }\n\n    if (node.placeholder) {\n        let tl = getTranslation(node.placeholder);\n        if (tl !== undefined) {\n            node.placeholder = tl;\n        }\n    }\n\n    textNodesUnder(node).forEach(function(node) {\n        processTextNode(node);\n    });\n}\n\nfunction localizeWholePage() {\n    processNode(gradioApp());\n\n    function elem(comp) {\n        var elem_id = comp.props.elem_id ? comp.props.elem_id : \"component-\" + comp.id;\n        return gradioApp().getElementById(elem_id);\n    }\n\n    for (var comp of window.gradio_config.components) {\n        if (comp.props.webui_tooltip) {\n            let e = elem(comp);\n\n            let tl = e ? getTranslation(e.title) : undefined;\n            if (tl !== undefined) {\n                e.title = tl;\n            }\n        }\n        if (comp.props.placeholder) {\n            let e = elem(comp);\n            let textbox = e ? e.querySelector('[placeholder]') : null;\n\n            let tl = textbox ? getTranslation(textbox.placeholder) : undefined;\n            if (tl !== undefined) {\n                textbox.placeholder = tl;\n            }\n        }\n    }\n}\n\nfunction dumpTranslations() {\n    if (!hasLocalization()) {\n        // If we don't have any localization,\n        // we will not have traversed the app to find\n        // original_lines, so do that now.\n        localizeWholePage();\n    }\n    var dumped = {};\n    if (localization.rtl) {\n        dumped.rtl = true;\n    }\n\n    for (const text in original_lines) {\n        if (dumped[text] !== undefined) continue;\n        dumped[text] = localization[text] || text;\n    }\n\n    return dumped;\n}\n\nfunction download_localization() {\n    var text = JSON.stringify(dumpTranslations(), null, 4);\n\n    var element = document.createElement('a');\n    element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text));\n    element.setAttribute('download', \"localization.json\");\n    element.style.display = 'none';\n    document.body.appendChild(element);\n\n    element.click();\n\n    document.body.removeChild(element);\n}\n\ndocument.addEventListener(\"DOMContentLoaded\", function() {\n    if (!hasLocalization()) {\n        return;\n    }\n\n    onUiUpdate(function(m) {\n        m.forEach(function(mutation) {\n            mutation.addedNodes.forEach(function(node) {\n                processNode(node);\n            });\n        });\n    });\n\n    localizeWholePage();\n\n    if (localization.rtl) { // if the language is from right to left,\n        (new MutationObserver((mutations, observer) => { // wait for the style to load\n            mutations.forEach(mutation => {\n                mutation.addedNodes.forEach(node => {\n                    if (node.tagName === 'STYLE') {\n                        observer.disconnect();\n\n                        for (const x of node.sheet.rules) { // find all rtl media rules\n                            if (Array.from(x.media || []).includes('rtl')) {\n                                x.media.appendMedium('all'); // enable them\n                            }\n                        }\n                    }\n                });\n            });\n        })).observe(gradioApp(), {childList: true});\n    }\n});\n"
  },
  {
    "path": "javascript/notification.js",
    "content": "// Monitors the gallery and sends a browser notification when the leading image is new.\n\nlet lastHeadImg = null;\n\nlet notificationButton = null;\n\nonAfterUiUpdate(function() {\n    if (notificationButton == null) {\n        notificationButton = gradioApp().getElementById('request_notifications');\n\n        if (notificationButton != null) {\n            notificationButton.addEventListener('click', () => {\n                void Notification.requestPermission();\n            }, true);\n        }\n    }\n\n    const galleryPreviews = gradioApp().querySelectorAll('div[id^=\"tab_\"] div[id$=\"_results\"] .thumbnail-item > img');\n\n    if (galleryPreviews == null) return;\n\n    const headImg = galleryPreviews[0]?.src;\n\n    if (headImg == null || headImg == lastHeadImg) return;\n\n    lastHeadImg = headImg;\n\n    // play notification sound if available\n    const notificationAudio = gradioApp().querySelector('#audio_notification audio');\n    if (notificationAudio) {\n        notificationAudio.volume = opts.notification_volume / 100.0 || 1.0;\n        notificationAudio.play();\n    }\n\n    if (document.hasFocus()) return;\n\n    // Multiple copies of the images are in the DOM when one is selected. Dedup with a Set to get the real number generated.\n    const imgs = new Set(Array.from(galleryPreviews).map(img => img.src));\n\n    const notification = new Notification(\n        'Stable Diffusion',\n        {\n            body: `Generated ${imgs.size > 1 ? imgs.size - opts.return_grid : 1} image${imgs.size > 1 ? 's' : ''}`,\n            icon: headImg,\n            image: headImg,\n        }\n    );\n\n    notification.onclick = function(_) {\n        parent.focus();\n        this.close();\n    };\n});\n"
  },
  {
    "path": "javascript/profilerVisualization.js",
    "content": "\nfunction createRow(table, cellName, items) {\n    var tr = document.createElement('tr');\n    var res = [];\n\n    items.forEach(function(x, i) {\n        if (x === undefined) {\n            res.push(null);\n            return;\n        }\n\n        var td = document.createElement(cellName);\n        td.textContent = x;\n        tr.appendChild(td);\n        res.push(td);\n\n        var colspan = 1;\n        for (var n = i + 1; n < items.length; n++) {\n            if (items[n] !== undefined) {\n                break;\n            }\n\n            colspan += 1;\n        }\n\n        if (colspan > 1) {\n            td.colSpan = colspan;\n        }\n    });\n\n    table.appendChild(tr);\n\n    return res;\n}\n\nfunction createVisualizationTable(data, cutoff = 0, sort = \"\") {\n    var table = document.createElement('table');\n    table.className = 'popup-table';\n\n    var keys = Object.keys(data);\n    if (sort === \"number\") {\n        keys = keys.sort(function(a, b) {\n            return data[b] - data[a];\n        });\n    } else {\n        keys = keys.sort();\n    }\n    var items = keys.map(function(x) {\n        return {key: x, parts: x.split('/'), value: data[x]};\n    });\n    var maxLength = items.reduce(function(a, b) {\n        return Math.max(a, b.parts.length);\n    }, 0);\n\n    var cols = createRow(\n        table,\n        'th',\n        [\n            cutoff === 0 ? 'key' : 'record',\n            cutoff === 0 ? 'value' : 'seconds'\n        ]\n    );\n    cols[0].colSpan = maxLength;\n\n    function arraysEqual(a, b) {\n        return !(a < b || b < a);\n    }\n\n    var addLevel = function(level, parent, hide) {\n        var matching = items.filter(function(x) {\n            return x.parts[level] && !x.parts[level + 1] && arraysEqual(x.parts.slice(0, level), parent);\n        });\n        if (sort === \"number\") {\n            matching = matching.sort(function(a, b) {\n                return b.value - a.value;\n            });\n        } else {\n            matching = matching.sort();\n        }\n        var othersTime = 0;\n        var othersList = [];\n        var othersRows = [];\n        var childrenRows = [];\n        matching.forEach(function(x) {\n            var visible = (cutoff === 0 && !hide) || (x.value >= cutoff && !hide);\n\n            var cells = [];\n            for (var i = 0; i < maxLength; i++) {\n                cells.push(x.parts[i]);\n            }\n            cells.push(cutoff === 0 ? x.value : x.value.toFixed(3));\n            var cols = createRow(table, 'td', cells);\n            for (i = 0; i < level; i++) {\n                cols[i].className = 'muted';\n            }\n\n            var tr = cols[0].parentNode;\n            if (!visible) {\n                tr.classList.add(\"hidden\");\n            }\n\n            if (cutoff === 0 || x.value >= cutoff) {\n                childrenRows.push(tr);\n            } else {\n                othersTime += x.value;\n                othersList.push(x.parts[level]);\n                othersRows.push(tr);\n            }\n\n            var children = addLevel(level + 1, parent.concat([x.parts[level]]), true);\n            if (children.length > 0) {\n                var cell = cols[level];\n                var onclick = function() {\n                    cell.classList.remove(\"link\");\n                    cell.removeEventListener(\"click\", onclick);\n                    children.forEach(function(x) {\n                        x.classList.remove(\"hidden\");\n                    });\n                };\n                cell.classList.add(\"link\");\n                cell.addEventListener(\"click\", onclick);\n            }\n        });\n\n        if (othersTime > 0) {\n            var cells = [];\n            for (var i = 0; i < maxLength; i++) {\n                cells.push(parent[i]);\n            }\n            cells.push(othersTime.toFixed(3));\n            cells[level] = 'others';\n            var cols = createRow(table, 'td', cells);\n            for (i = 0; i < level; i++) {\n                cols[i].className = 'muted';\n            }\n\n            var cell = cols[level];\n            var tr = cell.parentNode;\n            var onclick = function() {\n                tr.classList.add(\"hidden\");\n                cell.classList.remove(\"link\");\n                cell.removeEventListener(\"click\", onclick);\n                othersRows.forEach(function(x) {\n                    x.classList.remove(\"hidden\");\n                });\n            };\n\n            cell.title = othersList.join(\", \");\n            cell.classList.add(\"link\");\n            cell.addEventListener(\"click\", onclick);\n\n            if (hide) {\n                tr.classList.add(\"hidden\");\n            }\n\n            childrenRows.push(tr);\n        }\n\n        return childrenRows;\n    };\n\n    addLevel(0, []);\n\n    return table;\n}\n\nfunction showProfile(path, cutoff = 0.05) {\n    requestGet(path, {}, function(data) {\n        data.records['total'] = data.total;\n        const table = createVisualizationTable(data.records, cutoff, \"number\");\n        popup(table);\n    });\n}\n\n"
  },
  {
    "path": "javascript/progressbar.js",
    "content": "// code related to showing and updating progressbar shown as the image is being made\n\nfunction rememberGallerySelection() {\n\n}\n\nfunction getGallerySelectedIndex() {\n\n}\n\nfunction request(url, data, handler, errorHandler) {\n    var xhr = new XMLHttpRequest();\n    xhr.open(\"POST\", url, true);\n    xhr.setRequestHeader(\"Content-Type\", \"application/json\");\n    xhr.onreadystatechange = function() {\n        if (xhr.readyState === 4) {\n            if (xhr.status === 200) {\n                try {\n                    var js = JSON.parse(xhr.responseText);\n                    handler(js);\n                } catch (error) {\n                    console.error(error);\n                    errorHandler();\n                }\n            } else {\n                errorHandler();\n            }\n        }\n    };\n    var js = JSON.stringify(data);\n    xhr.send(js);\n}\n\nfunction pad2(x) {\n    return x < 10 ? '0' + x : x;\n}\n\nfunction formatTime(secs) {\n    if (secs > 3600) {\n        return pad2(Math.floor(secs / 60 / 60)) + \":\" + pad2(Math.floor(secs / 60) % 60) + \":\" + pad2(Math.floor(secs) % 60);\n    } else if (secs > 60) {\n        return pad2(Math.floor(secs / 60)) + \":\" + pad2(Math.floor(secs) % 60);\n    } else {\n        return Math.floor(secs) + \"s\";\n    }\n}\n\n\nvar originalAppTitle = undefined;\n\nonUiLoaded(function() {\n    originalAppTitle = document.title;\n});\n\nfunction setTitle(progress) {\n    var title = originalAppTitle;\n\n    if (opts.show_progress_in_title && progress) {\n        title = '[' + progress.trim() + '] ' + title;\n    }\n\n    if (document.title != title) {\n        document.title = title;\n    }\n}\n\n\nfunction randomId() {\n    return \"task(\" + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7) + \")\";\n}\n\n// starts sending progress requests to \"/internal/progress\" uri, creating progressbar above progressbarContainer element and\n// preview inside gallery element. Cleans up all created stuff when the task is over and calls atEnd.\n// calls onProgress every time there is a progress update\nfunction requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgress, inactivityTimeout = 40) {\n    var dateStart = new Date();\n    var wasEverActive = false;\n    var parentProgressbar = progressbarContainer.parentNode;\n    var wakeLock = null;\n\n    var requestWakeLock = async function() {\n        if (!opts.prevent_screen_sleep_during_generation || wakeLock) return;\n        try {\n            wakeLock = await navigator.wakeLock.request('screen');\n        } catch (err) {\n            console.error('Wake Lock is not supported.');\n        }\n    };\n\n    var releaseWakeLock = async function() {\n        if (!opts.prevent_screen_sleep_during_generation || !wakeLock) return;\n        try {\n            await wakeLock.release();\n            wakeLock = null;\n        } catch (err) {\n            console.error('Wake Lock release failed', err);\n        }\n    };\n\n    var divProgress = document.createElement('div');\n    divProgress.className = 'progressDiv';\n    divProgress.style.display = opts.show_progressbar ? \"block\" : \"none\";\n    var divInner = document.createElement('div');\n    divInner.className = 'progress';\n\n    divProgress.appendChild(divInner);\n    parentProgressbar.insertBefore(divProgress, progressbarContainer);\n\n    var livePreview = null;\n\n    var removeProgressBar = function() {\n        releaseWakeLock();\n        if (!divProgress) return;\n\n        setTitle(\"\");\n        parentProgressbar.removeChild(divProgress);\n        if (gallery && livePreview) gallery.removeChild(livePreview);\n        atEnd();\n\n        divProgress = null;\n    };\n\n    var funProgress = function(id_task) {\n        requestWakeLock();\n        request(\"./internal/progress\", {id_task: id_task, live_preview: false}, function(res) {\n            if (res.completed) {\n                removeProgressBar();\n                return;\n            }\n\n            let progressText = \"\";\n\n            divInner.style.width = ((res.progress || 0) * 100.0) + '%';\n            divInner.style.background = res.progress ? \"\" : \"transparent\";\n\n            if (res.progress > 0) {\n                progressText = ((res.progress || 0) * 100.0).toFixed(0) + '%';\n            }\n\n            if (res.eta) {\n                progressText += \" ETA: \" + formatTime(res.eta);\n            }\n\n            setTitle(progressText);\n\n            if (res.textinfo && res.textinfo.indexOf(\"\\n\") == -1) {\n                progressText = res.textinfo + \" \" + progressText;\n            }\n\n            divInner.textContent = progressText;\n\n            var elapsedFromStart = (new Date() - dateStart) / 1000;\n\n            if (res.active) wasEverActive = true;\n\n            if (!res.active && wasEverActive) {\n                removeProgressBar();\n                return;\n            }\n\n            if (elapsedFromStart > inactivityTimeout && !res.queued && !res.active) {\n                removeProgressBar();\n                return;\n            }\n\n            if (onProgress) {\n                onProgress(res);\n            }\n\n            setTimeout(() => {\n                funProgress(id_task, res.id_live_preview);\n            }, opts.live_preview_refresh_period || 500);\n        }, function() {\n            removeProgressBar();\n        });\n    };\n\n    var funLivePreview = function(id_task, id_live_preview) {\n        request(\"./internal/progress\", {id_task: id_task, id_live_preview: id_live_preview}, function(res) {\n            if (!divProgress) {\n                return;\n            }\n\n            if (res.live_preview && gallery) {\n                var img = new Image();\n                img.onload = function() {\n                    if (!livePreview) {\n                        livePreview = document.createElement('div');\n                        livePreview.className = 'livePreview';\n                        gallery.insertBefore(livePreview, gallery.firstElementChild);\n                    }\n\n                    livePreview.appendChild(img);\n                    if (livePreview.childElementCount > 2) {\n                        livePreview.removeChild(livePreview.firstElementChild);\n                    }\n                };\n                img.src = res.live_preview;\n            }\n\n            setTimeout(() => {\n                funLivePreview(id_task, res.id_live_preview);\n            }, opts.live_preview_refresh_period || 500);\n        }, function() {\n            removeProgressBar();\n        });\n    };\n\n    funProgress(id_task, 0);\n\n    if (gallery) {\n        funLivePreview(id_task, 0);\n    }\n\n}\n"
  },
  {
    "path": "javascript/resizeHandle.js",
    "content": "(function() {\n    const GRADIO_MIN_WIDTH = 320;\n    const PAD = 16;\n    const DEBOUNCE_TIME = 100;\n    const DOUBLE_TAP_DELAY = 200; //ms\n\n    const R = {\n        tracking: false,\n        parent: null,\n        parentWidth: null,\n        leftCol: null,\n        leftColStartWidth: null,\n        screenX: null,\n        lastTapTime: null,\n    };\n\n    let resizeTimer;\n    let parents = [];\n\n    function setLeftColGridTemplate(el, width) {\n        el.style.gridTemplateColumns = `${width}px 16px 1fr`;\n    }\n\n    function displayResizeHandle(parent) {\n        if (!parent.needHideOnMoblie) {\n            return true;\n        }\n        if (window.innerWidth < GRADIO_MIN_WIDTH * 2 + PAD * 4) {\n            parent.style.display = 'flex';\n            parent.resizeHandle.style.display = \"none\";\n            return false;\n        } else {\n            parent.style.display = 'grid';\n            parent.resizeHandle.style.display = \"block\";\n            return true;\n        }\n    }\n\n    function afterResize(parent) {\n        if (displayResizeHandle(parent) && parent.style.gridTemplateColumns != parent.style.originalGridTemplateColumns) {\n            const oldParentWidth = R.parentWidth;\n            const newParentWidth = parent.offsetWidth;\n            const widthL = parseInt(parent.style.gridTemplateColumns.split(' ')[0]);\n\n            const ratio = newParentWidth / oldParentWidth;\n\n            const newWidthL = Math.max(Math.floor(ratio * widthL), parent.minLeftColWidth);\n            setLeftColGridTemplate(parent, newWidthL);\n\n            R.parentWidth = newParentWidth;\n        }\n    }\n\n    function setup(parent) {\n\n        function onDoubleClick(evt) {\n            evt.preventDefault();\n            evt.stopPropagation();\n\n            parent.style.gridTemplateColumns = parent.style.originalGridTemplateColumns;\n        }\n\n        const leftCol = parent.firstElementChild;\n        const rightCol = parent.lastElementChild;\n\n        parents.push(parent);\n\n        parent.style.display = 'grid';\n        parent.style.gap = '0';\n        let leftColTemplate = \"\";\n        if (parent.children[0].style.flexGrow) {\n            leftColTemplate = `${parent.children[0].style.flexGrow}fr`;\n            parent.minLeftColWidth = GRADIO_MIN_WIDTH;\n            parent.minRightColWidth = GRADIO_MIN_WIDTH;\n            parent.needHideOnMoblie = true;\n        } else {\n            leftColTemplate = parent.children[0].style.flexBasis;\n            parent.minLeftColWidth = parent.children[0].style.flexBasis.slice(0, -2) / 2;\n            parent.minRightColWidth = 0;\n            parent.needHideOnMoblie = false;\n        }\n\n        if (!leftColTemplate) {\n            leftColTemplate = '1fr';\n        }\n\n        const gridTemplateColumns = `${leftColTemplate} ${PAD}px ${parent.children[1].style.flexGrow}fr`;\n        parent.style.gridTemplateColumns = gridTemplateColumns;\n        parent.style.originalGridTemplateColumns = gridTemplateColumns;\n\n        const resizeHandle = document.createElement('div');\n        resizeHandle.classList.add('resize-handle');\n        parent.insertBefore(resizeHandle, rightCol);\n        parent.resizeHandle = resizeHandle;\n\n        ['mousedown', 'touchstart'].forEach((eventType) => {\n            resizeHandle.addEventListener(eventType, (evt) => {\n                if (eventType.startsWith('mouse')) {\n                    if (evt.button !== 0) return;\n                } else {\n                    if (evt.changedTouches.length !== 1) return;\n\n                    const currentTime = new Date().getTime();\n                    if (R.lastTapTime && currentTime - R.lastTapTime <= DOUBLE_TAP_DELAY) {\n                        onDoubleClick(evt);\n                        return;\n                    }\n\n                    R.lastTapTime = currentTime;\n                }\n\n                evt.preventDefault();\n                evt.stopPropagation();\n\n                document.body.classList.add('resizing');\n\n                R.tracking = true;\n                R.parent = parent;\n                R.parentWidth = parent.offsetWidth;\n                R.leftCol = leftCol;\n                R.leftColStartWidth = leftCol.offsetWidth;\n                if (eventType.startsWith('mouse')) {\n                    R.screenX = evt.screenX;\n                } else {\n                    R.screenX = evt.changedTouches[0].screenX;\n                }\n            });\n        });\n\n        resizeHandle.addEventListener('dblclick', onDoubleClick);\n\n        afterResize(parent);\n    }\n\n    ['mousemove', 'touchmove'].forEach((eventType) => {\n        window.addEventListener(eventType, (evt) => {\n            if (eventType.startsWith('mouse')) {\n                if (evt.button !== 0) return;\n            } else {\n                if (evt.changedTouches.length !== 1) return;\n            }\n\n            if (R.tracking) {\n                if (eventType.startsWith('mouse')) {\n                    evt.preventDefault();\n                }\n                evt.stopPropagation();\n\n                let delta = 0;\n                if (eventType.startsWith('mouse')) {\n                    delta = R.screenX - evt.screenX;\n                } else {\n                    delta = R.screenX - evt.changedTouches[0].screenX;\n                }\n                const leftColWidth = Math.max(Math.min(R.leftColStartWidth - delta, R.parent.offsetWidth - R.parent.minRightColWidth - PAD), R.parent.minLeftColWidth);\n                setLeftColGridTemplate(R.parent, leftColWidth);\n            }\n        });\n    });\n\n    ['mouseup', 'touchend'].forEach((eventType) => {\n        window.addEventListener(eventType, (evt) => {\n            if (eventType.startsWith('mouse')) {\n                if (evt.button !== 0) return;\n            } else {\n                if (evt.changedTouches.length !== 1) return;\n            }\n\n            if (R.tracking) {\n                evt.preventDefault();\n                evt.stopPropagation();\n\n                R.tracking = false;\n\n                document.body.classList.remove('resizing');\n            }\n        });\n    });\n\n\n    window.addEventListener('resize', () => {\n        clearTimeout(resizeTimer);\n\n        resizeTimer = setTimeout(function() {\n            for (const parent of parents) {\n                afterResize(parent);\n            }\n        }, DEBOUNCE_TIME);\n    });\n\n    setupResizeHandle = setup;\n})();\n\n\nfunction setupAllResizeHandles() {\n    for (var elem of gradioApp().querySelectorAll('.resize-handle-row')) {\n        if (!elem.querySelector('.resize-handle') && !elem.children[0].classList.contains(\"hidden\")) {\n            setupResizeHandle(elem);\n        }\n    }\n}\n\n\nonUiLoaded(setupAllResizeHandles);\n\n"
  },
  {
    "path": "javascript/settings.js",
    "content": "let settingsExcludeTabsFromShowAll = {\n    settings_tab_defaults: 1,\n    settings_tab_sysinfo: 1,\n    settings_tab_actions: 1,\n    settings_tab_licenses: 1,\n};\n\nfunction settingsShowAllTabs() {\n    gradioApp().querySelectorAll('#settings > div').forEach(function(elem) {\n        if (settingsExcludeTabsFromShowAll[elem.id]) return;\n\n        elem.style.display = \"block\";\n    });\n}\n\nfunction settingsShowOneTab() {\n    gradioApp().querySelector('#settings_show_one_page').click();\n}\n\nonUiLoaded(function() {\n    var edit = gradioApp().querySelector('#settings_search');\n    var editTextarea = gradioApp().querySelector('#settings_search > label > input');\n    var buttonShowAllPages = gradioApp().getElementById('settings_show_all_pages');\n    var settings_tabs = gradioApp().querySelector('#settings div');\n\n    onEdit('settingsSearch', editTextarea, 250, function() {\n        var searchText = (editTextarea.value || \"\").trim().toLowerCase();\n\n        gradioApp().querySelectorAll('#settings > div[id^=settings_] div[id^=column_settings_] > *').forEach(function(elem) {\n            var visible = elem.textContent.trim().toLowerCase().indexOf(searchText) != -1;\n            elem.style.display = visible ? \"\" : \"none\";\n        });\n\n        if (searchText != \"\") {\n            settingsShowAllTabs();\n        } else {\n            settingsShowOneTab();\n        }\n    });\n\n    settings_tabs.insertBefore(edit, settings_tabs.firstChild);\n    settings_tabs.appendChild(buttonShowAllPages);\n\n\n    buttonShowAllPages.addEventListener(\"click\", settingsShowAllTabs);\n});\n\n\nonOptionsChanged(function() {\n    if (gradioApp().querySelector('#settings .settings-category')) return;\n\n    var sectionMap = {};\n    gradioApp().querySelectorAll('#settings > div > button').forEach(function(x) {\n        sectionMap[x.textContent.trim()] = x;\n    });\n\n    opts._categories.forEach(function(x) {\n        var section = localization[x[0]] ?? x[0];\n        var category = localization[x[1]] ?? x[1];\n\n        var span = document.createElement('SPAN');\n        span.textContent = category;\n        span.className = 'settings-category';\n\n        var sectionElem = sectionMap[section];\n        if (!sectionElem) return;\n\n        sectionElem.parentElement.insertBefore(span, sectionElem);\n    });\n});\n\n"
  },
  {
    "path": "javascript/textualInversion.js",
    "content": "\n\n\nfunction start_training_textual_inversion() {\n    gradioApp().querySelector('#ti_error').innerHTML = '';\n\n    var id = randomId();\n    requestProgress(id, gradioApp().getElementById('ti_output'), gradioApp().getElementById('ti_gallery'), function() {}, function(progress) {\n        gradioApp().getElementById('ti_progress').innerHTML = progress.textinfo;\n    });\n\n    var res = Array.from(arguments);\n\n    res[0] = id;\n\n    return res;\n}\n"
  },
  {
    "path": "javascript/token-counters.js",
    "content": "let promptTokenCountUpdateFunctions = {};\n\nfunction update_txt2img_tokens(...args) {\n    // Called from Gradio\n    update_token_counter(\"txt2img_token_button\");\n    update_token_counter(\"txt2img_negative_token_button\");\n    if (args.length == 2) {\n        return args[0];\n    }\n    return args;\n}\n\nfunction update_img2img_tokens(...args) {\n    // Called from Gradio\n    update_token_counter(\"img2img_token_button\");\n    update_token_counter(\"img2img_negative_token_button\");\n    if (args.length == 2) {\n        return args[0];\n    }\n    return args;\n}\n\nfunction update_token_counter(button_id) {\n    promptTokenCountUpdateFunctions[button_id]?.();\n}\n\n\nfunction recalculatePromptTokens(name) {\n    promptTokenCountUpdateFunctions[name]?.();\n}\n\nfunction recalculate_prompts_txt2img() {\n    // Called from Gradio\n    recalculatePromptTokens('txt2img_prompt');\n    recalculatePromptTokens('txt2img_neg_prompt');\n    return Array.from(arguments);\n}\n\nfunction recalculate_prompts_img2img() {\n    // Called from Gradio\n    recalculatePromptTokens('img2img_prompt');\n    recalculatePromptTokens('img2img_neg_prompt');\n    return Array.from(arguments);\n}\n\nfunction setupTokenCounting(id, id_counter, id_button) {\n    var prompt = gradioApp().getElementById(id);\n    var counter = gradioApp().getElementById(id_counter);\n    var textarea = gradioApp().querySelector(`#${id} > label > textarea`);\n\n    if (counter.parentElement == prompt.parentElement) {\n        return;\n    }\n\n    prompt.parentElement.insertBefore(counter, prompt);\n    prompt.parentElement.style.position = \"relative\";\n\n    var func = onEdit(id, textarea, 800, function() {\n        if (counter.classList.contains(\"token-counter-visible\")) {\n            gradioApp().getElementById(id_button)?.click();\n        }\n    });\n    promptTokenCountUpdateFunctions[id] = func;\n    promptTokenCountUpdateFunctions[id_button] = func;\n}\n\nfunction toggleTokenCountingVisibility(id, id_counter, id_button) {\n    var counter = gradioApp().getElementById(id_counter);\n\n    counter.style.display = opts.disable_token_counters ? \"none\" : \"block\";\n    counter.classList.toggle(\"token-counter-visible\", !opts.disable_token_counters);\n}\n\nfunction runCodeForTokenCounters(fun) {\n    fun('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button');\n    fun('txt2img_neg_prompt', 'txt2img_negative_token_counter', 'txt2img_negative_token_button');\n    fun('img2img_prompt', 'img2img_token_counter', 'img2img_token_button');\n    fun('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button');\n}\n\nonUiLoaded(function() {\n    runCodeForTokenCounters(setupTokenCounting);\n});\n\nonOptionsChanged(function() {\n    runCodeForTokenCounters(toggleTokenCountingVisibility);\n});\n"
  },
  {
    "path": "javascript/ui.js",
    "content": "// various functions for interaction with ui.py not large enough to warrant putting them in separate files\n\nfunction set_theme(theme) {\n    var gradioURL = window.location.href;\n    if (!gradioURL.includes('?__theme=')) {\n        window.location.replace(gradioURL + '?__theme=' + theme);\n    }\n}\n\nfunction all_gallery_buttons() {\n    var allGalleryButtons = gradioApp().querySelectorAll('[style=\"display: block;\"].tabitem div[id$=_gallery].gradio-gallery .thumbnails > .thumbnail-item.thumbnail-small');\n    var visibleGalleryButtons = [];\n    allGalleryButtons.forEach(function(elem) {\n        if (elem.parentElement.offsetParent) {\n            visibleGalleryButtons.push(elem);\n        }\n    });\n    return visibleGalleryButtons;\n}\n\nfunction selected_gallery_button() {\n    return all_gallery_buttons().find(elem => elem.classList.contains('selected')) ?? null;\n}\n\nfunction selected_gallery_index() {\n    return all_gallery_buttons().findIndex(elem => elem.classList.contains('selected'));\n}\n\nfunction gallery_container_buttons(gallery_container) {\n    return gradioApp().querySelectorAll(`#${gallery_container} .thumbnail-item.thumbnail-small`);\n}\n\nfunction selected_gallery_index_id(gallery_container) {\n    return Array.from(gallery_container_buttons(gallery_container)).findIndex(elem => elem.classList.contains('selected'));\n}\n\nfunction extract_image_from_gallery(gallery) {\n    if (gallery.length == 0) {\n        return [null];\n    }\n    if (gallery.length == 1) {\n        return [gallery[0]];\n    }\n\n    var index = selected_gallery_index();\n\n    if (index < 0 || index >= gallery.length) {\n        // Use the first image in the gallery as the default\n        index = 0;\n    }\n\n    return [gallery[index]];\n}\n\nwindow.args_to_array = Array.from; // Compatibility with e.g. extensions that may expect this to be around\n\nfunction switch_to_txt2img() {\n    gradioApp().querySelector('#tabs').querySelectorAll('button')[0].click();\n\n    return Array.from(arguments);\n}\n\nfunction switch_to_img2img_tab(no) {\n    gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click();\n    gradioApp().getElementById('mode_img2img').querySelectorAll('button')[no].click();\n}\nfunction switch_to_img2img() {\n    switch_to_img2img_tab(0);\n    return Array.from(arguments);\n}\n\nfunction switch_to_sketch() {\n    switch_to_img2img_tab(1);\n    return Array.from(arguments);\n}\n\nfunction switch_to_inpaint() {\n    switch_to_img2img_tab(2);\n    return Array.from(arguments);\n}\n\nfunction switch_to_inpaint_sketch() {\n    switch_to_img2img_tab(3);\n    return Array.from(arguments);\n}\n\nfunction switch_to_extras() {\n    gradioApp().querySelector('#tabs').querySelectorAll('button')[2].click();\n\n    return Array.from(arguments);\n}\n\nfunction get_tab_index(tabId) {\n    let buttons = gradioApp().getElementById(tabId).querySelector('div').querySelectorAll('button');\n    for (let i = 0; i < buttons.length; i++) {\n        if (buttons[i].classList.contains('selected')) {\n            return i;\n        }\n    }\n    return 0;\n}\n\nfunction create_tab_index_args(tabId, args) {\n    var res = Array.from(args);\n    res[0] = get_tab_index(tabId);\n    return res;\n}\n\nfunction get_img2img_tab_index() {\n    let res = Array.from(arguments);\n    res.splice(-2);\n    res[0] = get_tab_index('mode_img2img');\n    return res;\n}\n\nfunction create_submit_args(args) {\n    var res = Array.from(args);\n\n    // As it is currently, txt2img and img2img send back the previous output args (txt2img_gallery, generation_info, html_info) whenever you generate a new image.\n    // This can lead to uploading a huge gallery of previously generated images, which leads to an unnecessary delay between submitting and beginning to generate.\n    // I don't know why gradio is sending outputs along with inputs, but we can prevent sending the image gallery here, which seems to be an issue for some.\n    // If gradio at some point stops sending outputs, this may break something\n    if (Array.isArray(res[res.length - 3])) {\n        res[res.length - 3] = null;\n    }\n\n    return res;\n}\n\nfunction setSubmitButtonsVisibility(tabname, showInterrupt, showSkip, showInterrupting) {\n    gradioApp().getElementById(tabname + '_interrupt').style.display = showInterrupt ? \"block\" : \"none\";\n    gradioApp().getElementById(tabname + '_skip').style.display = showSkip ? \"block\" : \"none\";\n    gradioApp().getElementById(tabname + '_interrupting').style.display = showInterrupting ? \"block\" : \"none\";\n}\n\nfunction showSubmitButtons(tabname, show) {\n    setSubmitButtonsVisibility(tabname, !show, !show, false);\n}\n\nfunction showSubmitInterruptingPlaceholder(tabname) {\n    setSubmitButtonsVisibility(tabname, false, true, true);\n}\n\nfunction showRestoreProgressButton(tabname, show) {\n    var button = gradioApp().getElementById(tabname + \"_restore_progress\");\n    if (!button) return;\n    button.style.setProperty('display', show ? 'flex' : 'none', 'important');\n}\n\nfunction submit() {\n    showSubmitButtons('txt2img', false);\n\n    var id = randomId();\n    localSet(\"txt2img_task_id\", id);\n\n    requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() {\n        showSubmitButtons('txt2img', true);\n        localRemove(\"txt2img_task_id\");\n        showRestoreProgressButton('txt2img', false);\n    });\n\n    var res = create_submit_args(arguments);\n\n    res[0] = id;\n\n    return res;\n}\n\nfunction submit_txt2img_upscale() {\n    var res = submit(...arguments);\n\n    res[2] = selected_gallery_index();\n\n    return res;\n}\n\nfunction submit_img2img() {\n    showSubmitButtons('img2img', false);\n\n    var id = randomId();\n    localSet(\"img2img_task_id\", id);\n\n    requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function() {\n        showSubmitButtons('img2img', true);\n        localRemove(\"img2img_task_id\");\n        showRestoreProgressButton('img2img', false);\n    });\n\n    var res = create_submit_args(arguments);\n\n    res[0] = id;\n    res[1] = get_tab_index('mode_img2img');\n\n    return res;\n}\n\nfunction submit_extras() {\n    showSubmitButtons('extras', false);\n\n    var id = randomId();\n\n    requestProgress(id, gradioApp().getElementById('extras_gallery_container'), gradioApp().getElementById('extras_gallery'), function() {\n        showSubmitButtons('extras', true);\n    });\n\n    var res = create_submit_args(arguments);\n\n    res[0] = id;\n\n    console.log(res);\n    return res;\n}\n\nfunction restoreProgressTxt2img() {\n    showRestoreProgressButton(\"txt2img\", false);\n    var id = localGet(\"txt2img_task_id\");\n\n    if (id) {\n        showSubmitInterruptingPlaceholder('txt2img');\n        requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() {\n            showSubmitButtons('txt2img', true);\n        }, null, 0);\n    }\n\n    return id;\n}\n\nfunction restoreProgressImg2img() {\n    showRestoreProgressButton(\"img2img\", false);\n\n    var id = localGet(\"img2img_task_id\");\n\n    if (id) {\n        showSubmitInterruptingPlaceholder('img2img');\n        requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function() {\n            showSubmitButtons('img2img', true);\n        }, null, 0);\n    }\n\n    return id;\n}\n\n\n/**\n * Configure the width and height elements on `tabname` to accept\n * pasting of resolutions in the form of \"width x height\".\n */\nfunction setupResolutionPasting(tabname) {\n    var width = gradioApp().querySelector(`#${tabname}_width input[type=number]`);\n    var height = gradioApp().querySelector(`#${tabname}_height input[type=number]`);\n    for (const el of [width, height]) {\n        el.addEventListener('paste', function(event) {\n            var pasteData = event.clipboardData.getData('text/plain');\n            var parsed = pasteData.match(/^\\s*(\\d+)\\D+(\\d+)\\s*$/);\n            if (parsed) {\n                width.value = parsed[1];\n                height.value = parsed[2];\n                updateInput(width);\n                updateInput(height);\n                event.preventDefault();\n            }\n        });\n    }\n}\n\nonUiLoaded(function() {\n    showRestoreProgressButton('txt2img', localGet(\"txt2img_task_id\"));\n    showRestoreProgressButton('img2img', localGet(\"img2img_task_id\"));\n    setupResolutionPasting('txt2img');\n    setupResolutionPasting('img2img');\n});\n\n\nfunction modelmerger() {\n    var id = randomId();\n    requestProgress(id, gradioApp().getElementById('modelmerger_results_panel'), null, function() {});\n\n    var res = create_submit_args(arguments);\n    res[0] = id;\n    return res;\n}\n\n\nfunction ask_for_style_name(_, prompt_text, negative_prompt_text) {\n    var name_ = prompt('Style name:');\n    return [name_, prompt_text, negative_prompt_text];\n}\n\nfunction confirm_clear_prompt(prompt, negative_prompt) {\n    if (confirm(\"Delete prompt?\")) {\n        prompt = \"\";\n        negative_prompt = \"\";\n    }\n\n    return [prompt, negative_prompt];\n}\n\n\nvar opts = {};\nonAfterUiUpdate(function() {\n    if (Object.keys(opts).length != 0) return;\n\n    var json_elem = gradioApp().getElementById('settings_json');\n    if (json_elem == null) return;\n\n    var textarea = json_elem.querySelector('textarea');\n    var jsdata = textarea.value;\n    opts = JSON.parse(jsdata);\n\n    executeCallbacks(optionsAvailableCallbacks); /*global optionsAvailableCallbacks*/\n    executeCallbacks(optionsChangedCallbacks); /*global optionsChangedCallbacks*/\n\n    Object.defineProperty(textarea, 'value', {\n        set: function(newValue) {\n            var valueProp = Object.getOwnPropertyDescriptor(HTMLTextAreaElement.prototype, 'value');\n            var oldValue = valueProp.get.call(textarea);\n            valueProp.set.call(textarea, newValue);\n\n            if (oldValue != newValue) {\n                opts = JSON.parse(textarea.value);\n            }\n\n            executeCallbacks(optionsChangedCallbacks);\n        },\n        get: function() {\n            var valueProp = Object.getOwnPropertyDescriptor(HTMLTextAreaElement.prototype, 'value');\n            return valueProp.get.call(textarea);\n        }\n    });\n\n    json_elem.parentElement.style.display = \"none\";\n});\n\nonOptionsChanged(function() {\n    var elem = gradioApp().getElementById('sd_checkpoint_hash');\n    var sd_checkpoint_hash = opts.sd_checkpoint_hash || \"\";\n    var shorthash = sd_checkpoint_hash.substring(0, 10);\n\n    if (elem && elem.textContent != shorthash) {\n        elem.textContent = shorthash;\n        elem.title = sd_checkpoint_hash;\n        elem.href = \"https://google.com/search?q=\" + sd_checkpoint_hash;\n    }\n});\n\nlet txt2img_textarea, img2img_textarea = undefined;\n\nfunction restart_reload() {\n    document.body.style.backgroundColor = \"var(--background-fill-primary)\";\n    document.body.innerHTML = '<h1 style=\"font-family:monospace;margin-top:20%;color:lightgray;text-align:center;\">Reloading...</h1>';\n    var requestPing = function() {\n        requestGet(\"./internal/ping\", {}, function(data) {\n            location.reload();\n        }, function() {\n            setTimeout(requestPing, 500);\n        });\n    };\n\n    setTimeout(requestPing, 2000);\n\n    return [];\n}\n\n// Simulate an `input` DOM event for Gradio Textbox component. Needed after you edit its contents in javascript, otherwise your edits\n// will only visible on web page and not sent to python.\nfunction updateInput(target) {\n    let e = new Event(\"input\", {bubbles: true});\n    Object.defineProperty(e, \"target\", {value: target});\n    target.dispatchEvent(e);\n}\n\n\nvar desiredCheckpointName = null;\nfunction selectCheckpoint(name) {\n    desiredCheckpointName = name;\n    gradioApp().getElementById('change_checkpoint').click();\n}\n\nfunction currentImg2imgSourceResolution(w, h, scaleBy) {\n    var img = gradioApp().querySelector('#mode_img2img > div[style=\"display: block;\"] img');\n    return img ? [img.naturalWidth, img.naturalHeight, scaleBy] : [0, 0, scaleBy];\n}\n\nfunction updateImg2imgResizeToTextAfterChangingImage() {\n    // At the time this is called from gradio, the image has no yet been replaced.\n    // There may be a better solution, but this is simple and straightforward so I'm going with it.\n\n    setTimeout(function() {\n        gradioApp().getElementById('img2img_update_resize_to').click();\n    }, 500);\n\n    return [];\n\n}\n\n\n\nfunction setRandomSeed(elem_id) {\n    var input = gradioApp().querySelector(\"#\" + elem_id + \" input\");\n    if (!input) return [];\n\n    input.value = \"-1\";\n    updateInput(input);\n    return [];\n}\n\nfunction switchWidthHeight(tabname) {\n    var width = gradioApp().querySelector(\"#\" + tabname + \"_width input[type=number]\");\n    var height = gradioApp().querySelector(\"#\" + tabname + \"_height input[type=number]\");\n    if (!width || !height) return [];\n\n    var tmp = width.value;\n    width.value = height.value;\n    height.value = tmp;\n\n    updateInput(width);\n    updateInput(height);\n    return [];\n}\n\n\nvar onEditTimers = {};\n\n// calls func after afterMs milliseconds has passed since the input elem has been edited by user\nfunction onEdit(editId, elem, afterMs, func) {\n    var edited = function() {\n        var existingTimer = onEditTimers[editId];\n        if (existingTimer) clearTimeout(existingTimer);\n\n        onEditTimers[editId] = setTimeout(func, afterMs);\n    };\n\n    elem.addEventListener(\"input\", edited);\n\n    return edited;\n}\n"
  },
  {
    "path": "javascript/ui_settings_hints.js",
    "content": "// various hints and extra info for the settings tab\n\nvar settingsHintsSetup = false;\n\nonOptionsChanged(function() {\n    if (settingsHintsSetup) return;\n    settingsHintsSetup = true;\n\n    gradioApp().querySelectorAll('#settings [id^=setting_]').forEach(function(div) {\n        var name = div.id.substr(8);\n        var commentBefore = opts._comments_before[name];\n        var commentAfter = opts._comments_after[name];\n\n        if (!commentBefore && !commentAfter) return;\n\n        var span = null;\n        if (div.classList.contains('gradio-checkbox')) span = div.querySelector('label span');\n        else if (div.classList.contains('gradio-checkboxgroup')) span = div.querySelector('span').firstChild;\n        else if (div.classList.contains('gradio-radio')) span = div.querySelector('span').firstChild;\n        else span = div.querySelector('label span').firstChild;\n\n        if (!span) return;\n\n        if (commentBefore) {\n            var comment = document.createElement('DIV');\n            comment.className = 'settings-comment';\n            comment.innerHTML = commentBefore;\n            span.parentElement.insertBefore(document.createTextNode('\\xa0'), span);\n            span.parentElement.insertBefore(comment, span);\n            span.parentElement.insertBefore(document.createTextNode('\\xa0'), span);\n        }\n        if (commentAfter) {\n            comment = document.createElement('DIV');\n            comment.className = 'settings-comment';\n            comment.innerHTML = commentAfter;\n            span.parentElement.insertBefore(comment, span.nextSibling);\n            span.parentElement.insertBefore(document.createTextNode('\\xa0'), span.nextSibling);\n        }\n    });\n});\n\nfunction settingsHintsShowQuicksettings() {\n    requestGet(\"./internal/quicksettings-hint\", {}, function(data) {\n        var table = document.createElement('table');\n        table.className = 'popup-table';\n\n        data.forEach(function(obj) {\n            var tr = document.createElement('tr');\n            var td = document.createElement('td');\n            td.textContent = obj.name;\n            tr.appendChild(td);\n\n            td = document.createElement('td');\n            td.textContent = obj.label;\n            tr.appendChild(td);\n\n            table.appendChild(tr);\n        });\n\n        popup(table);\n    });\n}\n"
  },
  {
    "path": "launch.py",
    "content": "from modules import launch_utils\r\n\r\nargs = launch_utils.args\r\npython = launch_utils.python\r\ngit = launch_utils.git\r\nindex_url = launch_utils.index_url\r\ndir_repos = launch_utils.dir_repos\r\n\r\ncommit_hash = launch_utils.commit_hash\r\ngit_tag = launch_utils.git_tag\r\n\r\nrun = launch_utils.run\r\nis_installed = launch_utils.is_installed\r\nrepo_dir = launch_utils.repo_dir\r\n\r\nrun_pip = launch_utils.run_pip\r\ncheck_run_python = launch_utils.check_run_python\r\ngit_clone = launch_utils.git_clone\r\ngit_pull_recursive = launch_utils.git_pull_recursive\r\nlist_extensions = launch_utils.list_extensions\r\nrun_extension_installer = launch_utils.run_extension_installer\r\nprepare_environment = launch_utils.prepare_environment\r\nconfigure_for_tests = launch_utils.configure_for_tests\r\nstart = launch_utils.start\r\n\r\n\r\ndef main():\r\n    if args.dump_sysinfo:\r\n        filename = launch_utils.dump_sysinfo()\r\n\r\n        print(f\"Sysinfo saved as {filename}. Exiting...\")\r\n\r\n        exit(0)\r\n\r\n    launch_utils.startup_timer.record(\"initial startup\")\r\n\r\n    with launch_utils.startup_timer.subcategory(\"prepare environment\"):\r\n        if not args.skip_prepare_environment:\r\n            prepare_environment()\r\n\r\n    if args.test_server:\r\n        configure_for_tests()\r\n\r\n    start()\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    main()\r\n"
  },
  {
    "path": "localizations/Put localization files here.txt",
    "content": ""
  },
  {
    "path": "modules/api/api.py",
    "content": "import base64\nimport io\nimport os\nimport time\nimport datetime\nimport uvicorn\nimport ipaddress\nimport requests\nimport gradio as gr\nfrom threading import Lock\nfrom io import BytesIO\nfrom fastapi import APIRouter, Depends, FastAPI, Request, Response\nfrom fastapi.security import HTTPBasic, HTTPBasicCredentials\nfrom fastapi.exceptions import HTTPException\nfrom fastapi.responses import JSONResponse\nfrom fastapi.encoders import jsonable_encoder\nfrom secrets import compare_digest\n\nimport modules.shared as shared\nfrom modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, infotext_utils, sd_models, sd_schedulers\nfrom modules.api import models\nfrom modules.shared import opts\nfrom modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images\nfrom modules.textual_inversion.textual_inversion import create_embedding, train_embedding\nfrom modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork\nfrom PIL import PngImagePlugin\nfrom modules.sd_models_config import find_checkpoint_config_near_filename\nfrom modules.realesrgan_model import get_realesrgan_models\nfrom modules import devices\nfrom typing import Any\nimport piexif\nimport piexif.helper\nfrom contextlib import closing\nfrom modules.progress import create_task_id, add_task_to_queue, start_task, finish_task, current_task\n\ndef script_name_to_index(name, scripts):\n    try:\n        return [script.title().lower() for script in scripts].index(name.lower())\n    except Exception as e:\n        raise HTTPException(status_code=422, detail=f\"Script '{name}' not found\") from e\n\n\ndef validate_sampler_name(name):\n    config = sd_samplers.all_samplers_map.get(name, None)\n    if config is None:\n        raise HTTPException(status_code=400, detail=\"Sampler not found\")\n\n    return name\n\n\ndef setUpscalers(req: dict):\n    reqDict = vars(req)\n    reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)\n    reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)\n    return reqDict\n\n\ndef verify_url(url):\n    \"\"\"Returns True if the url refers to a global resource.\"\"\"\n\n    import socket\n    from urllib.parse import urlparse\n    try:\n        parsed_url = urlparse(url)\n        domain_name = parsed_url.netloc\n        host = socket.gethostbyname_ex(domain_name)\n        for ip in host[2]:\n            ip_addr = ipaddress.ip_address(ip)\n            if not ip_addr.is_global:\n                return False\n    except Exception:\n        return False\n\n    return True\n\n\ndef decode_base64_to_image(encoding):\n    if encoding.startswith(\"http://\") or encoding.startswith(\"https://\"):\n        if not opts.api_enable_requests:\n            raise HTTPException(status_code=500, detail=\"Requests not allowed\")\n\n        if opts.api_forbid_local_requests and not verify_url(encoding):\n            raise HTTPException(status_code=500, detail=\"Request to local resource not allowed\")\n\n        headers = {'user-agent': opts.api_useragent} if opts.api_useragent else {}\n        response = requests.get(encoding, timeout=30, headers=headers)\n        try:\n            image = images.read(BytesIO(response.content))\n            return image\n        except Exception as e:\n            raise HTTPException(status_code=500, detail=\"Invalid image url\") from e\n\n    if encoding.startswith(\"data:image/\"):\n        encoding = encoding.split(\";\")[1].split(\",\")[1]\n    try:\n        image = images.read(BytesIO(base64.b64decode(encoding)))\n        return image\n    except Exception as e:\n        raise HTTPException(status_code=500, detail=\"Invalid encoded image\") from e\n\n\ndef encode_pil_to_base64(image):\n    with io.BytesIO() as output_bytes:\n        if isinstance(image, str):\n            return image\n        if opts.samples_format.lower() == 'png':\n            use_metadata = False\n            metadata = PngImagePlugin.PngInfo()\n            for key, value in image.info.items():\n                if isinstance(key, str) and isinstance(value, str):\n                    metadata.add_text(key, value)\n                    use_metadata = True\n            image.save(output_bytes, format=\"PNG\", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality)\n\n        elif opts.samples_format.lower() in (\"jpg\", \"jpeg\", \"webp\"):\n            if image.mode in (\"RGBA\", \"P\"):\n                image = image.convert(\"RGB\")\n            parameters = image.info.get('parameters', None)\n            exif_bytes = piexif.dump({\n                \"Exif\": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or \"\", encoding=\"unicode\") }\n            })\n            if opts.samples_format.lower() in (\"jpg\", \"jpeg\"):\n                image.save(output_bytes, format=\"JPEG\", exif = exif_bytes, quality=opts.jpeg_quality)\n            else:\n                image.save(output_bytes, format=\"WEBP\", exif = exif_bytes, quality=opts.jpeg_quality)\n\n        else:\n            raise HTTPException(status_code=500, detail=\"Invalid image format\")\n\n        bytes_data = output_bytes.getvalue()\n\n    return base64.b64encode(bytes_data)\n\n\ndef api_middleware(app: FastAPI):\n    rich_available = False\n    try:\n        if os.environ.get('WEBUI_RICH_EXCEPTIONS', None) is not None:\n            import anyio  # importing just so it can be placed on silent list\n            import starlette  # importing just so it can be placed on silent list\n            from rich.console import Console\n            console = Console()\n            rich_available = True\n    except Exception:\n        pass\n\n    @app.middleware(\"http\")\n    async def log_and_time(req: Request, call_next):\n        ts = time.time()\n        res: Response = await call_next(req)\n        duration = str(round(time.time() - ts, 4))\n        res.headers[\"X-Process-Time\"] = duration\n        endpoint = req.scope.get('path', 'err')\n        if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'):\n            print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format(\n                t=datetime.datetime.now().strftime(\"%Y-%m-%d %H:%M:%S.%f\"),\n                code=res.status_code,\n                ver=req.scope.get('http_version', '0.0'),\n                cli=req.scope.get('client', ('0:0.0.0', 0))[0],\n                prot=req.scope.get('scheme', 'err'),\n                method=req.scope.get('method', 'err'),\n                endpoint=endpoint,\n                duration=duration,\n            ))\n        return res\n\n    def handle_exception(request: Request, e: Exception):\n        err = {\n            \"error\": type(e).__name__,\n            \"detail\": vars(e).get('detail', ''),\n            \"body\": vars(e).get('body', ''),\n            \"errors\": str(e),\n        }\n        if not isinstance(e, HTTPException):  # do not print backtrace on known httpexceptions\n            message = f\"API error: {request.method}: {request.url} {err}\"\n            if rich_available:\n                print(message)\n                console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200]))\n            else:\n                errors.report(message, exc_info=True)\n        return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err))\n\n    @app.middleware(\"http\")\n    async def exception_handling(request: Request, call_next):\n        try:\n            return await call_next(request)\n        except Exception as e:\n            return handle_exception(request, e)\n\n    @app.exception_handler(Exception)\n    async def fastapi_exception_handler(request: Request, e: Exception):\n        return handle_exception(request, e)\n\n    @app.exception_handler(HTTPException)\n    async def http_exception_handler(request: Request, e: HTTPException):\n        return handle_exception(request, e)\n\n\nclass Api:\n    def __init__(self, app: FastAPI, queue_lock: Lock):\n        if shared.cmd_opts.api_auth:\n            self.credentials = {}\n            for auth in shared.cmd_opts.api_auth.split(\",\"):\n                user, password = auth.split(\":\")\n                self.credentials[user] = password\n\n        self.router = APIRouter()\n        self.app = app\n        self.queue_lock = queue_lock\n        api_middleware(self.app)\n        self.add_api_route(\"/sdapi/v1/txt2img\", self.text2imgapi, methods=[\"POST\"], response_model=models.TextToImageResponse)\n        self.add_api_route(\"/sdapi/v1/img2img\", self.img2imgapi, methods=[\"POST\"], response_model=models.ImageToImageResponse)\n        self.add_api_route(\"/sdapi/v1/extra-single-image\", self.extras_single_image_api, methods=[\"POST\"], response_model=models.ExtrasSingleImageResponse)\n        self.add_api_route(\"/sdapi/v1/extra-batch-images\", self.extras_batch_images_api, methods=[\"POST\"], response_model=models.ExtrasBatchImagesResponse)\n        self.add_api_route(\"/sdapi/v1/png-info\", self.pnginfoapi, methods=[\"POST\"], response_model=models.PNGInfoResponse)\n        self.add_api_route(\"/sdapi/v1/progress\", self.progressapi, methods=[\"GET\"], response_model=models.ProgressResponse)\n        self.add_api_route(\"/sdapi/v1/interrogate\", self.interrogateapi, methods=[\"POST\"])\n        self.add_api_route(\"/sdapi/v1/interrupt\", self.interruptapi, methods=[\"POST\"])\n        self.add_api_route(\"/sdapi/v1/skip\", self.skip, methods=[\"POST\"])\n        self.add_api_route(\"/sdapi/v1/options\", self.get_config, methods=[\"GET\"], response_model=models.OptionsModel)\n        self.add_api_route(\"/sdapi/v1/options\", self.set_config, methods=[\"POST\"])\n        self.add_api_route(\"/sdapi/v1/cmd-flags\", self.get_cmd_flags, methods=[\"GET\"], response_model=models.FlagsModel)\n        self.add_api_route(\"/sdapi/v1/samplers\", self.get_samplers, methods=[\"GET\"], response_model=list[models.SamplerItem])\n        self.add_api_route(\"/sdapi/v1/schedulers\", self.get_schedulers, methods=[\"GET\"], response_model=list[models.SchedulerItem])\n        self.add_api_route(\"/sdapi/v1/upscalers\", self.get_upscalers, methods=[\"GET\"], response_model=list[models.UpscalerItem])\n        self.add_api_route(\"/sdapi/v1/latent-upscale-modes\", self.get_latent_upscale_modes, methods=[\"GET\"], response_model=list[models.LatentUpscalerModeItem])\n        self.add_api_route(\"/sdapi/v1/sd-models\", self.get_sd_models, methods=[\"GET\"], response_model=list[models.SDModelItem])\n        self.add_api_route(\"/sdapi/v1/sd-vae\", self.get_sd_vaes, methods=[\"GET\"], response_model=list[models.SDVaeItem])\n        self.add_api_route(\"/sdapi/v1/hypernetworks\", self.get_hypernetworks, methods=[\"GET\"], response_model=list[models.HypernetworkItem])\n        self.add_api_route(\"/sdapi/v1/face-restorers\", self.get_face_restorers, methods=[\"GET\"], response_model=list[models.FaceRestorerItem])\n        self.add_api_route(\"/sdapi/v1/realesrgan-models\", self.get_realesrgan_models, methods=[\"GET\"], response_model=list[models.RealesrganItem])\n        self.add_api_route(\"/sdapi/v1/prompt-styles\", self.get_prompt_styles, methods=[\"GET\"], response_model=list[models.PromptStyleItem])\n        self.add_api_route(\"/sdapi/v1/embeddings\", self.get_embeddings, methods=[\"GET\"], response_model=models.EmbeddingsResponse)\n        self.add_api_route(\"/sdapi/v1/refresh-embeddings\", self.refresh_embeddings, methods=[\"POST\"])\n        self.add_api_route(\"/sdapi/v1/refresh-checkpoints\", self.refresh_checkpoints, methods=[\"POST\"])\n        self.add_api_route(\"/sdapi/v1/refresh-vae\", self.refresh_vae, methods=[\"POST\"])\n        self.add_api_route(\"/sdapi/v1/create/embedding\", self.create_embedding, methods=[\"POST\"], response_model=models.CreateResponse)\n        self.add_api_route(\"/sdapi/v1/create/hypernetwork\", self.create_hypernetwork, methods=[\"POST\"], response_model=models.CreateResponse)\n        self.add_api_route(\"/sdapi/v1/train/embedding\", self.train_embedding, methods=[\"POST\"], response_model=models.TrainResponse)\n        self.add_api_route(\"/sdapi/v1/train/hypernetwork\", self.train_hypernetwork, methods=[\"POST\"], response_model=models.TrainResponse)\n        self.add_api_route(\"/sdapi/v1/memory\", self.get_memory, methods=[\"GET\"], response_model=models.MemoryResponse)\n        self.add_api_route(\"/sdapi/v1/unload-checkpoint\", self.unloadapi, methods=[\"POST\"])\n        self.add_api_route(\"/sdapi/v1/reload-checkpoint\", self.reloadapi, methods=[\"POST\"])\n        self.add_api_route(\"/sdapi/v1/scripts\", self.get_scripts_list, methods=[\"GET\"], response_model=models.ScriptsList)\n        self.add_api_route(\"/sdapi/v1/script-info\", self.get_script_info, methods=[\"GET\"], response_model=list[models.ScriptInfo])\n        self.add_api_route(\"/sdapi/v1/extensions\", self.get_extensions_list, methods=[\"GET\"], response_model=list[models.ExtensionItem])\n\n        if shared.cmd_opts.api_server_stop:\n            self.add_api_route(\"/sdapi/v1/server-kill\", self.kill_webui, methods=[\"POST\"])\n            self.add_api_route(\"/sdapi/v1/server-restart\", self.restart_webui, methods=[\"POST\"])\n            self.add_api_route(\"/sdapi/v1/server-stop\", self.stop_webui, methods=[\"POST\"])\n\n        self.default_script_arg_txt2img = []\n        self.default_script_arg_img2img = []\n\n        txt2img_script_runner = scripts.scripts_txt2img\n        img2img_script_runner = scripts.scripts_img2img\n\n        if not txt2img_script_runner.scripts or not img2img_script_runner.scripts:\n            ui.create_ui()\n\n        if not txt2img_script_runner.scripts:\n            txt2img_script_runner.initialize_scripts(False)\n        if not self.default_script_arg_txt2img:\n            self.default_script_arg_txt2img = self.init_default_script_args(txt2img_script_runner)\n\n        if not img2img_script_runner.scripts:\n            img2img_script_runner.initialize_scripts(True)\n        if not self.default_script_arg_img2img:\n            self.default_script_arg_img2img = self.init_default_script_args(img2img_script_runner)\n\n\n\n    def add_api_route(self, path: str, endpoint, **kwargs):\n        if shared.cmd_opts.api_auth:\n            return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs)\n        return self.app.add_api_route(path, endpoint, **kwargs)\n\n    def auth(self, credentials: HTTPBasicCredentials = Depends(HTTPBasic())):\n        if credentials.username in self.credentials:\n            if compare_digest(credentials.password, self.credentials[credentials.username]):\n                return True\n\n        raise HTTPException(status_code=401, detail=\"Incorrect username or password\", headers={\"WWW-Authenticate\": \"Basic\"})\n\n    def get_selectable_script(self, script_name, script_runner):\n        if script_name is None or script_name == \"\":\n            return None, None\n\n        script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)\n        script = script_runner.selectable_scripts[script_idx]\n        return script, script_idx\n\n    def get_scripts_list(self):\n        t2ilist = [script.name for script in scripts.scripts_txt2img.scripts if script.name is not None]\n        i2ilist = [script.name for script in scripts.scripts_img2img.scripts if script.name is not None]\n\n        return models.ScriptsList(txt2img=t2ilist, img2img=i2ilist)\n\n    def get_script_info(self):\n        res = []\n\n        for script_list in [scripts.scripts_txt2img.scripts, scripts.scripts_img2img.scripts]:\n            res += [script.api_info for script in script_list if script.api_info is not None]\n\n        return res\n\n    def get_script(self, script_name, script_runner):\n        if script_name is None or script_name == \"\":\n            return None, None\n\n        script_idx = script_name_to_index(script_name, script_runner.scripts)\n        return script_runner.scripts[script_idx]\n\n    def init_default_script_args(self, script_runner):\n        #find max idx from the scripts in runner and generate a none array to init script_args\n        last_arg_index = 1\n        for script in script_runner.scripts:\n            if last_arg_index < script.args_to:\n                last_arg_index = script.args_to\n        # None everywhere except position 0 to initialize script args\n        script_args = [None]*last_arg_index\n        script_args[0] = 0\n\n        # get default values\n        with gr.Blocks(): # will throw errors calling ui function without this\n            for script in script_runner.scripts:\n                if script.ui(script.is_img2img):\n                    ui_default_values = []\n                    for elem in script.ui(script.is_img2img):\n                        ui_default_values.append(elem.value)\n                    script_args[script.args_from:script.args_to] = ui_default_values\n        return script_args\n\n    def init_script_args(self, request, default_script_args, selectable_scripts, selectable_idx, script_runner, *, input_script_args=None):\n        script_args = default_script_args.copy()\n\n        if input_script_args is not None:\n            for index, value in input_script_args.items():\n                script_args[index] = value\n\n        # position 0 in script_arg is the idx+1 of the selectable script that is going to be run when using scripts.scripts_*2img.run()\n        if selectable_scripts:\n            script_args[selectable_scripts.args_from:selectable_scripts.args_to] = request.script_args\n            script_args[0] = selectable_idx + 1\n\n        # Now check for always on scripts\n        if request.alwayson_scripts:\n            for alwayson_script_name in request.alwayson_scripts.keys():\n                alwayson_script = self.get_script(alwayson_script_name, script_runner)\n                if alwayson_script is None:\n                    raise HTTPException(status_code=422, detail=f\"always on script {alwayson_script_name} not found\")\n                # Selectable script in always on script param check\n                if alwayson_script.alwayson is False:\n                    raise HTTPException(status_code=422, detail=\"Cannot have a selectable script in the always on scripts params\")\n                # always on script with no arg should always run so you don't really need to add them to the requests\n                if \"args\" in request.alwayson_scripts[alwayson_script_name]:\n                    # min between arg length in scriptrunner and arg length in the request\n                    for idx in range(0, min((alwayson_script.args_to - alwayson_script.args_from), len(request.alwayson_scripts[alwayson_script_name][\"args\"]))):\n                        script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name][\"args\"][idx]\n        return script_args\n\n    def apply_infotext(self, request, tabname, *, script_runner=None, mentioned_script_args=None):\n        \"\"\"Processes `infotext` field from the `request`, and sets other fields of the `request` according to what's in infotext.\n\n        If request already has a field set, and that field is encountered in infotext too, the value from infotext is ignored.\n\n        Additionally, fills `mentioned_script_args` dict with index: value pairs for script arguments read from infotext.\n        \"\"\"\n\n        if not request.infotext:\n            return {}\n\n        possible_fields = infotext_utils.paste_fields[tabname][\"fields\"]\n        set_fields = request.model_dump(exclude_unset=True) if hasattr(request, \"request\") else request.dict(exclude_unset=True)  # pydantic v1/v2 have different names for this\n        params = infotext_utils.parse_generation_parameters(request.infotext)\n\n        def get_field_value(field, params):\n            value = field.function(params) if field.function else params.get(field.label)\n            if value is None:\n                return None\n\n            if field.api in request.__fields__:\n                target_type = request.__fields__[field.api].type_\n            else:\n                target_type = type(field.component.value)\n\n            if target_type == type(None):\n                return None\n\n            if isinstance(value, dict) and value.get('__type__') == 'generic_update':  # this is a gradio.update rather than a value\n                value = value.get('value')\n\n            if value is not None and not isinstance(value, target_type):\n                value = target_type(value)\n\n            return value\n\n        for field in possible_fields:\n            if not field.api:\n                continue\n\n            if field.api in set_fields:\n                continue\n\n            value = get_field_value(field, params)\n            if value is not None:\n                setattr(request, field.api, value)\n\n        if request.override_settings is None:\n            request.override_settings = {}\n\n        overridden_settings = infotext_utils.get_override_settings(params)\n        for _, setting_name, value in overridden_settings:\n            if setting_name not in request.override_settings:\n                request.override_settings[setting_name] = value\n\n        if script_runner is not None and mentioned_script_args is not None:\n            indexes = {v: i for i, v in enumerate(script_runner.inputs)}\n            script_fields = ((field, indexes[field.component]) for field in possible_fields if field.component in indexes)\n\n            for field, index in script_fields:\n                value = get_field_value(field, params)\n\n                if value is None:\n                    continue\n\n                mentioned_script_args[index] = value\n\n        return params\n\n    def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):\n        task_id = txt2imgreq.force_task_id or create_task_id(\"txt2img\")\n\n        script_runner = scripts.scripts_txt2img\n\n        infotext_script_args = {}\n        self.apply_infotext(txt2imgreq, \"txt2img\", script_runner=script_runner, mentioned_script_args=infotext_script_args)\n\n        selectable_scripts, selectable_script_idx = self.get_selectable_script(txt2imgreq.script_name, script_runner)\n        sampler, scheduler = sd_samplers.get_sampler_and_scheduler(txt2imgreq.sampler_name or txt2imgreq.sampler_index, txt2imgreq.scheduler)\n\n        populate = txt2imgreq.copy(update={  # Override __init__ params\n            \"sampler_name\": validate_sampler_name(sampler),\n            \"do_not_save_samples\": not txt2imgreq.save_images,\n            \"do_not_save_grid\": not txt2imgreq.save_images,\n        })\n        if populate.sampler_name:\n            populate.sampler_index = None  # prevent a warning later on\n\n        if not populate.scheduler and scheduler != \"Automatic\":\n            populate.scheduler = scheduler\n\n        args = vars(populate)\n        args.pop('script_name', None)\n        args.pop('script_args', None) # will refeed them to the pipeline directly after initializing them\n        args.pop('alwayson_scripts', None)\n        args.pop('infotext', None)\n\n        script_args = self.init_script_args(txt2imgreq, self.default_script_arg_txt2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args)\n\n        send_images = args.pop('send_images', True)\n        args.pop('save_images', None)\n\n        add_task_to_queue(task_id)\n\n        with self.queue_lock:\n            with closing(StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)) as p:\n                p.is_api = True\n                p.scripts = script_runner\n                p.outpath_grids = opts.outdir_txt2img_grids\n                p.outpath_samples = opts.outdir_txt2img_samples\n\n                try:\n                    shared.state.begin(job=\"scripts_txt2img\")\n                    start_task(task_id)\n                    if selectable_scripts is not None:\n                        p.script_args = script_args\n                        processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here\n                    else:\n                        p.script_args = tuple(script_args) # Need to pass args as tuple here\n                        processed = process_images(p)\n                    finish_task(task_id)\n                finally:\n                    shared.state.end()\n                    shared.total_tqdm.clear()\n\n        b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []\n\n        return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())\n\n    def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI):\n        task_id = img2imgreq.force_task_id or create_task_id(\"img2img\")\n\n        init_images = img2imgreq.init_images\n        if init_images is None:\n            raise HTTPException(status_code=404, detail=\"Init image not found\")\n\n        mask = img2imgreq.mask\n        if mask:\n            mask = decode_base64_to_image(mask)\n\n        script_runner = scripts.scripts_img2img\n\n        infotext_script_args = {}\n        self.apply_infotext(img2imgreq, \"img2img\", script_runner=script_runner, mentioned_script_args=infotext_script_args)\n\n        selectable_scripts, selectable_script_idx = self.get_selectable_script(img2imgreq.script_name, script_runner)\n        sampler, scheduler = sd_samplers.get_sampler_and_scheduler(img2imgreq.sampler_name or img2imgreq.sampler_index, img2imgreq.scheduler)\n\n        populate = img2imgreq.copy(update={  # Override __init__ params\n            \"sampler_name\": validate_sampler_name(sampler),\n            \"do_not_save_samples\": not img2imgreq.save_images,\n            \"do_not_save_grid\": not img2imgreq.save_images,\n            \"mask\": mask,\n        })\n        if populate.sampler_name:\n            populate.sampler_index = None  # prevent a warning later on\n\n        if not populate.scheduler and scheduler != \"Automatic\":\n            populate.scheduler = scheduler\n\n        args = vars(populate)\n        args.pop('include_init_images', None)  # this is meant to be done by \"exclude\": True in model, but it's for a reason that I cannot determine.\n        args.pop('script_name', None)\n        args.pop('script_args', None)  # will refeed them to the pipeline directly after initializing them\n        args.pop('alwayson_scripts', None)\n        args.pop('infotext', None)\n\n        script_args = self.init_script_args(img2imgreq, self.default_script_arg_img2img, selectable_scripts, selectable_script_idx, script_runner, input_script_args=infotext_script_args)\n\n        send_images = args.pop('send_images', True)\n        args.pop('save_images', None)\n\n        add_task_to_queue(task_id)\n\n        with self.queue_lock:\n            with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p:\n                p.init_images = [decode_base64_to_image(x) for x in init_images]\n                p.is_api = True\n                p.scripts = script_runner\n                p.outpath_grids = opts.outdir_img2img_grids\n                p.outpath_samples = opts.outdir_img2img_samples\n\n                try:\n                    shared.state.begin(job=\"scripts_img2img\")\n                    start_task(task_id)\n                    if selectable_scripts is not None:\n                        p.script_args = script_args\n                        processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here\n                    else:\n                        p.script_args = tuple(script_args) # Need to pass args as tuple here\n                        processed = process_images(p)\n                    finish_task(task_id)\n                finally:\n                    shared.state.end()\n                    shared.total_tqdm.clear()\n\n        b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []\n\n        if not img2imgreq.include_init_images:\n            img2imgreq.init_images = None\n            img2imgreq.mask = None\n\n        return models.ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())\n\n    def extras_single_image_api(self, req: models.ExtrasSingleImageRequest):\n        reqDict = setUpscalers(req)\n\n        reqDict['image'] = decode_base64_to_image(reqDict['image'])\n\n        with self.queue_lock:\n            result = postprocessing.run_extras(extras_mode=0, image_folder=\"\", input_dir=\"\", output_dir=\"\", save_output=False, **reqDict)\n\n        return models.ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])\n\n    def extras_batch_images_api(self, req: models.ExtrasBatchImagesRequest):\n        reqDict = setUpscalers(req)\n\n        image_list = reqDict.pop('imageList', [])\n        image_folder = [decode_base64_to_image(x.data) for x in image_list]\n\n        with self.queue_lock:\n            result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image=\"\", input_dir=\"\", output_dir=\"\", save_output=False, **reqDict)\n\n        return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])\n\n    def pnginfoapi(self, req: models.PNGInfoRequest):\n        image = decode_base64_to_image(req.image.strip())\n        if image is None:\n            return models.PNGInfoResponse(info=\"\")\n\n        geninfo, items = images.read_info_from_image(image)\n        if geninfo is None:\n            geninfo = \"\"\n\n        params = infotext_utils.parse_generation_parameters(geninfo)\n        script_callbacks.infotext_pasted_callback(geninfo, params)\n\n        return models.PNGInfoResponse(info=geninfo, items=items, parameters=params)\n\n    def progressapi(self, req: models.ProgressRequest = Depends()):\n        # copy from check_progress_call of ui.py\n\n        if shared.state.job_count == 0:\n            return models.ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)\n\n        # avoid dividing zero\n        progress = 0.01\n\n        if shared.state.job_count > 0:\n            progress += shared.state.job_no / shared.state.job_count\n        if shared.state.sampling_steps > 0:\n            progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps\n\n        time_since_start = time.time() - shared.state.time_start\n        eta = (time_since_start/progress)\n        eta_relative = eta-time_since_start\n\n        progress = min(progress, 1)\n\n        shared.state.set_current_image()\n\n        current_image = None\n        if shared.state.current_image and not req.skip_current_image:\n            current_image = encode_pil_to_base64(shared.state.current_image)\n\n        return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo, current_task=current_task)\n\n    def interrogateapi(self, interrogatereq: models.InterrogateRequest):\n        image_b64 = interrogatereq.image\n        if image_b64 is None:\n            raise HTTPException(status_code=404, detail=\"Image not found\")\n\n        img = decode_base64_to_image(image_b64)\n        img = img.convert('RGB')\n\n        # Override object param\n        with self.queue_lock:\n            if interrogatereq.model == \"clip\":\n                processed = shared.interrogator.interrogate(img)\n            elif interrogatereq.model == \"deepdanbooru\":\n                processed = deepbooru.model.tag(img)\n            else:\n                raise HTTPException(status_code=404, detail=\"Model not found\")\n\n        return models.InterrogateResponse(caption=processed)\n\n    def interruptapi(self):\n        shared.state.interrupt()\n\n        return {}\n\n    def unloadapi(self):\n        sd_models.unload_model_weights()\n\n        return {}\n\n    def reloadapi(self):\n        sd_models.send_model_to_device(shared.sd_model)\n\n        return {}\n\n    def skip(self):\n        shared.state.skip()\n\n    def get_config(self):\n        options = {}\n        for key in shared.opts.data.keys():\n            metadata = shared.opts.data_labels.get(key)\n            if(metadata is not None):\n                options.update({key: shared.opts.data.get(key, shared.opts.data_labels.get(key).default)})\n            else:\n                options.update({key: shared.opts.data.get(key, None)})\n\n        return options\n\n    def set_config(self, req: dict[str, Any]):\n        checkpoint_name = req.get(\"sd_model_checkpoint\", None)\n        if checkpoint_name is not None and checkpoint_name not in sd_models.checkpoint_aliases:\n            raise RuntimeError(f\"model {checkpoint_name!r} not found\")\n\n        for k, v in req.items():\n            shared.opts.set(k, v, is_api=True)\n\n        shared.opts.save(shared.config_filename)\n        return\n\n    def get_cmd_flags(self):\n        return vars(shared.cmd_opts)\n\n    def get_samplers(self):\n        return [{\"name\": sampler[0], \"aliases\":sampler[2], \"options\":sampler[3]} for sampler in sd_samplers.all_samplers]\n\n    def get_schedulers(self):\n        return [\n            {\n                \"name\": scheduler.name,\n                \"label\": scheduler.label,\n                \"aliases\": scheduler.aliases,\n                \"default_rho\": scheduler.default_rho,\n                \"need_inner_model\": scheduler.need_inner_model,\n            }\n            for scheduler in sd_schedulers.schedulers]\n\n    def get_upscalers(self):\n        return [\n            {\n                \"name\": upscaler.name,\n                \"model_name\": upscaler.scaler.model_name,\n                \"model_path\": upscaler.data_path,\n                \"model_url\": None,\n                \"scale\": upscaler.scale,\n            }\n            for upscaler in shared.sd_upscalers\n        ]\n\n    def get_latent_upscale_modes(self):\n        return [\n            {\n                \"name\": upscale_mode,\n            }\n            for upscale_mode in [*(shared.latent_upscale_modes or {})]\n        ]\n\n    def get_sd_models(self):\n        import modules.sd_models as sd_models\n        return [{\"title\": x.title, \"model_name\": x.model_name, \"hash\": x.shorthash, \"sha256\": x.sha256, \"filename\": x.filename, \"config\": find_checkpoint_config_near_filename(x)} for x in sd_models.checkpoints_list.values()]\n\n    def get_sd_vaes(self):\n        import modules.sd_vae as sd_vae\n        return [{\"model_name\": x, \"filename\": sd_vae.vae_dict[x]} for x in sd_vae.vae_dict.keys()]\n\n    def get_hypernetworks(self):\n        return [{\"name\": name, \"path\": shared.hypernetworks[name]} for name in shared.hypernetworks]\n\n    def get_face_restorers(self):\n        return [{\"name\":x.name(), \"cmd_dir\": getattr(x, \"cmd_dir\", None)} for x in shared.face_restorers]\n\n    def get_realesrgan_models(self):\n        return [{\"name\":x.name,\"path\":x.data_path, \"scale\":x.scale} for x in get_realesrgan_models(None)]\n\n    def get_prompt_styles(self):\n        styleList = []\n        for k in shared.prompt_styles.styles:\n            style = shared.prompt_styles.styles[k]\n            styleList.append({\"name\":style[0], \"prompt\": style[1], \"negative_prompt\": style[2]})\n\n        return styleList\n\n    def get_embeddings(self):\n        db = sd_hijack.model_hijack.embedding_db\n\n        def convert_embedding(embedding):\n            return {\n                \"step\": embedding.step,\n                \"sd_checkpoint\": embedding.sd_checkpoint,\n                \"sd_checkpoint_name\": embedding.sd_checkpoint_name,\n                \"shape\": embedding.shape,\n                \"vectors\": embedding.vectors,\n            }\n\n        def convert_embeddings(embeddings):\n            return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}\n\n        return {\n            \"loaded\": convert_embeddings(db.word_embeddings),\n            \"skipped\": convert_embeddings(db.skipped_embeddings),\n        }\n\n    def refresh_embeddings(self):\n        with self.queue_lock:\n            sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)\n\n    def refresh_checkpoints(self):\n        with self.queue_lock:\n            shared.refresh_checkpoints()\n\n    def refresh_vae(self):\n        with self.queue_lock:\n            shared_items.refresh_vae_list()\n\n    def create_embedding(self, args: dict):\n        try:\n            shared.state.begin(job=\"create_embedding\")\n            filename = create_embedding(**args) # create empty embedding\n            sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used\n            return models.CreateResponse(info=f\"create embedding filename: {filename}\")\n        except AssertionError as e:\n            return models.TrainResponse(info=f\"create embedding error: {e}\")\n        finally:\n            shared.state.end()\n\n\n    def create_hypernetwork(self, args: dict):\n        try:\n            shared.state.begin(job=\"create_hypernetwork\")\n            filename = create_hypernetwork(**args) # create empty embedding\n            return models.CreateResponse(info=f\"create hypernetwork filename: {filename}\")\n        except AssertionError as e:\n            return models.TrainResponse(info=f\"create hypernetwork error: {e}\")\n        finally:\n            shared.state.end()\n\n    def train_embedding(self, args: dict):\n        try:\n            shared.state.begin(job=\"train_embedding\")\n            apply_optimizations = shared.opts.training_xattention_optimizations\n            error = None\n            filename = ''\n            if not apply_optimizations:\n                sd_hijack.undo_optimizations()\n            try:\n                embedding, filename = train_embedding(**args) # can take a long time to complete\n            except Exception as e:\n                error = e\n            finally:\n                if not apply_optimizations:\n                    sd_hijack.apply_optimizations()\n            return models.TrainResponse(info=f\"train embedding complete: filename: {filename} error: {error}\")\n        except Exception as msg:\n            return models.TrainResponse(info=f\"train embedding error: {msg}\")\n        finally:\n            shared.state.end()\n\n    def train_hypernetwork(self, args: dict):\n        try:\n            shared.state.begin(job=\"train_hypernetwork\")\n            shared.loaded_hypernetworks = []\n            apply_optimizations = shared.opts.training_xattention_optimizations\n            error = None\n            filename = ''\n            if not apply_optimizations:\n                sd_hijack.undo_optimizations()\n            try:\n                hypernetwork, filename = train_hypernetwork(**args)\n            except Exception as e:\n                error = e\n            finally:\n                shared.sd_model.cond_stage_model.to(devices.device)\n                shared.sd_model.first_stage_model.to(devices.device)\n                if not apply_optimizations:\n                    sd_hijack.apply_optimizations()\n                shared.state.end()\n            return models.TrainResponse(info=f\"train embedding complete: filename: {filename} error: {error}\")\n        except Exception as exc:\n            return models.TrainResponse(info=f\"train embedding error: {exc}\")\n        finally:\n            shared.state.end()\n\n    def get_memory(self):\n        try:\n            import os\n            import psutil\n            process = psutil.Process(os.getpid())\n            res = process.memory_info() # only rss is cross-platform guaranteed so we dont rely on other values\n            ram_total = 100 * res.rss / process.memory_percent() # and total memory is calculated as actual value is not cross-platform safe\n            ram = { 'free': ram_total - res.rss, 'used': res.rss, 'total': ram_total }\n        except Exception as err:\n            ram = { 'error': f'{err}' }\n        try:\n            import torch\n            if torch.cuda.is_available():\n                s = torch.cuda.mem_get_info()\n                system = { 'free': s[0], 'used': s[1] - s[0], 'total': s[1] }\n                s = dict(torch.cuda.memory_stats(shared.device))\n                allocated = { 'current': s['allocated_bytes.all.current'], 'peak': s['allocated_bytes.all.peak'] }\n                reserved = { 'current': s['reserved_bytes.all.current'], 'peak': s['reserved_bytes.all.peak'] }\n                active = { 'current': s['active_bytes.all.current'], 'peak': s['active_bytes.all.peak'] }\n                inactive = { 'current': s['inactive_split_bytes.all.current'], 'peak': s['inactive_split_bytes.all.peak'] }\n                warnings = { 'retries': s['num_alloc_retries'], 'oom': s['num_ooms'] }\n                cuda = {\n                    'system': system,\n                    'active': active,\n                    'allocated': allocated,\n                    'reserved': reserved,\n                    'inactive': inactive,\n                    'events': warnings,\n                }\n            else:\n                cuda = {'error': 'unavailable'}\n        except Exception as err:\n            cuda = {'error': f'{err}'}\n        return models.MemoryResponse(ram=ram, cuda=cuda)\n\n    def get_extensions_list(self):\n        from modules import extensions\n        extensions.list_extensions()\n        ext_list = []\n        for ext in extensions.extensions:\n            ext: extensions.Extension\n            ext.read_info_from_repo()\n            if ext.remote is not None:\n                ext_list.append({\n                    \"name\": ext.name,\n                    \"remote\": ext.remote,\n                    \"branch\": ext.branch,\n                    \"commit_hash\":ext.commit_hash,\n                    \"commit_date\":ext.commit_date,\n                    \"version\":ext.version,\n                    \"enabled\":ext.enabled\n                })\n        return ext_list\n\n    def launch(self, server_name, port, root_path):\n        self.app.include_router(self.router)\n        uvicorn.run(\n            self.app,\n            host=server_name,\n            port=port,\n            timeout_keep_alive=shared.cmd_opts.timeout_keep_alive,\n            root_path=root_path,\n            ssl_keyfile=shared.cmd_opts.tls_keyfile,\n            ssl_certfile=shared.cmd_opts.tls_certfile\n        )\n\n    def kill_webui(self):\n        restart.stop_program()\n\n    def restart_webui(self):\n        if restart.is_restartable():\n            restart.restart_program()\n        return Response(status_code=501)\n\n    def stop_webui(request):\n        shared.state.server_command = \"stop\"\n        return Response(\"Stopping.\")\n\n"
  },
  {
    "path": "modules/api/models.py",
    "content": "import inspect\n\nfrom pydantic import BaseModel, Field, create_model\nfrom typing import Any, Optional, Literal\nfrom inflection import underscore\nfrom modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img\nfrom modules.shared import sd_upscalers, opts, parser\n\nAPI_NOT_ALLOWED = [\n    \"self\",\n    \"kwargs\",\n    \"sd_model\",\n    \"outpath_samples\",\n    \"outpath_grids\",\n    \"sampler_index\",\n    # \"do_not_save_samples\",\n    # \"do_not_save_grid\",\n    \"extra_generation_params\",\n    \"overlay_images\",\n    \"do_not_reload_embeddings\",\n    \"seed_enable_extras\",\n    \"prompt_for_display\",\n    \"sampler_noise_scheduler_override\",\n    \"ddim_discretize\"\n]\n\nclass ModelDef(BaseModel):\n    \"\"\"Assistance Class for Pydantic Dynamic Model Generation\"\"\"\n\n    field: str\n    field_alias: str\n    field_type: Any\n    field_value: Any\n    field_exclude: bool = False\n\n\nclass PydanticModelGenerator:\n    \"\"\"\n    Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:\n    source_data is a snapshot of the default values produced by the class\n    params are the names of the actual keys required by __init__\n    \"\"\"\n\n    def __init__(\n        self,\n        model_name: str = None,\n        class_instance = None,\n        additional_fields = None,\n    ):\n        def field_type_generator(k, v):\n            field_type = v.annotation\n\n            if field_type == 'Image':\n                # images are sent as base64 strings via API\n                field_type = 'str'\n\n            return Optional[field_type]\n\n        def merge_class_params(class_):\n            all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))\n            parameters = {}\n            for classes in all_classes:\n                parameters = {**parameters, **inspect.signature(classes.__init__).parameters}\n            return parameters\n\n        self._model_name = model_name\n        self._class_data = merge_class_params(class_instance)\n\n        self._model_def = [\n            ModelDef(\n                field=underscore(k),\n                field_alias=k,\n                field_type=field_type_generator(k, v),\n                field_value=None if isinstance(v.default, property) else v.default\n            )\n            for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED\n        ]\n\n        for fields in additional_fields:\n            self._model_def.append(ModelDef(\n                field=underscore(fields[\"key\"]),\n                field_alias=fields[\"key\"],\n                field_type=fields[\"type\"],\n                field_value=fields[\"default\"],\n                field_exclude=fields[\"exclude\"] if \"exclude\" in fields else False))\n\n    def generate_model(self):\n        \"\"\"\n        Creates a pydantic BaseModel\n        from the json and overrides provided at initialization\n        \"\"\"\n        fields = {\n            d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias, exclude=d.field_exclude)) for d in self._model_def\n        }\n        DynamicModel = create_model(self._model_name, **fields)\n        DynamicModel.__config__.allow_population_by_field_name = True\n        DynamicModel.__config__.allow_mutation = True\n        return DynamicModel\n\nStableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(\n    \"StableDiffusionProcessingTxt2Img\",\n    StableDiffusionProcessingTxt2Img,\n    [\n        {\"key\": \"sampler_index\", \"type\": str, \"default\": \"Euler\"},\n        {\"key\": \"script_name\", \"type\": str, \"default\": None},\n        {\"key\": \"script_args\", \"type\": list, \"default\": []},\n        {\"key\": \"send_images\", \"type\": bool, \"default\": True},\n        {\"key\": \"save_images\", \"type\": bool, \"default\": False},\n        {\"key\": \"alwayson_scripts\", \"type\": dict, \"default\": {}},\n        {\"key\": \"force_task_id\", \"type\": str, \"default\": None},\n        {\"key\": \"infotext\", \"type\": str, \"default\": None},\n    ]\n).generate_model()\n\nStableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(\n    \"StableDiffusionProcessingImg2Img\",\n    StableDiffusionProcessingImg2Img,\n    [\n        {\"key\": \"sampler_index\", \"type\": str, \"default\": \"Euler\"},\n        {\"key\": \"init_images\", \"type\": list, \"default\": None},\n        {\"key\": \"denoising_strength\", \"type\": float, \"default\": 0.75},\n        {\"key\": \"mask\", \"type\": str, \"default\": None},\n        {\"key\": \"include_init_images\", \"type\": bool, \"default\": False, \"exclude\" : True},\n        {\"key\": \"script_name\", \"type\": str, \"default\": None},\n        {\"key\": \"script_args\", \"type\": list, \"default\": []},\n        {\"key\": \"send_images\", \"type\": bool, \"default\": True},\n        {\"key\": \"save_images\", \"type\": bool, \"default\": False},\n        {\"key\": \"alwayson_scripts\", \"type\": dict, \"default\": {}},\n        {\"key\": \"force_task_id\", \"type\": str, \"default\": None},\n        {\"key\": \"infotext\", \"type\": str, \"default\": None},\n    ]\n).generate_model()\n\nclass TextToImageResponse(BaseModel):\n    images: list[str] = Field(default=None, title=\"Image\", description=\"The generated image in base64 format.\")\n    parameters: dict\n    info: str\n\nclass ImageToImageResponse(BaseModel):\n    images: list[str] = Field(default=None, title=\"Image\", description=\"The generated image in base64 format.\")\n    parameters: dict\n    info: str\n\nclass ExtrasBaseRequest(BaseModel):\n    resize_mode: Literal[0, 1] = Field(default=0, title=\"Resize Mode\", description=\"Sets the resize mode: 0 to upscale by upscaling_resize amount, 1 to upscale up to upscaling_resize_h x upscaling_resize_w.\")\n    show_extras_results: bool = Field(default=True, title=\"Show results\", description=\"Should the backend return the generated image?\")\n    gfpgan_visibility: float = Field(default=0, title=\"GFPGAN Visibility\", ge=0, le=1, allow_inf_nan=False, description=\"Sets the visibility of GFPGAN, values should be between 0 and 1.\")\n    codeformer_visibility: float = Field(default=0, title=\"CodeFormer Visibility\", ge=0, le=1, allow_inf_nan=False, description=\"Sets the visibility of CodeFormer, values should be between 0 and 1.\")\n    codeformer_weight: float = Field(default=0, title=\"CodeFormer Weight\", ge=0, le=1, allow_inf_nan=False, description=\"Sets the weight of CodeFormer, values should be between 0 and 1.\")\n    upscaling_resize: float = Field(default=2, title=\"Upscaling Factor\", gt=0, description=\"By how much to upscale the image, only used when resize_mode=0.\")\n    upscaling_resize_w: int = Field(default=512, title=\"Target Width\", ge=1, description=\"Target width for the upscaler to hit. Only used when resize_mode=1.\")\n    upscaling_resize_h: int = Field(default=512, title=\"Target Height\", ge=1, description=\"Target height for the upscaler to hit. Only used when resize_mode=1.\")\n    upscaling_crop: bool = Field(default=True, title=\"Crop to fit\", description=\"Should the upscaler crop the image to fit in the chosen size?\")\n    upscaler_1: str = Field(default=\"None\", title=\"Main upscaler\", description=f\"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}\")\n    upscaler_2: str = Field(default=\"None\", title=\"Secondary upscaler\", description=f\"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}\")\n    extras_upscaler_2_visibility: float = Field(default=0, title=\"Secondary upscaler visibility\", ge=0, le=1, allow_inf_nan=False, description=\"Sets the visibility of secondary upscaler, values should be between 0 and 1.\")\n    upscale_first: bool = Field(default=False, title=\"Upscale first\", description=\"Should the upscaler run before restoring faces?\")\n\nclass ExtraBaseResponse(BaseModel):\n    html_info: str = Field(title=\"HTML info\", description=\"A series of HTML tags containing the process info.\")\n\nclass ExtrasSingleImageRequest(ExtrasBaseRequest):\n    image: str = Field(default=\"\", title=\"Image\", description=\"Image to work on, must be a Base64 string containing the image's data.\")\n\nclass ExtrasSingleImageResponse(ExtraBaseResponse):\n    image: str = Field(default=None, title=\"Image\", description=\"The generated image in base64 format.\")\n\nclass FileData(BaseModel):\n    data: str = Field(title=\"File data\", description=\"Base64 representation of the file\")\n    name: str = Field(title=\"File name\")\n\nclass ExtrasBatchImagesRequest(ExtrasBaseRequest):\n    imageList: list[FileData] = Field(title=\"Images\", description=\"List of images to work on. Must be Base64 strings\")\n\nclass ExtrasBatchImagesResponse(ExtraBaseResponse):\n    images: list[str] = Field(title=\"Images\", description=\"The generated images in base64 format.\")\n\nclass PNGInfoRequest(BaseModel):\n    image: str = Field(title=\"Image\", description=\"The base64 encoded PNG image\")\n\nclass PNGInfoResponse(BaseModel):\n    info: str = Field(title=\"Image info\", description=\"A string with the parameters used to generate the image\")\n    items: dict = Field(title=\"Items\", description=\"A dictionary containing all the other fields the image had\")\n    parameters: dict = Field(title=\"Parameters\", description=\"A dictionary with parsed generation info fields\")\n\nclass ProgressRequest(BaseModel):\n    skip_current_image: bool = Field(default=False, title=\"Skip current image\", description=\"Skip current image serialization\")\n\nclass ProgressResponse(BaseModel):\n    progress: float = Field(title=\"Progress\", description=\"The progress with a range of 0 to 1\")\n    eta_relative: float = Field(title=\"ETA in secs\")\n    state: dict = Field(title=\"State\", description=\"The current state snapshot\")\n    current_image: str = Field(default=None, title=\"Current image\", description=\"The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.\")\n    textinfo: str = Field(default=None, title=\"Info text\", description=\"Info text used by WebUI.\")\n\nclass InterrogateRequest(BaseModel):\n    image: str = Field(default=\"\", title=\"Image\", description=\"Image to work on, must be a Base64 string containing the image's data.\")\n    model: str = Field(default=\"clip\", title=\"Model\", description=\"The interrogate model used.\")\n\nclass InterrogateResponse(BaseModel):\n    caption: str = Field(default=None, title=\"Caption\", description=\"The generated caption for the image.\")\n\nclass TrainResponse(BaseModel):\n    info: str = Field(title=\"Train info\", description=\"Response string from train embedding or hypernetwork task.\")\n\nclass CreateResponse(BaseModel):\n    info: str = Field(title=\"Create info\", description=\"Response string from create embedding or hypernetwork task.\")\n\nfields = {}\nfor key, metadata in opts.data_labels.items():\n    value = opts.data.get(key)\n    optType = opts.typemap.get(type(metadata.default), type(metadata.default)) if metadata.default else Any\n\n    if metadata is not None:\n        fields.update({key: (Optional[optType], Field(default=metadata.default, description=metadata.label))})\n    else:\n        fields.update({key: (Optional[optType], Field())})\n\nOptionsModel = create_model(\"Options\", **fields)\n\nflags = {}\n_options = vars(parser)['_option_string_actions']\nfor key in _options:\n    if(_options[key].dest != 'help'):\n        flag = _options[key]\n        _type = str\n        if _options[key].default is not None:\n            _type = type(_options[key].default)\n        flags.update({flag.dest: (_type, Field(default=flag.default, description=flag.help))})\n\nFlagsModel = create_model(\"Flags\", **flags)\n\nclass SamplerItem(BaseModel):\n    name: str = Field(title=\"Name\")\n    aliases: list[str] = Field(title=\"Aliases\")\n    options: dict[str, str] = Field(title=\"Options\")\n\nclass SchedulerItem(BaseModel):\n    name: str = Field(title=\"Name\")\n    label: str = Field(title=\"Label\")\n    aliases: Optional[list[str]] = Field(title=\"Aliases\")\n    default_rho: Optional[float] = Field(title=\"Default Rho\")\n    need_inner_model: Optional[bool] = Field(title=\"Needs Inner Model\")\n\nclass UpscalerItem(BaseModel):\n    name: str = Field(title=\"Name\")\n    model_name: Optional[str] = Field(title=\"Model Name\")\n    model_path: Optional[str] = Field(title=\"Path\")\n    model_url: Optional[str] = Field(title=\"URL\")\n    scale: Optional[float] = Field(title=\"Scale\")\n\nclass LatentUpscalerModeItem(BaseModel):\n    name: str = Field(title=\"Name\")\n\nclass SDModelItem(BaseModel):\n    title: str = Field(title=\"Title\")\n    model_name: str = Field(title=\"Model Name\")\n    hash: Optional[str] = Field(title=\"Short hash\")\n    sha256: Optional[str] = Field(title=\"sha256 hash\")\n    filename: str = Field(title=\"Filename\")\n    config: Optional[str] = Field(title=\"Config file\")\n\nclass SDVaeItem(BaseModel):\n    model_name: str = Field(title=\"Model Name\")\n    filename: str = Field(title=\"Filename\")\n\nclass HypernetworkItem(BaseModel):\n    name: str = Field(title=\"Name\")\n    path: Optional[str] = Field(title=\"Path\")\n\nclass FaceRestorerItem(BaseModel):\n    name: str = Field(title=\"Name\")\n    cmd_dir: Optional[str] = Field(title=\"Path\")\n\nclass RealesrganItem(BaseModel):\n    name: str = Field(title=\"Name\")\n    path: Optional[str] = Field(title=\"Path\")\n    scale: Optional[int] = Field(title=\"Scale\")\n\nclass PromptStyleItem(BaseModel):\n    name: str = Field(title=\"Name\")\n    prompt: Optional[str] = Field(title=\"Prompt\")\n    negative_prompt: Optional[str] = Field(title=\"Negative Prompt\")\n\n\nclass EmbeddingItem(BaseModel):\n    step: Optional[int] = Field(title=\"Step\", description=\"The number of steps that were used to train this embedding, if available\")\n    sd_checkpoint: Optional[str] = Field(title=\"SD Checkpoint\", description=\"The hash of the checkpoint this embedding was trained on, if available\")\n    sd_checkpoint_name: Optional[str] = Field(title=\"SD Checkpoint Name\", description=\"The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead\")\n    shape: int = Field(title=\"Shape\", description=\"The length of each individual vector in the embedding\")\n    vectors: int = Field(title=\"Vectors\", description=\"The number of vectors in the embedding\")\n\nclass EmbeddingsResponse(BaseModel):\n    loaded: dict[str, EmbeddingItem] = Field(title=\"Loaded\", description=\"Embeddings loaded for the current model\")\n    skipped: dict[str, EmbeddingItem] = Field(title=\"Skipped\", description=\"Embeddings skipped for the current model (likely due to architecture incompatibility)\")\n\nclass MemoryResponse(BaseModel):\n    ram: dict = Field(title=\"RAM\", description=\"System memory stats\")\n    cuda: dict = Field(title=\"CUDA\", description=\"nVidia CUDA memory stats\")\n\n\nclass ScriptsList(BaseModel):\n    txt2img: list = Field(default=None, title=\"Txt2img\", description=\"Titles of scripts (txt2img)\")\n    img2img: list = Field(default=None, title=\"Img2img\", description=\"Titles of scripts (img2img)\")\n\n\nclass ScriptArg(BaseModel):\n    label: str = Field(default=None, title=\"Label\", description=\"Name of the argument in UI\")\n    value: Optional[Any] = Field(default=None, title=\"Value\", description=\"Default value of the argument\")\n    minimum: Optional[Any] = Field(default=None, title=\"Minimum\", description=\"Minimum allowed value for the argumentin UI\")\n    maximum: Optional[Any] = Field(default=None, title=\"Minimum\", description=\"Maximum allowed value for the argumentin UI\")\n    step: Optional[Any] = Field(default=None, title=\"Minimum\", description=\"Step for changing value of the argumentin UI\")\n    choices: Optional[list[str]] = Field(default=None, title=\"Choices\", description=\"Possible values for the argument\")\n\n\nclass ScriptInfo(BaseModel):\n    name: str = Field(default=None, title=\"Name\", description=\"Script name\")\n    is_alwayson: bool = Field(default=None, title=\"IsAlwayson\", description=\"Flag specifying whether this script is an alwayson script\")\n    is_img2img: bool = Field(default=None, title=\"IsImg2img\", description=\"Flag specifying whether this script is an img2img script\")\n    args: list[ScriptArg] = Field(title=\"Arguments\", description=\"List of script's arguments\")\n\nclass ExtensionItem(BaseModel):\n    name: str = Field(title=\"Name\", description=\"Extension name\")\n    remote: str = Field(title=\"Remote\", description=\"Extension Repository URL\")\n    branch: str = Field(title=\"Branch\", description=\"Extension Repository Branch\")\n    commit_hash: str = Field(title=\"Commit Hash\", description=\"Extension Repository Commit Hash\")\n    version: str = Field(title=\"Version\", description=\"Extension Version\")\n    commit_date: str = Field(title=\"Commit Date\", description=\"Extension Repository Commit Date\")\n    enabled: bool = Field(title=\"Enabled\", description=\"Flag specifying whether this extension is enabled\")\n"
  },
  {
    "path": "modules/cache.py",
    "content": "import json\r\nimport os\r\nimport os.path\r\nimport threading\r\n\r\nimport diskcache\r\nimport tqdm\r\n\r\nfrom modules.paths import data_path, script_path\r\n\r\ncache_filename = os.environ.get('SD_WEBUI_CACHE_FILE', os.path.join(data_path, \"cache.json\"))\r\ncache_dir = os.environ.get('SD_WEBUI_CACHE_DIR', os.path.join(data_path, \"cache\"))\r\ncaches = {}\r\ncache_lock = threading.Lock()\r\n\r\n\r\ndef dump_cache():\r\n    \"\"\"old function for dumping cache to disk; does nothing since diskcache.\"\"\"\r\n\r\n    pass\r\n\r\n\r\ndef make_cache(subsection: str) -> diskcache.Cache:\r\n    return diskcache.Cache(\r\n        os.path.join(cache_dir, subsection),\r\n        size_limit=2**32,  # 4 GB, culling oldest first\r\n        disk_min_file_size=2**18,  # keep up to 256KB in Sqlite\r\n    )\r\n\r\n\r\ndef convert_old_cached_data():\r\n    try:\r\n        with open(cache_filename, \"r\", encoding=\"utf8\") as file:\r\n            data = json.load(file)\r\n    except FileNotFoundError:\r\n        return\r\n    except Exception:\r\n        os.replace(cache_filename, os.path.join(script_path, \"tmp\", \"cache.json\"))\r\n        print('[ERROR] issue occurred while trying to read cache.json; old cache has been moved to tmp/cache.json')\r\n        return\r\n\r\n    total_count = sum(len(keyvalues) for keyvalues in data.values())\r\n\r\n    with tqdm.tqdm(total=total_count, desc=\"converting cache\") as progress:\r\n        for subsection, keyvalues in data.items():\r\n            cache_obj = caches.get(subsection)\r\n            if cache_obj is None:\r\n                cache_obj = make_cache(subsection)\r\n                caches[subsection] = cache_obj\r\n\r\n            for key, value in keyvalues.items():\r\n                cache_obj[key] = value\r\n                progress.update(1)\r\n\r\n\r\ndef cache(subsection):\r\n    \"\"\"\r\n    Retrieves or initializes a cache for a specific subsection.\r\n\r\n    Parameters:\r\n        subsection (str): The subsection identifier for the cache.\r\n\r\n    Returns:\r\n        diskcache.Cache: The cache data for the specified subsection.\r\n    \"\"\"\r\n\r\n    cache_obj = caches.get(subsection)\r\n    if not cache_obj:\r\n        with cache_lock:\r\n            if not os.path.exists(cache_dir) and os.path.isfile(cache_filename):\r\n                convert_old_cached_data()\r\n\r\n            cache_obj = caches.get(subsection)\r\n            if not cache_obj:\r\n                cache_obj = make_cache(subsection)\r\n                caches[subsection] = cache_obj\r\n\r\n    return cache_obj\r\n\r\n\r\ndef cached_data_for_file(subsection, title, filename, func):\r\n    \"\"\"\r\n    Retrieves or generates data for a specific file, using a caching mechanism.\r\n\r\n    Parameters:\r\n        subsection (str): The subsection of the cache to use.\r\n        title (str): The title of the data entry in the subsection of the cache.\r\n        filename (str): The path to the file to be checked for modifications.\r\n        func (callable): A function that generates the data if it is not available in the cache.\r\n\r\n    Returns:\r\n        dict or None: The cached or generated data, or None if data generation fails.\r\n\r\n    The `cached_data_for_file` function implements a caching mechanism for data stored in files.\r\n    It checks if the data associated with the given `title` is present in the cache and compares the\r\n    modification time of the file with the cached modification time. If the file has been modified,\r\n    the cache is considered invalid and the data is regenerated using the provided `func`.\r\n    Otherwise, the cached data is returned.\r\n\r\n    If the data generation fails, None is returned to indicate the failure. Otherwise, the generated\r\n    or cached data is returned as a dictionary.\r\n    \"\"\"\r\n\r\n    existing_cache = cache(subsection)\r\n    ondisk_mtime = os.path.getmtime(filename)\r\n\r\n    entry = existing_cache.get(title)\r\n    if entry:\r\n        cached_mtime = entry.get(\"mtime\", 0)\r\n        if ondisk_mtime > cached_mtime:\r\n            entry = None\r\n\r\n    if not entry or 'value' not in entry:\r\n        value = func()\r\n        if value is None:\r\n            return None\r\n\r\n        entry = {'mtime': ondisk_mtime, 'value': value}\r\n        existing_cache[title] = entry\r\n\r\n        dump_cache()\r\n\r\n    return entry['value']\r\n"
  },
  {
    "path": "modules/call_queue.py",
    "content": "import os.path\r\nfrom functools import wraps\r\nimport html\r\nimport time\r\n\r\nfrom modules import shared, progress, errors, devices, fifo_lock, profiling\r\n\r\nqueue_lock = fifo_lock.FIFOLock()\r\n\r\n\r\ndef wrap_queued_call(func):\r\n    def f(*args, **kwargs):\r\n        with queue_lock:\r\n            res = func(*args, **kwargs)\r\n\r\n        return res\r\n\r\n    return f\r\n\r\n\r\ndef wrap_gradio_gpu_call(func, extra_outputs=None):\r\n    @wraps(func)\r\n    def f(*args, **kwargs):\r\n\r\n        # if the first argument is a string that says \"task(...)\", it is treated as a job id\r\n        if args and type(args[0]) == str and args[0].startswith(\"task(\") and args[0].endswith(\")\"):\r\n            id_task = args[0]\r\n            progress.add_task_to_queue(id_task)\r\n        else:\r\n            id_task = None\r\n\r\n        with queue_lock:\r\n            shared.state.begin(job=id_task)\r\n            progress.start_task(id_task)\r\n\r\n            try:\r\n                res = func(*args, **kwargs)\r\n                progress.record_results(id_task, res)\r\n            finally:\r\n                progress.finish_task(id_task)\r\n\r\n            shared.state.end()\r\n\r\n        return res\r\n\r\n    return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)\r\n\r\n\r\ndef wrap_gradio_call(func, extra_outputs=None, add_stats=False):\r\n    @wraps(func)\r\n    def f(*args, **kwargs):\r\n        try:\r\n            res = func(*args, **kwargs)\r\n        finally:\r\n            shared.state.skipped = False\r\n            shared.state.interrupted = False\r\n            shared.state.stopping_generation = False\r\n            shared.state.job_count = 0\r\n            shared.state.job = \"\"\r\n        return res\r\n\r\n    return wrap_gradio_call_no_job(f, extra_outputs, add_stats)\r\n\r\n\r\ndef wrap_gradio_call_no_job(func, extra_outputs=None, add_stats=False):\r\n    @wraps(func)\r\n    def f(*args, extra_outputs_array=extra_outputs, **kwargs):\r\n        run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats\r\n        if run_memmon:\r\n            shared.mem_mon.monitor()\r\n        t = time.perf_counter()\r\n\r\n        try:\r\n            res = list(func(*args, **kwargs))\r\n        except Exception as e:\r\n            # When printing out our debug argument list,\r\n            # do not print out more than a 100 KB of text\r\n            max_debug_str_len = 131072\r\n            message = \"Error completing request\"\r\n            arg_str = f\"Arguments: {args} {kwargs}\"[:max_debug_str_len]\r\n            if len(arg_str) > max_debug_str_len:\r\n                arg_str += f\" (Argument list truncated at {max_debug_str_len}/{len(arg_str)} characters)\"\r\n            errors.report(f\"{message}\\n{arg_str}\", exc_info=True)\r\n\r\n            if extra_outputs_array is None:\r\n                extra_outputs_array = [None, '']\r\n\r\n            error_message = f'{type(e).__name__}: {e}'\r\n            res = extra_outputs_array + [f\"<div class='error'>{html.escape(error_message)}</div>\"]\r\n\r\n        devices.torch_gc()\r\n\r\n        if not add_stats:\r\n            return tuple(res)\r\n\r\n        elapsed = time.perf_counter() - t\r\n        elapsed_m = int(elapsed // 60)\r\n        elapsed_s = elapsed % 60\r\n        elapsed_text = f\"{elapsed_s:.1f} sec.\"\r\n        if elapsed_m > 0:\r\n            elapsed_text = f\"{elapsed_m} min. \"+elapsed_text\r\n\r\n        if run_memmon:\r\n            mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()}\r\n            active_peak = mem_stats['active_peak']\r\n            reserved_peak = mem_stats['reserved_peak']\r\n            sys_peak = mem_stats['system_peak']\r\n            sys_total = mem_stats['total']\r\n            sys_pct = sys_peak/max(sys_total, 1) * 100\r\n\r\n            toltip_a = \"Active: peak amount of video memory used during generation (excluding cached data)\"\r\n            toltip_r = \"Reserved: total amount of video memory allocated by the Torch library \"\r\n            toltip_sys = \"System: peak amount of video memory allocated by all running programs, out of total capacity\"\r\n\r\n            text_a = f\"<abbr title='{toltip_a}'>A</abbr>: <span class='measurement'>{active_peak/1024:.2f} GB</span>\"\r\n            text_r = f\"<abbr title='{toltip_r}'>R</abbr>: <span class='measurement'>{reserved_peak/1024:.2f} GB</span>\"\r\n            text_sys = f\"<abbr title='{toltip_sys}'>Sys</abbr>: <span class='measurement'>{sys_peak/1024:.1f}/{sys_total/1024:g} GB</span> ({sys_pct:.1f}%)\"\r\n\r\n            vram_html = f\"<p class='vram'>{text_a}, <wbr>{text_r}, <wbr>{text_sys}</p>\"\r\n        else:\r\n            vram_html = ''\r\n\r\n        if shared.opts.profiling_enable and os.path.exists(shared.opts.profiling_filename):\r\n            profiling_html = f\"<p class='profile'> [ <a href='{profiling.webpath()}' download>Profile</a> ] </p>\"\r\n        else:\r\n            profiling_html = ''\r\n\r\n        # last item is always HTML\r\n        res[-1] += f\"<div class='performance'><p class='time'>Time taken: <wbr><span class='measurement'>{elapsed_text}</span></p>{vram_html}{profiling_html}</div>\"\r\n\r\n        return tuple(res)\r\n\r\n    return f\r\n\r\n"
  },
  {
    "path": "modules/cmd_args.py",
    "content": "import argparse\r\nimport json\r\nimport os\r\nfrom modules.paths_internal import normalized_filepath, models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file  # noqa: F401\r\n\r\nparser = argparse.ArgumentParser()\r\n\r\nparser.add_argument(\"-f\", action='store_true', help=argparse.SUPPRESS)  # allows running as root; implemented outside of webui\r\nparser.add_argument(\"--update-all-extensions\", action='store_true', help=\"launch.py argument: download updates for all extensions when starting the program\")\r\nparser.add_argument(\"--skip-python-version-check\", action='store_true', help=\"launch.py argument: do not check python version\")\r\nparser.add_argument(\"--skip-torch-cuda-test\", action='store_true', help=\"launch.py argument: do not check if CUDA is able to work properly\")\r\nparser.add_argument(\"--reinstall-xformers\", action='store_true', help=\"launch.py argument: install the appropriate version of xformers even if you have some version already installed\")\r\nparser.add_argument(\"--reinstall-torch\", action='store_true', help=\"launch.py argument: install the appropriate version of torch even if you have some version already installed\")\r\nparser.add_argument(\"--update-check\", action='store_true', help=\"launch.py argument: check for updates at startup\")\r\nparser.add_argument(\"--test-server\", action='store_true', help=\"launch.py argument: configure server for testing\")\r\nparser.add_argument(\"--log-startup\", action='store_true', help=\"launch.py argument: print a detailed log of what's happening at startup\")\r\nparser.add_argument(\"--skip-prepare-environment\", action='store_true', help=\"launch.py argument: skip all environment preparation\")\r\nparser.add_argument(\"--skip-install\", action='store_true', help=\"launch.py argument: skip installation of packages\")\r\nparser.add_argument(\"--dump-sysinfo\", action='store_true', help=\"launch.py argument: dump limited sysinfo file (without information about extensions, options) to disk and quit\")\r\nparser.add_argument(\"--loglevel\", type=str, help=\"log level; one of: CRITICAL, ERROR, WARNING, INFO, DEBUG\", default=None)\r\nparser.add_argument(\"--do-not-download-clip\", action='store_true', help=\"do not download CLIP model even if it's not included in the checkpoint\")\r\nparser.add_argument(\"--data-dir\", type=normalized_filepath, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help=\"base path where all user data is stored\")\r\nparser.add_argument(\"--models-dir\", type=normalized_filepath, default=None, help=\"base path where models are stored; overrides --data-dir\")\r\nparser.add_argument(\"--config\", type=normalized_filepath, default=sd_default_config, help=\"path to config which constructs model\",)\r\nparser.add_argument(\"--ckpt\", type=normalized_filepath, default=sd_model_file, help=\"path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded\",)\r\nparser.add_argument(\"--ckpt-dir\", type=normalized_filepath, default=None, help=\"Path to directory with stable diffusion checkpoints\")\r\nparser.add_argument(\"--vae-dir\", type=normalized_filepath, default=None, help=\"Path to directory with VAE files\")\r\nparser.add_argument(\"--gfpgan-dir\", type=normalized_filepath, help=\"GFPGAN directory\", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))\r\nparser.add_argument(\"--gfpgan-model\", type=normalized_filepath, help=\"GFPGAN model file name\", default=None)\r\nparser.add_argument(\"--no-half\", action='store_true', help=\"do not switch the model to 16-bit floats\")\r\nparser.add_argument(\"--no-half-vae\", action='store_true', help=\"do not switch the VAE model to 16-bit floats\")\r\nparser.add_argument(\"--no-progressbar-hiding\", action='store_true', help=\"do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)\")\r\nparser.add_argument(\"--max-batch-count\", type=int, default=16, help=\"does not do anything\")\r\nparser.add_argument(\"--embeddings-dir\", type=normalized_filepath, default=os.path.join(data_path, 'embeddings'), help=\"embeddings directory for textual inversion (default: embeddings)\")\r\nparser.add_argument(\"--textual-inversion-templates-dir\", type=normalized_filepath, default=os.path.join(script_path, 'textual_inversion_templates'), help=\"directory with textual inversion templates\")\r\nparser.add_argument(\"--hypernetwork-dir\", type=normalized_filepath, default=os.path.join(models_path, 'hypernetworks'), help=\"hypernetwork directory\")\r\nparser.add_argument(\"--localizations-dir\", type=normalized_filepath, default=os.path.join(script_path, 'localizations'), help=\"localizations directory\")\r\nparser.add_argument(\"--allow-code\", action='store_true', help=\"allow custom script execution from webui\")\r\nparser.add_argument(\"--medvram\", action='store_true', help=\"enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage\")\r\nparser.add_argument(\"--medvram-sdxl\", action='store_true', help=\"enable --medvram optimization just for SDXL models\")\r\nparser.add_argument(\"--lowvram\", action='store_true', help=\"enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage\")\r\nparser.add_argument(\"--lowram\", action='store_true', help=\"load stable diffusion checkpoint weights to VRAM instead of RAM\")\r\nparser.add_argument(\"--always-batch-cond-uncond\", action='store_true', help=\"does not do anything\")\r\nparser.add_argument(\"--unload-gfpgan\", action='store_true', help=\"does not do anything.\")\r\nparser.add_argument(\"--precision\", type=str, help=\"evaluate at this precision\", choices=[\"full\", \"half\", \"autocast\"], default=\"autocast\")\r\nparser.add_argument(\"--upcast-sampling\", action='store_true', help=\"upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.\")\r\nparser.add_argument(\"--share\", action='store_true', help=\"use share=True for gradio and make the UI accessible through their site\")\r\nparser.add_argument(\"--ngrok\", type=str, help=\"ngrok authtoken, alternative to gradio --share\", default=None)\r\nparser.add_argument(\"--ngrok-region\", type=str, help=\"does not do anything.\", default=\"\")\r\nparser.add_argument(\"--ngrok-options\", type=json.loads, help='The options to pass to ngrok in JSON format, e.g.: \\'{\"authtoken_from_env\":true, \"basic_auth\":\"user:password\", \"oauth_provider\":\"google\", \"oauth_allow_emails\":\"user@asdf.com\"}\\'', default=dict())\r\nparser.add_argument(\"--enable-insecure-extension-access\", action='store_true', help=\"enable extensions tab regardless of other options\")\r\nparser.add_argument(\"--codeformer-models-path\", type=normalized_filepath, help=\"Path to directory with codeformer model file(s).\", default=os.path.join(models_path, 'Codeformer'))\r\nparser.add_argument(\"--gfpgan-models-path\", type=normalized_filepath, help=\"Path to directory with GFPGAN model file(s).\", default=os.path.join(models_path, 'GFPGAN'))\r\nparser.add_argument(\"--esrgan-models-path\", type=normalized_filepath, help=\"Path to directory with ESRGAN model file(s).\", default=os.path.join(models_path, 'ESRGAN'))\r\nparser.add_argument(\"--bsrgan-models-path\", type=normalized_filepath, help=\"Path to directory with BSRGAN model file(s).\", default=os.path.join(models_path, 'BSRGAN'))\r\nparser.add_argument(\"--realesrgan-models-path\", type=normalized_filepath, help=\"Path to directory with RealESRGAN model file(s).\", default=os.path.join(models_path, 'RealESRGAN'))\r\nparser.add_argument(\"--dat-models-path\", type=normalized_filepath, help=\"Path to directory with DAT model file(s).\", default=os.path.join(models_path, 'DAT'))\r\nparser.add_argument(\"--clip-models-path\", type=normalized_filepath, help=\"Path to directory with CLIP model file(s).\", default=None)\r\nparser.add_argument(\"--xformers\", action='store_true', help=\"enable xformers for cross attention layers\")\r\nparser.add_argument(\"--force-enable-xformers\", action='store_true', help=\"enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work\")\r\nparser.add_argument(\"--xformers-flash-attention\", action='store_true', help=\"enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)\")\r\nparser.add_argument(\"--deepdanbooru\", action='store_true', help=\"does not do anything\")\r\nparser.add_argument(\"--opt-split-attention\", action='store_true', help=\"prefer Doggettx's cross-attention layer optimization for automatic choice of optimization\")\r\nparser.add_argument(\"--opt-sub-quad-attention\", action='store_true', help=\"prefer memory efficient sub-quadratic cross-attention layer optimization for automatic choice of optimization\")\r\nparser.add_argument(\"--sub-quad-q-chunk-size\", type=int, help=\"query chunk size for the sub-quadratic cross-attention layer optimization to use\", default=1024)\r\nparser.add_argument(\"--sub-quad-kv-chunk-size\", type=int, help=\"kv chunk size for the sub-quadratic cross-attention layer optimization to use\", default=None)\r\nparser.add_argument(\"--sub-quad-chunk-threshold\", type=int, help=\"the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking\", default=None)\r\nparser.add_argument(\"--opt-split-attention-invokeai\", action='store_true', help=\"prefer InvokeAI's cross-attention layer optimization for automatic choice of optimization\")\r\nparser.add_argument(\"--opt-split-attention-v1\", action='store_true', help=\"prefer older version of split attention optimization for automatic choice of optimization\")\r\nparser.add_argument(\"--opt-sdp-attention\", action='store_true', help=\"prefer scaled dot product cross-attention layer optimization for automatic choice of optimization; requires PyTorch 2.*\")\r\nparser.add_argument(\"--opt-sdp-no-mem-attention\", action='store_true', help=\"prefer scaled dot product cross-attention layer optimization without memory efficient attention for automatic choice of optimization, makes image generation deterministic; requires PyTorch 2.*\")\r\nparser.add_argument(\"--disable-opt-split-attention\", action='store_true', help=\"prefer no cross-attention layer optimization for automatic choice of optimization\")\r\nparser.add_argument(\"--disable-nan-check\", action='store_true', help=\"do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI\")\r\nparser.add_argument(\"--use-cpu\", nargs='+', help=\"use CPU as torch device for specified modules\", default=[], type=str.lower)\r\nparser.add_argument(\"--use-ipex\", action=\"store_true\", help=\"use Intel XPU as torch device\")\r\nparser.add_argument(\"--disable-model-loading-ram-optimization\", action='store_true', help=\"disable an optimization that reduces RAM use when loading a model\")\r\nparser.add_argument(\"--listen\", action='store_true', help=\"launch gradio with 0.0.0.0 as server name, allowing to respond to network requests\")\r\nparser.add_argument(\"--port\", type=int, help=\"launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available\", default=None)\r\nparser.add_argument(\"--show-negative-prompt\", action='store_true', help=\"does not do anything\", default=False)\r\nparser.add_argument(\"--ui-config-file\", type=str, help=\"filename to use for ui configuration\", default=os.path.join(data_path, 'ui-config.json'))\r\nparser.add_argument(\"--hide-ui-dir-config\", action='store_true', help=\"hide directory configuration from webui\", default=False)\r\nparser.add_argument(\"--freeze-settings\", action='store_true', help=\"disable editing of all settings globally\", default=False)\r\nparser.add_argument(\"--freeze-settings-in-sections\", type=str, help='disable editing settings in specific sections of the settings page by specifying a comma-delimited list such like \"saving-images,upscaling\". The list of setting names can be found in the modules/shared_options.py file', default=None)\r\nparser.add_argument(\"--freeze-specific-settings\", type=str, help='disable editing of individual settings by specifying a comma-delimited list like \"samples_save,samples_format\". The list of setting names can be found in the config.json file', default=None)\r\nparser.add_argument(\"--ui-settings-file\", type=str, help=\"filename to use for ui settings\", default=os.path.join(data_path, 'config.json'))\r\nparser.add_argument(\"--gradio-debug\",  action='store_true', help=\"launch gradio with --debug option\")\r\nparser.add_argument(\"--gradio-auth\", type=str, help='set gradio authentication like \"username:password\"; or comma-delimit multiple like \"u1:p1,u2:p2,u3:p3\"', default=None)\r\nparser.add_argument(\"--gradio-auth-path\", type=normalized_filepath, help='set gradio authentication file path ex. \"/path/to/auth/file\" same auth format as --gradio-auth', default=None)\r\nparser.add_argument(\"--gradio-img2img-tool\", type=str, help='does not do anything')\r\nparser.add_argument(\"--gradio-inpaint-tool\", type=str, help=\"does not do anything\")\r\nparser.add_argument(\"--gradio-allowed-path\", action='append', help=\"add path to gradio's allowed_paths, make it possible to serve files from it\", default=[data_path])\r\nparser.add_argument(\"--opt-channelslast\", action='store_true', help=\"change memory type for stable diffusion to channels last\")\r\nparser.add_argument(\"--styles-file\", type=str, action='append', help=\"path or wildcard path of styles files, allow multiple entries.\", default=[])\r\nparser.add_argument(\"--autolaunch\", action='store_true', help=\"open the webui URL in the system's default browser upon launch\", default=False)\r\nparser.add_argument(\"--theme\", type=str, help=\"launches the UI with light or dark theme\", default=None)\r\nparser.add_argument(\"--use-textbox-seed\", action='store_true', help=\"use textbox for seeds in UI (no up/down, but possible to input long seeds)\", default=False)\r\nparser.add_argument(\"--disable-console-progressbars\", action='store_true', help=\"do not output progressbars to console\", default=False)\r\nparser.add_argument(\"--enable-console-prompts\", action='store_true', help=\"does not do anything\", default=False)  # Legacy compatibility, use as default value shared.opts.enable_console_prompts\r\nparser.add_argument('--vae-path', type=normalized_filepath, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None)\r\nparser.add_argument(\"--disable-safe-unpickle\", action='store_true', help=\"disable checking pytorch models for malicious code\", default=False)\r\nparser.add_argument(\"--api\", action='store_true', help=\"use api=True to launch the API together with the webui (use --nowebui instead for only the API)\")\r\nparser.add_argument(\"--api-auth\", type=str, help='Set authentication for API like \"username:password\"; or comma-delimit multiple like \"u1:p1,u2:p2,u3:p3\"', default=None)\r\nparser.add_argument(\"--api-log\", action='store_true', help=\"use api-log=True to enable logging of all API requests\")\r\nparser.add_argument(\"--nowebui\", action='store_true', help=\"use api=True to launch the API instead of the webui\")\r\nparser.add_argument(\"--ui-debug-mode\", action='store_true', help=\"Don't load model to quickly launch UI\")\r\nparser.add_argument(\"--device-id\", type=str, help=\"Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)\", default=None)\r\nparser.add_argument(\"--administrator\", action='store_true', help=\"Administrator rights\", default=False)\r\nparser.add_argument(\"--cors-allow-origins\", type=str, help=\"Allowed CORS origin(s) in the form of a comma-separated list (no spaces)\", default=None)\r\nparser.add_argument(\"--cors-allow-origins-regex\", type=str, help=\"Allowed CORS origin(s) in the form of a single regular expression\", default=None)\r\nparser.add_argument(\"--tls-keyfile\", type=str, help=\"Partially enables TLS, requires --tls-certfile to fully function\", default=None)\r\nparser.add_argument(\"--tls-certfile\", type=str, help=\"Partially enables TLS, requires --tls-keyfile to fully function\", default=None)\r\nparser.add_argument(\"--disable-tls-verify\", action=\"store_false\", help=\"When passed, enables the use of self-signed certificates.\", default=None)\r\nparser.add_argument(\"--server-name\", type=str, help=\"Sets hostname of server\", default=None)\r\nparser.add_argument(\"--gradio-queue\", action='store_true', help=\"does not do anything\", default=True)\r\nparser.add_argument(\"--no-gradio-queue\", action='store_true', help=\"Disables gradio queue; causes the webpage to use http requests instead of websockets; was the default in earlier versions\")\r\nparser.add_argument(\"--skip-version-check\", action='store_true', help=\"Do not check versions of torch and xformers\")\r\nparser.add_argument(\"--no-hashing\", action='store_true', help=\"disable sha256 hashing of checkpoints to help loading performance\", default=False)\r\nparser.add_argument(\"--no-download-sd-model\", action='store_true', help=\"don't download SD1.5 model even if no model is found in --ckpt-dir\", default=False)\r\nparser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')\r\nparser.add_argument('--add-stop-route', action='store_true', help='does not do anything')\r\nparser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api')\r\nparser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn')\r\nparser.add_argument(\"--disable-all-extensions\", action='store_true', help=\"prevent all extensions from running regardless of any other settings\", default=False)\r\nparser.add_argument(\"--disable-extra-extensions\", action='store_true', help=\"prevent all extensions except built-in from running regardless of any other settings\", default=False)\r\nparser.add_argument(\"--skip-load-model-at-start\", action='store_true', help=\"if load a model at web start, only take effect when --nowebui\")\r\nparser.add_argument(\"--unix-filenames-sanitization\", action='store_true', help=\"allow any symbols except '/' in filenames. May conflict with your browser and file system\")\r\nparser.add_argument(\"--filenames-max-length\", type=int, default=128, help='maximal length of filenames of saved images. If you override it, it can conflict with your file system')\r\nparser.add_argument(\"--no-prompt-history\", action='store_true', help=\"disable read prompt from last generation feature; settings this argument will not create '--data_path/params.txt' file\")\r\n"
  },
  {
    "path": "modules/codeformer_model.py",
    "content": "from __future__ import annotations\r\n\r\nimport logging\r\n\r\nimport torch\r\n\r\nfrom modules import (\r\n    devices,\r\n    errors,\r\n    face_restoration,\r\n    face_restoration_utils,\r\n    modelloader,\r\n    shared,\r\n)\r\n\r\nlogger = logging.getLogger(__name__)\r\n\r\nmodel_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'\r\nmodel_download_name = 'codeformer-v0.1.0.pth'\r\n\r\n# used by e.g. postprocessing_codeformer.py\r\ncodeformer: face_restoration.FaceRestoration | None = None\r\n\r\n\r\nclass FaceRestorerCodeFormer(face_restoration_utils.CommonFaceRestoration):\r\n    def name(self):\r\n        return \"CodeFormer\"\r\n\r\n    def load_net(self) -> torch.Module:\r\n        for model_path in modelloader.load_models(\r\n            model_path=self.model_path,\r\n            model_url=model_url,\r\n            command_path=self.model_path,\r\n            download_name=model_download_name,\r\n            ext_filter=['.pth'],\r\n        ):\r\n            return modelloader.load_spandrel_model(\r\n                model_path,\r\n                device=devices.device_codeformer,\r\n                expected_architecture='CodeFormer',\r\n            ).model\r\n        raise ValueError(\"No codeformer model found\")\r\n\r\n    def get_device(self):\r\n        return devices.device_codeformer\r\n\r\n    def restore(self, np_image, w: float | None = None):\r\n        if w is None:\r\n            w = getattr(shared.opts, \"code_former_weight\", 0.5)\r\n\r\n        def restore_face(cropped_face_t):\r\n            assert self.net is not None\r\n            return self.net(cropped_face_t, weight=w, adain=True)[0]\r\n\r\n        return self.restore_with_helper(np_image, restore_face)\r\n\r\n\r\ndef setup_model(dirname: str) -> None:\r\n    global codeformer\r\n    try:\r\n        codeformer = FaceRestorerCodeFormer(dirname)\r\n        shared.face_restorers.append(codeformer)\r\n    except Exception:\r\n        errors.report(\"Error setting up CodeFormer\", exc_info=True)\r\n"
  },
  {
    "path": "modules/config_states.py",
    "content": "\"\"\"\nSupports saving and restoring webui and extensions from a known working set of commits\n\"\"\"\n\nimport os\nimport json\nimport tqdm\n\nfrom datetime import datetime\nimport git\n\nfrom modules import shared, extensions, errors\nfrom modules.paths_internal import script_path, config_states_dir\n\nall_config_states = {}\n\n\ndef list_config_states():\n    global all_config_states\n\n    all_config_states.clear()\n    os.makedirs(config_states_dir, exist_ok=True)\n\n    config_states = []\n    for filename in os.listdir(config_states_dir):\n        if filename.endswith(\".json\"):\n            path = os.path.join(config_states_dir, filename)\n            try:\n                with open(path, \"r\", encoding=\"utf-8\") as f:\n                    j = json.load(f)\n                    assert \"created_at\" in j, '\"created_at\" does not exist'\n                    j[\"filepath\"] = path\n                    config_states.append(j)\n            except Exception as e:\n                print(f'[ERROR]: Config states {path}, {e}')\n\n    config_states = sorted(config_states, key=lambda cs: cs[\"created_at\"], reverse=True)\n\n    for cs in config_states:\n        timestamp = datetime.fromtimestamp(cs[\"created_at\"]).strftime('%Y-%m-%d %H:%M:%S')\n        name = cs.get(\"name\", \"Config\")\n        full_name = f\"{name}: {timestamp}\"\n        all_config_states[full_name] = cs\n\n    return all_config_states\n\n\ndef get_webui_config():\n    webui_repo = None\n\n    try:\n        if os.path.exists(os.path.join(script_path, \".git\")):\n            webui_repo = git.Repo(script_path)\n    except Exception:\n        errors.report(f\"Error reading webui git info from {script_path}\", exc_info=True)\n\n    webui_remote = None\n    webui_commit_hash = None\n    webui_commit_date = None\n    webui_branch = None\n    if webui_repo and not webui_repo.bare:\n        try:\n            webui_remote = next(webui_repo.remote().urls, None)\n            head = webui_repo.head.commit\n            webui_commit_date = webui_repo.head.commit.committed_date\n            webui_commit_hash = head.hexsha\n            webui_branch = webui_repo.active_branch.name\n\n        except Exception:\n            webui_remote = None\n\n    return {\n        \"remote\": webui_remote,\n        \"commit_hash\": webui_commit_hash,\n        \"commit_date\": webui_commit_date,\n        \"branch\": webui_branch,\n    }\n\n\ndef get_extension_config():\n    ext_config = {}\n\n    for ext in extensions.extensions:\n        ext.read_info_from_repo()\n\n        entry = {\n            \"name\": ext.name,\n            \"path\": ext.path,\n            \"enabled\": ext.enabled,\n            \"is_builtin\": ext.is_builtin,\n            \"remote\": ext.remote,\n            \"commit_hash\": ext.commit_hash,\n            \"commit_date\": ext.commit_date,\n            \"branch\": ext.branch,\n            \"have_info_from_repo\": ext.have_info_from_repo\n        }\n\n        ext_config[ext.name] = entry\n\n    return ext_config\n\n\ndef get_config():\n    creation_time = datetime.now().timestamp()\n    webui_config = get_webui_config()\n    ext_config = get_extension_config()\n\n    return {\n        \"created_at\": creation_time,\n        \"webui\": webui_config,\n        \"extensions\": ext_config\n    }\n\n\ndef restore_webui_config(config):\n    print(\"* Restoring webui state...\")\n\n    if \"webui\" not in config:\n        print(\"Error: No webui data saved to config\")\n        return\n\n    webui_config = config[\"webui\"]\n\n    if \"commit_hash\" not in webui_config:\n        print(\"Error: No commit saved to webui config\")\n        return\n\n    webui_commit_hash = webui_config.get(\"commit_hash\", None)\n    webui_repo = None\n\n    try:\n        if os.path.exists(os.path.join(script_path, \".git\")):\n            webui_repo = git.Repo(script_path)\n    except Exception:\n        errors.report(f\"Error reading webui git info from {script_path}\", exc_info=True)\n        return\n\n    try:\n        webui_repo.git.fetch(all=True)\n        webui_repo.git.reset(webui_commit_hash, hard=True)\n        print(f\"* Restored webui to commit {webui_commit_hash}.\")\n    except Exception:\n        errors.report(f\"Error restoring webui to commit{webui_commit_hash}\")\n\n\ndef restore_extension_config(config):\n    print(\"* Restoring extension state...\")\n\n    if \"extensions\" not in config:\n        print(\"Error: No extension data saved to config\")\n        return\n\n    ext_config = config[\"extensions\"]\n\n    results = []\n    disabled = []\n\n    for ext in tqdm.tqdm(extensions.extensions):\n        if ext.is_builtin:\n            continue\n\n        ext.read_info_from_repo()\n        current_commit = ext.commit_hash\n\n        if ext.name not in ext_config:\n            ext.disabled = True\n            disabled.append(ext.name)\n            results.append((ext, current_commit[:8], False, \"Saved extension state not found in config, marking as disabled\"))\n            continue\n\n        entry = ext_config[ext.name]\n\n        if \"commit_hash\" in entry and entry[\"commit_hash\"]:\n            try:\n                ext.fetch_and_reset_hard(entry[\"commit_hash\"])\n                ext.read_info_from_repo()\n                if current_commit != entry[\"commit_hash\"]:\n                    results.append((ext, current_commit[:8], True, entry[\"commit_hash\"][:8]))\n            except Exception as ex:\n                results.append((ext, current_commit[:8], False, ex))\n        else:\n            results.append((ext, current_commit[:8], False, \"No commit hash found in config\"))\n\n        if not entry.get(\"enabled\", False):\n            ext.disabled = True\n            disabled.append(ext.name)\n        else:\n            ext.disabled = False\n\n    shared.opts.disabled_extensions = disabled\n    shared.opts.save(shared.config_filename)\n\n    print(\"* Finished restoring extensions. Results:\")\n    for ext, prev_commit, success, result in results:\n        if success:\n            print(f\"  + {ext.name}: {prev_commit} -> {result}\")\n        else:\n            print(f\"  ! {ext.name}: FAILURE ({result})\")\n"
  },
  {
    "path": "modules/dat_model.py",
    "content": "import os\n\nfrom modules import modelloader, errors\nfrom modules.shared import cmd_opts, opts\nfrom modules.upscaler import Upscaler, UpscalerData\nfrom modules.upscaler_utils import upscale_with_model\n\n\nclass UpscalerDAT(Upscaler):\n    def __init__(self, user_path):\n        self.name = \"DAT\"\n        self.user_path = user_path\n        self.scalers = []\n        super().__init__()\n\n        for file in self.find_models(ext_filter=[\".pt\", \".pth\"]):\n            name = modelloader.friendly_name(file)\n            scaler_data = UpscalerData(name, file, upscaler=self, scale=None)\n            self.scalers.append(scaler_data)\n\n        for model in get_dat_models(self):\n            if model.name in opts.dat_enabled_models:\n                self.scalers.append(model)\n\n    def do_upscale(self, img, path):\n        try:\n            info = self.load_model(path)\n        except Exception:\n            errors.report(f\"Unable to load DAT model {path}\", exc_info=True)\n            return img\n\n        model_descriptor = modelloader.load_spandrel_model(\n            info.local_data_path,\n            device=self.device,\n            prefer_half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),\n            expected_architecture=\"DAT\",\n        )\n        return upscale_with_model(\n            model_descriptor,\n            img,\n            tile_size=opts.DAT_tile,\n            tile_overlap=opts.DAT_tile_overlap,\n        )\n\n    def load_model(self, path):\n        for scaler in self.scalers:\n            if scaler.data_path == path:\n                if scaler.local_data_path.startswith(\"http\"):\n                    scaler.local_data_path = modelloader.load_file_from_url(\n                        scaler.data_path,\n                        model_dir=self.model_download_path,\n                    )\n                if not os.path.exists(scaler.local_data_path):\n                    raise FileNotFoundError(f\"DAT data missing: {scaler.local_data_path}\")\n                return scaler\n        raise ValueError(f\"Unable to find model info: {path}\")\n\n\ndef get_dat_models(scaler):\n    return [\n        UpscalerData(\n            name=\"DAT x2\",\n            path=\"https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x2.pth\",\n            scale=2,\n            upscaler=scaler,\n        ),\n        UpscalerData(\n            name=\"DAT x3\",\n            path=\"https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x3.pth\",\n            scale=3,\n            upscaler=scaler,\n        ),\n        UpscalerData(\n            name=\"DAT x4\",\n            path=\"https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x4.pth\",\n            scale=4,\n            upscaler=scaler,\n        ),\n    ]\n"
  },
  {
    "path": "modules/deepbooru.py",
    "content": "import os\nimport re\n\nimport torch\nimport numpy as np\n\nfrom modules import modelloader, paths, deepbooru_model, devices, images, shared\n\nre_special = re.compile(r'([\\\\()])')\n\n\nclass DeepDanbooru:\n    def __init__(self):\n        self.model = None\n\n    def load(self):\n        if self.model is not None:\n            return\n\n        files = modelloader.load_models(\n            model_path=os.path.join(paths.models_path, \"torch_deepdanbooru\"),\n            model_url='https://github.com/AUTOMATIC1111/TorchDeepDanbooru/releases/download/v1/model-resnet_custom_v3.pt',\n            ext_filter=[\".pt\"],\n            download_name='model-resnet_custom_v3.pt',\n        )\n\n        self.model = deepbooru_model.DeepDanbooruModel()\n        self.model.load_state_dict(torch.load(files[0], map_location=\"cpu\"))\n\n        self.model.eval()\n        self.model.to(devices.cpu, devices.dtype)\n\n    def start(self):\n        self.load()\n        self.model.to(devices.device)\n\n    def stop(self):\n        if not shared.opts.interrogate_keep_models_in_memory:\n            self.model.to(devices.cpu)\n            devices.torch_gc()\n\n    def tag(self, pil_image):\n        self.start()\n        res = self.tag_multi(pil_image)\n        self.stop()\n\n        return res\n\n    def tag_multi(self, pil_image, force_disable_ranks=False):\n        threshold = shared.opts.interrogate_deepbooru_score_threshold\n        use_spaces = shared.opts.deepbooru_use_spaces\n        use_escape = shared.opts.deepbooru_escape\n        alpha_sort = shared.opts.deepbooru_sort_alpha\n        include_ranks = shared.opts.interrogate_return_ranks and not force_disable_ranks\n\n        pic = images.resize_image(2, pil_image.convert(\"RGB\"), 512, 512)\n        a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255\n\n        with torch.no_grad(), devices.autocast():\n            x = torch.from_numpy(a).to(devices.device, devices.dtype)\n            y = self.model(x)[0].detach().cpu().numpy()\n\n        probability_dict = {}\n\n        for tag, probability in zip(self.model.tags, y):\n            if probability < threshold:\n                continue\n\n            if tag.startswith(\"rating:\"):\n                continue\n\n            probability_dict[tag] = probability\n\n        if alpha_sort:\n            tags = sorted(probability_dict)\n        else:\n            tags = [tag for tag, _ in sorted(probability_dict.items(), key=lambda x: -x[1])]\n\n        res = []\n\n        filtertags = {x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(\",\")}\n\n        for tag in [x for x in tags if x not in filtertags]:\n            probability = probability_dict[tag]\n            tag_outformat = tag\n            if use_spaces:\n                tag_outformat = tag_outformat.replace('_', ' ')\n            if use_escape:\n                tag_outformat = re.sub(re_special, r'\\\\\\1', tag_outformat)\n            if include_ranks:\n                tag_outformat = f\"({tag_outformat}:{probability:.3f})\"\n\n            res.append(tag_outformat)\n\n        return \", \".join(res)\n\n\nmodel = DeepDanbooru()\n"
  },
  {
    "path": "modules/deepbooru_model.py",
    "content": "import torch\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\n\r\nfrom modules import devices\r\n\r\n# see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more\r\n\r\n\r\nclass DeepDanbooruModel(nn.Module):\r\n    def __init__(self):\r\n        super(DeepDanbooruModel, self).__init__()\r\n\r\n        self.tags = []\r\n\r\n        self.n_Conv_0 = nn.Conv2d(kernel_size=(7, 7), in_channels=3, out_channels=64, stride=(2, 2))\r\n        self.n_MaxPool_0 = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2))\r\n        self.n_Conv_1 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)\r\n        self.n_Conv_2 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=64)\r\n        self.n_Conv_3 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)\r\n        self.n_Conv_4 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)\r\n        self.n_Conv_5 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)\r\n        self.n_Conv_6 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)\r\n        self.n_Conv_7 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)\r\n        self.n_Conv_8 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)\r\n        self.n_Conv_9 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)\r\n        self.n_Conv_10 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)\r\n        self.n_Conv_11 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=512, stride=(2, 2))\r\n        self.n_Conv_12 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=128)\r\n        self.n_Conv_13 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128, stride=(2, 2))\r\n        self.n_Conv_14 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)\r\n        self.n_Conv_15 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)\r\n        self.n_Conv_16 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)\r\n        self.n_Conv_17 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)\r\n        self.n_Conv_18 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)\r\n        self.n_Conv_19 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)\r\n        self.n_Conv_20 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)\r\n        self.n_Conv_21 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)\r\n        self.n_Conv_22 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)\r\n        self.n_Conv_23 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)\r\n        self.n_Conv_24 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)\r\n        self.n_Conv_25 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)\r\n        self.n_Conv_26 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)\r\n        self.n_Conv_27 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)\r\n        self.n_Conv_28 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)\r\n        self.n_Conv_29 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)\r\n        self.n_Conv_30 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)\r\n        self.n_Conv_31 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)\r\n        self.n_Conv_32 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)\r\n        self.n_Conv_33 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)\r\n        self.n_Conv_34 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)\r\n        self.n_Conv_35 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)\r\n        self.n_Conv_36 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=1024, stride=(2, 2))\r\n        self.n_Conv_37 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=256)\r\n        self.n_Conv_38 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))\r\n        self.n_Conv_39 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)\r\n        self.n_Conv_40 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)\r\n        self.n_Conv_41 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)\r\n        self.n_Conv_42 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)\r\n        self.n_Conv_43 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)\r\n        self.n_Conv_44 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)\r\n        self.n_Conv_45 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)\r\n        self.n_Conv_46 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)\r\n        self.n_Conv_47 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)\r\n        self.n_Conv_48 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)\r\n        self.n_Conv_49 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)\r\n        self.n_Conv_50 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)\r\n        self.n_Conv_51 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)\r\n        self.n_Conv_52 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)\r\n        self.n_Conv_53 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)\r\n        self.n_Conv_54 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)\r\n        self.n_Conv_55 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)\r\n        self.n_Conv_56 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)\r\n        self.n_Conv_57 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)\r\n        self.n_Conv_58 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)\r\n        self.n_Conv_59 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)\r\n        self.n_Conv_60 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)\r\n        self.n_Conv_61 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)\r\n        self.n_Conv_62 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)\r\n        self.n_Conv_63 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)\r\n        self.n_Conv_64 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)\r\n        self.n_Conv_65 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)\r\n        self.n_Conv_66 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)\r\n        self.n_Conv_67 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)\r\n        self.n_Conv_68 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)\r\n        self.n_Conv_69 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)\r\n        self.n_Conv_70 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)\r\n        self.n_Conv_71 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)\r\n        self.n_Conv_72 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)\r\n        self.n_Conv_73 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)\r\n        self.n_Conv_74 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)\r\n        self.n_Conv_75 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)\r\n        self.n_Conv_76 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)\r\n        self.n_Conv_77 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)\r\n        self.n_Conv_78 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)\r\n        self.n_Conv_79 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)\r\n        self.n_Conv_80 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)\r\n        self.n_Conv_81 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)\r\n        self.n_Conv_82 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)\r\n        self.n_Conv_83 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)\r\n        self.n_Conv_84 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)\r\n        self.n_Conv_85 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)\r\n        self.n_Conv_86 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)\r\n        self.n_Conv_87 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)\r\n        self.n_Conv_88 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)\r\n        self.n_Conv_89 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)\r\n        self.n_Conv_90 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)\r\n        self.n_Conv_91 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)\r\n        self.n_Conv_92 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)\r\n        self.n_Conv_93 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)\r\n        self.n_Conv_94 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)\r\n        self.n_Conv_95 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)\r\n        self.n_Conv_96 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)\r\n        self.n_Conv_97 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)\r\n        self.n_Conv_98 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))\r\n        self.n_Conv_99 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)\r\n        self.n_Conv_100 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=1024, stride=(2, 2))\r\n        self.n_Conv_101 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)\r\n        self.n_Conv_102 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)\r\n        self.n_Conv_103 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)\r\n        self.n_Conv_104 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)\r\n        self.n_Conv_105 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)\r\n        self.n_Conv_106 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)\r\n        self.n_Conv_107 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)\r\n        self.n_Conv_108 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)\r\n        self.n_Conv_109 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)\r\n        self.n_Conv_110 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)\r\n        self.n_Conv_111 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)\r\n        self.n_Conv_112 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)\r\n        self.n_Conv_113 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)\r\n        self.n_Conv_114 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)\r\n        self.n_Conv_115 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)\r\n        self.n_Conv_116 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)\r\n        self.n_Conv_117 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)\r\n        self.n_Conv_118 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)\r\n        self.n_Conv_119 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)\r\n        self.n_Conv_120 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)\r\n        self.n_Conv_121 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)\r\n        self.n_Conv_122 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)\r\n        self.n_Conv_123 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)\r\n        self.n_Conv_124 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)\r\n        self.n_Conv_125 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)\r\n        self.n_Conv_126 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)\r\n        self.n_Conv_127 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)\r\n        self.n_Conv_128 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)\r\n        self.n_Conv_129 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)\r\n        self.n_Conv_130 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)\r\n        self.n_Conv_131 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)\r\n        self.n_Conv_132 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)\r\n        self.n_Conv_133 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)\r\n        self.n_Conv_134 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)\r\n        self.n_Conv_135 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)\r\n        self.n_Conv_136 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)\r\n        self.n_Conv_137 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)\r\n        self.n_Conv_138 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)\r\n        self.n_Conv_139 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)\r\n        self.n_Conv_140 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)\r\n        self.n_Conv_141 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)\r\n        self.n_Conv_142 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)\r\n        self.n_Conv_143 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)\r\n        self.n_Conv_144 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)\r\n        self.n_Conv_145 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)\r\n        self.n_Conv_146 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)\r\n        self.n_Conv_147 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)\r\n        self.n_Conv_148 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)\r\n        self.n_Conv_149 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)\r\n        self.n_Conv_150 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)\r\n        self.n_Conv_151 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)\r\n        self.n_Conv_152 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)\r\n        self.n_Conv_153 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)\r\n        self.n_Conv_154 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)\r\n        self.n_Conv_155 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)\r\n        self.n_Conv_156 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)\r\n        self.n_Conv_157 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)\r\n        self.n_Conv_158 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=2048, stride=(2, 2))\r\n        self.n_Conv_159 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=512)\r\n        self.n_Conv_160 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512, stride=(2, 2))\r\n        self.n_Conv_161 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)\r\n        self.n_Conv_162 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)\r\n        self.n_Conv_163 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)\r\n        self.n_Conv_164 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)\r\n        self.n_Conv_165 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)\r\n        self.n_Conv_166 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)\r\n        self.n_Conv_167 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)\r\n        self.n_Conv_168 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=4096, stride=(2, 2))\r\n        self.n_Conv_169 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=1024)\r\n        self.n_Conv_170 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024, stride=(2, 2))\r\n        self.n_Conv_171 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)\r\n        self.n_Conv_172 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)\r\n        self.n_Conv_173 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)\r\n        self.n_Conv_174 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)\r\n        self.n_Conv_175 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)\r\n        self.n_Conv_176 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)\r\n        self.n_Conv_177 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)\r\n        self.n_Conv_178 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=9176, bias=False)\r\n\r\n    def forward(self, *inputs):\r\n        t_358, = inputs\r\n        t_359 = t_358.permute(*[0, 3, 1, 2])\r\n        t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0)\r\n        t_360 = self.n_Conv_0(t_359_padded.to(self.n_Conv_0.bias.dtype) if devices.unet_needs_upcast else t_359_padded)\r\n        t_361 = F.relu(t_360)\r\n        t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf'))\r\n        t_362 = self.n_MaxPool_0(t_361)\r\n        t_363 = self.n_Conv_1(t_362)\r\n        t_364 = self.n_Conv_2(t_362)\r\n        t_365 = F.relu(t_364)\r\n        t_365_padded = F.pad(t_365, [1, 1, 1, 1], value=0)\r\n        t_366 = self.n_Conv_3(t_365_padded)\r\n        t_367 = F.relu(t_366)\r\n        t_368 = self.n_Conv_4(t_367)\r\n        t_369 = torch.add(t_368, t_363)\r\n        t_370 = F.relu(t_369)\r\n        t_371 = self.n_Conv_5(t_370)\r\n        t_372 = F.relu(t_371)\r\n        t_372_padded = F.pad(t_372, [1, 1, 1, 1], value=0)\r\n        t_373 = self.n_Conv_6(t_372_padded)\r\n        t_374 = F.relu(t_373)\r\n        t_375 = self.n_Conv_7(t_374)\r\n        t_376 = torch.add(t_375, t_370)\r\n        t_377 = F.relu(t_376)\r\n        t_378 = self.n_Conv_8(t_377)\r\n        t_379 = F.relu(t_378)\r\n        t_379_padded = F.pad(t_379, [1, 1, 1, 1], value=0)\r\n        t_380 = self.n_Conv_9(t_379_padded)\r\n        t_381 = F.relu(t_380)\r\n        t_382 = self.n_Conv_10(t_381)\r\n        t_383 = torch.add(t_382, t_377)\r\n        t_384 = F.relu(t_383)\r\n        t_385 = self.n_Conv_11(t_384)\r\n        t_386 = self.n_Conv_12(t_384)\r\n        t_387 = F.relu(t_386)\r\n        t_387_padded = F.pad(t_387, [0, 1, 0, 1], value=0)\r\n        t_388 = self.n_Conv_13(t_387_padded)\r\n        t_389 = F.relu(t_388)\r\n        t_390 = self.n_Conv_14(t_389)\r\n        t_391 = torch.add(t_390, t_385)\r\n        t_392 = F.relu(t_391)\r\n        t_393 = self.n_Conv_15(t_392)\r\n        t_394 = F.relu(t_393)\r\n        t_394_padded = F.pad(t_394, [1, 1, 1, 1], value=0)\r\n        t_395 = self.n_Conv_16(t_394_padded)\r\n        t_396 = F.relu(t_395)\r\n        t_397 = self.n_Conv_17(t_396)\r\n        t_398 = torch.add(t_397, t_392)\r\n        t_399 = F.relu(t_398)\r\n        t_400 = self.n_Conv_18(t_399)\r\n        t_401 = F.relu(t_400)\r\n        t_401_padded = F.pad(t_401, [1, 1, 1, 1], value=0)\r\n        t_402 = self.n_Conv_19(t_401_padded)\r\n        t_403 = F.relu(t_402)\r\n        t_404 = self.n_Conv_20(t_403)\r\n        t_405 = torch.add(t_404, t_399)\r\n        t_406 = F.relu(t_405)\r\n        t_407 = self.n_Conv_21(t_406)\r\n        t_408 = F.relu(t_407)\r\n        t_408_padded = F.pad(t_408, [1, 1, 1, 1], value=0)\r\n        t_409 = self.n_Conv_22(t_408_padded)\r\n        t_410 = F.relu(t_409)\r\n        t_411 = self.n_Conv_23(t_410)\r\n        t_412 = torch.add(t_411, t_406)\r\n        t_413 = F.relu(t_412)\r\n        t_414 = self.n_Conv_24(t_413)\r\n        t_415 = F.relu(t_414)\r\n        t_415_padded = F.pad(t_415, [1, 1, 1, 1], value=0)\r\n        t_416 = self.n_Conv_25(t_415_padded)\r\n        t_417 = F.relu(t_416)\r\n        t_418 = self.n_Conv_26(t_417)\r\n        t_419 = torch.add(t_418, t_413)\r\n        t_420 = F.relu(t_419)\r\n        t_421 = self.n_Conv_27(t_420)\r\n        t_422 = F.relu(t_421)\r\n        t_422_padded = F.pad(t_422, [1, 1, 1, 1], value=0)\r\n        t_423 = self.n_Conv_28(t_422_padded)\r\n        t_424 = F.relu(t_423)\r\n        t_425 = self.n_Conv_29(t_424)\r\n        t_426 = torch.add(t_425, t_420)\r\n        t_427 = F.relu(t_426)\r\n        t_428 = self.n_Conv_30(t_427)\r\n        t_429 = F.relu(t_428)\r\n        t_429_padded = F.pad(t_429, [1, 1, 1, 1], value=0)\r\n        t_430 = self.n_Conv_31(t_429_padded)\r\n        t_431 = F.relu(t_430)\r\n        t_432 = self.n_Conv_32(t_431)\r\n        t_433 = torch.add(t_432, t_427)\r\n        t_434 = F.relu(t_433)\r\n        t_435 = self.n_Conv_33(t_434)\r\n        t_436 = F.relu(t_435)\r\n        t_436_padded = F.pad(t_436, [1, 1, 1, 1], value=0)\r\n        t_437 = self.n_Conv_34(t_436_padded)\r\n        t_438 = F.relu(t_437)\r\n        t_439 = self.n_Conv_35(t_438)\r\n        t_440 = torch.add(t_439, t_434)\r\n        t_441 = F.relu(t_440)\r\n        t_442 = self.n_Conv_36(t_441)\r\n        t_443 = self.n_Conv_37(t_441)\r\n        t_444 = F.relu(t_443)\r\n        t_444_padded = F.pad(t_444, [0, 1, 0, 1], value=0)\r\n        t_445 = self.n_Conv_38(t_444_padded)\r\n        t_446 = F.relu(t_445)\r\n        t_447 = self.n_Conv_39(t_446)\r\n        t_448 = torch.add(t_447, t_442)\r\n        t_449 = F.relu(t_448)\r\n        t_450 = self.n_Conv_40(t_449)\r\n        t_451 = F.relu(t_450)\r\n        t_451_padded = F.pad(t_451, [1, 1, 1, 1], value=0)\r\n        t_452 = self.n_Conv_41(t_451_padded)\r\n        t_453 = F.relu(t_452)\r\n        t_454 = self.n_Conv_42(t_453)\r\n        t_455 = torch.add(t_454, t_449)\r\n        t_456 = F.relu(t_455)\r\n        t_457 = self.n_Conv_43(t_456)\r\n        t_458 = F.relu(t_457)\r\n        t_458_padded = F.pad(t_458, [1, 1, 1, 1], value=0)\r\n        t_459 = self.n_Conv_44(t_458_padded)\r\n        t_460 = F.relu(t_459)\r\n        t_461 = self.n_Conv_45(t_460)\r\n        t_462 = torch.add(t_461, t_456)\r\n        t_463 = F.relu(t_462)\r\n        t_464 = self.n_Conv_46(t_463)\r\n        t_465 = F.relu(t_464)\r\n        t_465_padded = F.pad(t_465, [1, 1, 1, 1], value=0)\r\n        t_466 = self.n_Conv_47(t_465_padded)\r\n        t_467 = F.relu(t_466)\r\n        t_468 = self.n_Conv_48(t_467)\r\n        t_469 = torch.add(t_468, t_463)\r\n        t_470 = F.relu(t_469)\r\n        t_471 = self.n_Conv_49(t_470)\r\n        t_472 = F.relu(t_471)\r\n        t_472_padded = F.pad(t_472, [1, 1, 1, 1], value=0)\r\n        t_473 = self.n_Conv_50(t_472_padded)\r\n        t_474 = F.relu(t_473)\r\n        t_475 = self.n_Conv_51(t_474)\r\n        t_476 = torch.add(t_475, t_470)\r\n        t_477 = F.relu(t_476)\r\n        t_478 = self.n_Conv_52(t_477)\r\n        t_479 = F.relu(t_478)\r\n        t_479_padded = F.pad(t_479, [1, 1, 1, 1], value=0)\r\n        t_480 = self.n_Conv_53(t_479_padded)\r\n        t_481 = F.relu(t_480)\r\n        t_482 = self.n_Conv_54(t_481)\r\n        t_483 = torch.add(t_482, t_477)\r\n        t_484 = F.relu(t_483)\r\n        t_485 = self.n_Conv_55(t_484)\r\n        t_486 = F.relu(t_485)\r\n        t_486_padded = F.pad(t_486, [1, 1, 1, 1], value=0)\r\n        t_487 = self.n_Conv_56(t_486_padded)\r\n        t_488 = F.relu(t_487)\r\n        t_489 = self.n_Conv_57(t_488)\r\n        t_490 = torch.add(t_489, t_484)\r\n        t_491 = F.relu(t_490)\r\n        t_492 = self.n_Conv_58(t_491)\r\n        t_493 = F.relu(t_492)\r\n        t_493_padded = F.pad(t_493, [1, 1, 1, 1], value=0)\r\n        t_494 = self.n_Conv_59(t_493_padded)\r\n        t_495 = F.relu(t_494)\r\n        t_496 = self.n_Conv_60(t_495)\r\n        t_497 = torch.add(t_496, t_491)\r\n        t_498 = F.relu(t_497)\r\n        t_499 = self.n_Conv_61(t_498)\r\n        t_500 = F.relu(t_499)\r\n        t_500_padded = F.pad(t_500, [1, 1, 1, 1], value=0)\r\n        t_501 = self.n_Conv_62(t_500_padded)\r\n        t_502 = F.relu(t_501)\r\n        t_503 = self.n_Conv_63(t_502)\r\n        t_504 = torch.add(t_503, t_498)\r\n        t_505 = F.relu(t_504)\r\n        t_506 = self.n_Conv_64(t_505)\r\n        t_507 = F.relu(t_506)\r\n        t_507_padded = F.pad(t_507, [1, 1, 1, 1], value=0)\r\n        t_508 = self.n_Conv_65(t_507_padded)\r\n        t_509 = F.relu(t_508)\r\n        t_510 = self.n_Conv_66(t_509)\r\n        t_511 = torch.add(t_510, t_505)\r\n        t_512 = F.relu(t_511)\r\n        t_513 = self.n_Conv_67(t_512)\r\n        t_514 = F.relu(t_513)\r\n        t_514_padded = F.pad(t_514, [1, 1, 1, 1], value=0)\r\n        t_515 = self.n_Conv_68(t_514_padded)\r\n        t_516 = F.relu(t_515)\r\n        t_517 = self.n_Conv_69(t_516)\r\n        t_518 = torch.add(t_517, t_512)\r\n        t_519 = F.relu(t_518)\r\n        t_520 = self.n_Conv_70(t_519)\r\n        t_521 = F.relu(t_520)\r\n        t_521_padded = F.pad(t_521, [1, 1, 1, 1], value=0)\r\n        t_522 = self.n_Conv_71(t_521_padded)\r\n        t_523 = F.relu(t_522)\r\n        t_524 = self.n_Conv_72(t_523)\r\n        t_525 = torch.add(t_524, t_519)\r\n        t_526 = F.relu(t_525)\r\n        t_527 = self.n_Conv_73(t_526)\r\n        t_528 = F.relu(t_527)\r\n        t_528_padded = F.pad(t_528, [1, 1, 1, 1], value=0)\r\n        t_529 = self.n_Conv_74(t_528_padded)\r\n        t_530 = F.relu(t_529)\r\n        t_531 = self.n_Conv_75(t_530)\r\n        t_532 = torch.add(t_531, t_526)\r\n        t_533 = F.relu(t_532)\r\n        t_534 = self.n_Conv_76(t_533)\r\n        t_535 = F.relu(t_534)\r\n        t_535_padded = F.pad(t_535, [1, 1, 1, 1], value=0)\r\n        t_536 = self.n_Conv_77(t_535_padded)\r\n        t_537 = F.relu(t_536)\r\n        t_538 = self.n_Conv_78(t_537)\r\n        t_539 = torch.add(t_538, t_533)\r\n        t_540 = F.relu(t_539)\r\n        t_541 = self.n_Conv_79(t_540)\r\n        t_542 = F.relu(t_541)\r\n        t_542_padded = F.pad(t_542, [1, 1, 1, 1], value=0)\r\n        t_543 = self.n_Conv_80(t_542_padded)\r\n        t_544 = F.relu(t_543)\r\n        t_545 = self.n_Conv_81(t_544)\r\n        t_546 = torch.add(t_545, t_540)\r\n        t_547 = F.relu(t_546)\r\n        t_548 = self.n_Conv_82(t_547)\r\n        t_549 = F.relu(t_548)\r\n        t_549_padded = F.pad(t_549, [1, 1, 1, 1], value=0)\r\n        t_550 = self.n_Conv_83(t_549_padded)\r\n        t_551 = F.relu(t_550)\r\n        t_552 = self.n_Conv_84(t_551)\r\n        t_553 = torch.add(t_552, t_547)\r\n        t_554 = F.relu(t_553)\r\n        t_555 = self.n_Conv_85(t_554)\r\n        t_556 = F.relu(t_555)\r\n        t_556_padded = F.pad(t_556, [1, 1, 1, 1], value=0)\r\n        t_557 = self.n_Conv_86(t_556_padded)\r\n        t_558 = F.relu(t_557)\r\n        t_559 = self.n_Conv_87(t_558)\r\n        t_560 = torch.add(t_559, t_554)\r\n        t_561 = F.relu(t_560)\r\n        t_562 = self.n_Conv_88(t_561)\r\n        t_563 = F.relu(t_562)\r\n        t_563_padded = F.pad(t_563, [1, 1, 1, 1], value=0)\r\n        t_564 = self.n_Conv_89(t_563_padded)\r\n        t_565 = F.relu(t_564)\r\n        t_566 = self.n_Conv_90(t_565)\r\n        t_567 = torch.add(t_566, t_561)\r\n        t_568 = F.relu(t_567)\r\n        t_569 = self.n_Conv_91(t_568)\r\n        t_570 = F.relu(t_569)\r\n        t_570_padded = F.pad(t_570, [1, 1, 1, 1], value=0)\r\n        t_571 = self.n_Conv_92(t_570_padded)\r\n        t_572 = F.relu(t_571)\r\n        t_573 = self.n_Conv_93(t_572)\r\n        t_574 = torch.add(t_573, t_568)\r\n        t_575 = F.relu(t_574)\r\n        t_576 = self.n_Conv_94(t_575)\r\n        t_577 = F.relu(t_576)\r\n        t_577_padded = F.pad(t_577, [1, 1, 1, 1], value=0)\r\n        t_578 = self.n_Conv_95(t_577_padded)\r\n        t_579 = F.relu(t_578)\r\n        t_580 = self.n_Conv_96(t_579)\r\n        t_581 = torch.add(t_580, t_575)\r\n        t_582 = F.relu(t_581)\r\n        t_583 = self.n_Conv_97(t_582)\r\n        t_584 = F.relu(t_583)\r\n        t_584_padded = F.pad(t_584, [0, 1, 0, 1], value=0)\r\n        t_585 = self.n_Conv_98(t_584_padded)\r\n        t_586 = F.relu(t_585)\r\n        t_587 = self.n_Conv_99(t_586)\r\n        t_588 = self.n_Conv_100(t_582)\r\n        t_589 = torch.add(t_587, t_588)\r\n        t_590 = F.relu(t_589)\r\n        t_591 = self.n_Conv_101(t_590)\r\n        t_592 = F.relu(t_591)\r\n        t_592_padded = F.pad(t_592, [1, 1, 1, 1], value=0)\r\n        t_593 = self.n_Conv_102(t_592_padded)\r\n        t_594 = F.relu(t_593)\r\n        t_595 = self.n_Conv_103(t_594)\r\n        t_596 = torch.add(t_595, t_590)\r\n        t_597 = F.relu(t_596)\r\n        t_598 = self.n_Conv_104(t_597)\r\n        t_599 = F.relu(t_598)\r\n        t_599_padded = F.pad(t_599, [1, 1, 1, 1], value=0)\r\n        t_600 = self.n_Conv_105(t_599_padded)\r\n        t_601 = F.relu(t_600)\r\n        t_602 = self.n_Conv_106(t_601)\r\n        t_603 = torch.add(t_602, t_597)\r\n        t_604 = F.relu(t_603)\r\n        t_605 = self.n_Conv_107(t_604)\r\n        t_606 = F.relu(t_605)\r\n        t_606_padded = F.pad(t_606, [1, 1, 1, 1], value=0)\r\n        t_607 = self.n_Conv_108(t_606_padded)\r\n        t_608 = F.relu(t_607)\r\n        t_609 = self.n_Conv_109(t_608)\r\n        t_610 = torch.add(t_609, t_604)\r\n        t_611 = F.relu(t_610)\r\n        t_612 = self.n_Conv_110(t_611)\r\n        t_613 = F.relu(t_612)\r\n        t_613_padded = F.pad(t_613, [1, 1, 1, 1], value=0)\r\n        t_614 = self.n_Conv_111(t_613_padded)\r\n        t_615 = F.relu(t_614)\r\n        t_616 = self.n_Conv_112(t_615)\r\n        t_617 = torch.add(t_616, t_611)\r\n        t_618 = F.relu(t_617)\r\n        t_619 = self.n_Conv_113(t_618)\r\n        t_620 = F.relu(t_619)\r\n        t_620_padded = F.pad(t_620, [1, 1, 1, 1], value=0)\r\n        t_621 = self.n_Conv_114(t_620_padded)\r\n        t_622 = F.relu(t_621)\r\n        t_623 = self.n_Conv_115(t_622)\r\n        t_624 = torch.add(t_623, t_618)\r\n        t_625 = F.relu(t_624)\r\n        t_626 = self.n_Conv_116(t_625)\r\n        t_627 = F.relu(t_626)\r\n        t_627_padded = F.pad(t_627, [1, 1, 1, 1], value=0)\r\n        t_628 = self.n_Conv_117(t_627_padded)\r\n        t_629 = F.relu(t_628)\r\n        t_630 = self.n_Conv_118(t_629)\r\n        t_631 = torch.add(t_630, t_625)\r\n        t_632 = F.relu(t_631)\r\n        t_633 = self.n_Conv_119(t_632)\r\n        t_634 = F.relu(t_633)\r\n        t_634_padded = F.pad(t_634, [1, 1, 1, 1], value=0)\r\n        t_635 = self.n_Conv_120(t_634_padded)\r\n        t_636 = F.relu(t_635)\r\n        t_637 = self.n_Conv_121(t_636)\r\n        t_638 = torch.add(t_637, t_632)\r\n        t_639 = F.relu(t_638)\r\n        t_640 = self.n_Conv_122(t_639)\r\n        t_641 = F.relu(t_640)\r\n        t_641_padded = F.pad(t_641, [1, 1, 1, 1], value=0)\r\n        t_642 = self.n_Conv_123(t_641_padded)\r\n        t_643 = F.relu(t_642)\r\n        t_644 = self.n_Conv_124(t_643)\r\n        t_645 = torch.add(t_644, t_639)\r\n        t_646 = F.relu(t_645)\r\n        t_647 = self.n_Conv_125(t_646)\r\n        t_648 = F.relu(t_647)\r\n        t_648_padded = F.pad(t_648, [1, 1, 1, 1], value=0)\r\n        t_649 = self.n_Conv_126(t_648_padded)\r\n        t_650 = F.relu(t_649)\r\n        t_651 = self.n_Conv_127(t_650)\r\n        t_652 = torch.add(t_651, t_646)\r\n        t_653 = F.relu(t_652)\r\n        t_654 = self.n_Conv_128(t_653)\r\n        t_655 = F.relu(t_654)\r\n        t_655_padded = F.pad(t_655, [1, 1, 1, 1], value=0)\r\n        t_656 = self.n_Conv_129(t_655_padded)\r\n        t_657 = F.relu(t_656)\r\n        t_658 = self.n_Conv_130(t_657)\r\n        t_659 = torch.add(t_658, t_653)\r\n        t_660 = F.relu(t_659)\r\n        t_661 = self.n_Conv_131(t_660)\r\n        t_662 = F.relu(t_661)\r\n        t_662_padded = F.pad(t_662, [1, 1, 1, 1], value=0)\r\n        t_663 = self.n_Conv_132(t_662_padded)\r\n        t_664 = F.relu(t_663)\r\n        t_665 = self.n_Conv_133(t_664)\r\n        t_666 = torch.add(t_665, t_660)\r\n        t_667 = F.relu(t_666)\r\n        t_668 = self.n_Conv_134(t_667)\r\n        t_669 = F.relu(t_668)\r\n        t_669_padded = F.pad(t_669, [1, 1, 1, 1], value=0)\r\n        t_670 = self.n_Conv_135(t_669_padded)\r\n        t_671 = F.relu(t_670)\r\n        t_672 = self.n_Conv_136(t_671)\r\n        t_673 = torch.add(t_672, t_667)\r\n        t_674 = F.relu(t_673)\r\n        t_675 = self.n_Conv_137(t_674)\r\n        t_676 = F.relu(t_675)\r\n        t_676_padded = F.pad(t_676, [1, 1, 1, 1], value=0)\r\n        t_677 = self.n_Conv_138(t_676_padded)\r\n        t_678 = F.relu(t_677)\r\n        t_679 = self.n_Conv_139(t_678)\r\n        t_680 = torch.add(t_679, t_674)\r\n        t_681 = F.relu(t_680)\r\n        t_682 = self.n_Conv_140(t_681)\r\n        t_683 = F.relu(t_682)\r\n        t_683_padded = F.pad(t_683, [1, 1, 1, 1], value=0)\r\n        t_684 = self.n_Conv_141(t_683_padded)\r\n        t_685 = F.relu(t_684)\r\n        t_686 = self.n_Conv_142(t_685)\r\n        t_687 = torch.add(t_686, t_681)\r\n        t_688 = F.relu(t_687)\r\n        t_689 = self.n_Conv_143(t_688)\r\n        t_690 = F.relu(t_689)\r\n        t_690_padded = F.pad(t_690, [1, 1, 1, 1], value=0)\r\n        t_691 = self.n_Conv_144(t_690_padded)\r\n        t_692 = F.relu(t_691)\r\n        t_693 = self.n_Conv_145(t_692)\r\n        t_694 = torch.add(t_693, t_688)\r\n        t_695 = F.relu(t_694)\r\n        t_696 = self.n_Conv_146(t_695)\r\n        t_697 = F.relu(t_696)\r\n        t_697_padded = F.pad(t_697, [1, 1, 1, 1], value=0)\r\n        t_698 = self.n_Conv_147(t_697_padded)\r\n        t_699 = F.relu(t_698)\r\n        t_700 = self.n_Conv_148(t_699)\r\n        t_701 = torch.add(t_700, t_695)\r\n        t_702 = F.relu(t_701)\r\n        t_703 = self.n_Conv_149(t_702)\r\n        t_704 = F.relu(t_703)\r\n        t_704_padded = F.pad(t_704, [1, 1, 1, 1], value=0)\r\n        t_705 = self.n_Conv_150(t_704_padded)\r\n        t_706 = F.relu(t_705)\r\n        t_707 = self.n_Conv_151(t_706)\r\n        t_708 = torch.add(t_707, t_702)\r\n        t_709 = F.relu(t_708)\r\n        t_710 = self.n_Conv_152(t_709)\r\n        t_711 = F.relu(t_710)\r\n        t_711_padded = F.pad(t_711, [1, 1, 1, 1], value=0)\r\n        t_712 = self.n_Conv_153(t_711_padded)\r\n        t_713 = F.relu(t_712)\r\n        t_714 = self.n_Conv_154(t_713)\r\n        t_715 = torch.add(t_714, t_709)\r\n        t_716 = F.relu(t_715)\r\n        t_717 = self.n_Conv_155(t_716)\r\n        t_718 = F.relu(t_717)\r\n        t_718_padded = F.pad(t_718, [1, 1, 1, 1], value=0)\r\n        t_719 = self.n_Conv_156(t_718_padded)\r\n        t_720 = F.relu(t_719)\r\n        t_721 = self.n_Conv_157(t_720)\r\n        t_722 = torch.add(t_721, t_716)\r\n        t_723 = F.relu(t_722)\r\n        t_724 = self.n_Conv_158(t_723)\r\n        t_725 = self.n_Conv_159(t_723)\r\n        t_726 = F.relu(t_725)\r\n        t_726_padded = F.pad(t_726, [0, 1, 0, 1], value=0)\r\n        t_727 = self.n_Conv_160(t_726_padded)\r\n        t_728 = F.relu(t_727)\r\n        t_729 = self.n_Conv_161(t_728)\r\n        t_730 = torch.add(t_729, t_724)\r\n        t_731 = F.relu(t_730)\r\n        t_732 = self.n_Conv_162(t_731)\r\n        t_733 = F.relu(t_732)\r\n        t_733_padded = F.pad(t_733, [1, 1, 1, 1], value=0)\r\n        t_734 = self.n_Conv_163(t_733_padded)\r\n        t_735 = F.relu(t_734)\r\n        t_736 = self.n_Conv_164(t_735)\r\n        t_737 = torch.add(t_736, t_731)\r\n        t_738 = F.relu(t_737)\r\n        t_739 = self.n_Conv_165(t_738)\r\n        t_740 = F.relu(t_739)\r\n        t_740_padded = F.pad(t_740, [1, 1, 1, 1], value=0)\r\n        t_741 = self.n_Conv_166(t_740_padded)\r\n        t_742 = F.relu(t_741)\r\n        t_743 = self.n_Conv_167(t_742)\r\n        t_744 = torch.add(t_743, t_738)\r\n        t_745 = F.relu(t_744)\r\n        t_746 = self.n_Conv_168(t_745)\r\n        t_747 = self.n_Conv_169(t_745)\r\n        t_748 = F.relu(t_747)\r\n        t_748_padded = F.pad(t_748, [0, 1, 0, 1], value=0)\r\n        t_749 = self.n_Conv_170(t_748_padded)\r\n        t_750 = F.relu(t_749)\r\n        t_751 = self.n_Conv_171(t_750)\r\n        t_752 = torch.add(t_751, t_746)\r\n        t_753 = F.relu(t_752)\r\n        t_754 = self.n_Conv_172(t_753)\r\n        t_755 = F.relu(t_754)\r\n        t_755_padded = F.pad(t_755, [1, 1, 1, 1], value=0)\r\n        t_756 = self.n_Conv_173(t_755_padded)\r\n        t_757 = F.relu(t_756)\r\n        t_758 = self.n_Conv_174(t_757)\r\n        t_759 = torch.add(t_758, t_753)\r\n        t_760 = F.relu(t_759)\r\n        t_761 = self.n_Conv_175(t_760)\r\n        t_762 = F.relu(t_761)\r\n        t_762_padded = F.pad(t_762, [1, 1, 1, 1], value=0)\r\n        t_763 = self.n_Conv_176(t_762_padded)\r\n        t_764 = F.relu(t_763)\r\n        t_765 = self.n_Conv_177(t_764)\r\n        t_766 = torch.add(t_765, t_760)\r\n        t_767 = F.relu(t_766)\r\n        t_768 = self.n_Conv_178(t_767)\r\n        t_769 = F.avg_pool2d(t_768, kernel_size=t_768.shape[-2:])\r\n        t_770 = torch.squeeze(t_769, 3)\r\n        t_770 = torch.squeeze(t_770, 2)\r\n        t_771 = torch.sigmoid(t_770)\r\n        return t_771\r\n\r\n    def load_state_dict(self, state_dict, **kwargs):\r\n        self.tags = state_dict.get('tags', [])\r\n\r\n        super(DeepDanbooruModel, self).load_state_dict({k: v for k, v in state_dict.items() if k != 'tags'})\r\n\r\n"
  },
  {
    "path": "modules/devices.py",
    "content": "import sys\nimport contextlib\nfrom functools import lru_cache\n\nimport torch\nfrom modules import errors, shared, npu_specific\n\nif sys.platform == \"darwin\":\n    from modules import mac_specific\n\nif shared.cmd_opts.use_ipex:\n    from modules import xpu_specific\n\n\ndef has_xpu() -> bool:\n    return shared.cmd_opts.use_ipex and xpu_specific.has_xpu\n\n\ndef has_mps() -> bool:\n    if sys.platform != \"darwin\":\n        return False\n    else:\n        return mac_specific.has_mps\n\n\ndef cuda_no_autocast(device_id=None) -> bool:\n    if device_id is None:\n        device_id = get_cuda_device_id()\n    return (\n        torch.cuda.get_device_capability(device_id) == (7, 5)\n        and torch.cuda.get_device_name(device_id).startswith(\"NVIDIA GeForce GTX 16\")\n    )\n\n\ndef get_cuda_device_id():\n    return (\n        int(shared.cmd_opts.device_id)\n        if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit()\n        else 0\n    ) or torch.cuda.current_device()\n\n\ndef get_cuda_device_string():\n    if shared.cmd_opts.device_id is not None:\n        return f\"cuda:{shared.cmd_opts.device_id}\"\n\n    return \"cuda\"\n\n\ndef get_optimal_device_name():\n    if torch.cuda.is_available():\n        return get_cuda_device_string()\n\n    if has_mps():\n        return \"mps\"\n\n    if has_xpu():\n        return xpu_specific.get_xpu_device_string()\n\n    if npu_specific.has_npu:\n        return npu_specific.get_npu_device_string()\n\n    return \"cpu\"\n\n\ndef get_optimal_device():\n    return torch.device(get_optimal_device_name())\n\n\ndef get_device_for(task):\n    if task in shared.cmd_opts.use_cpu or \"all\" in shared.cmd_opts.use_cpu:\n        return cpu\n\n    return get_optimal_device()\n\n\ndef torch_gc():\n\n    if torch.cuda.is_available():\n        with torch.cuda.device(get_cuda_device_string()):\n            torch.cuda.empty_cache()\n            torch.cuda.ipc_collect()\n\n    if has_mps():\n        mac_specific.torch_mps_gc()\n\n    if has_xpu():\n        xpu_specific.torch_xpu_gc()\n\n    if npu_specific.has_npu:\n        torch_npu_set_device()\n        npu_specific.torch_npu_gc()\n\n\ndef torch_npu_set_device():\n    # Work around due to bug in torch_npu, revert me after fixed, @see https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue\n    if npu_specific.has_npu:\n        torch.npu.set_device(0)\n\n\ndef enable_tf32():\n    if torch.cuda.is_available():\n\n        # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't\n        # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407\n        if cuda_no_autocast():\n            torch.backends.cudnn.benchmark = True\n\n        torch.backends.cuda.matmul.allow_tf32 = True\n        torch.backends.cudnn.allow_tf32 = True\n\n\nerrors.run(enable_tf32, \"Enabling TF32\")\n\ncpu: torch.device = torch.device(\"cpu\")\nfp8: bool = False\n# Force fp16 for all models in inference. No casting during inference.\n# This flag is controlled by \"--precision half\" command line arg.\nforce_fp16: bool = False\ndevice: torch.device = None\ndevice_interrogate: torch.device = None\ndevice_gfpgan: torch.device = None\ndevice_esrgan: torch.device = None\ndevice_codeformer: torch.device = None\ndtype: torch.dtype = torch.float16\ndtype_vae: torch.dtype = torch.float16\ndtype_unet: torch.dtype = torch.float16\ndtype_inference: torch.dtype = torch.float16\nunet_needs_upcast = False\n\n\ndef cond_cast_unet(input):\n    if force_fp16:\n        return input.to(torch.float16)\n    return input.to(dtype_unet) if unet_needs_upcast else input\n\n\ndef cond_cast_float(input):\n    return input.float() if unet_needs_upcast else input\n\n\nnv_rng = None\npatch_module_list = [\n    torch.nn.Linear,\n    torch.nn.Conv2d,\n    torch.nn.MultiheadAttention,\n    torch.nn.GroupNorm,\n    torch.nn.LayerNorm,\n]\n\n\ndef manual_cast_forward(target_dtype):\n    def forward_wrapper(self, *args, **kwargs):\n        if any(\n            isinstance(arg, torch.Tensor) and arg.dtype != target_dtype\n            for arg in args\n        ):\n            args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]\n            kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}\n\n        org_dtype = target_dtype\n        for param in self.parameters():\n            if param.dtype != target_dtype:\n                org_dtype = param.dtype\n                break\n\n        if org_dtype != target_dtype:\n            self.to(target_dtype)\n        result = self.org_forward(*args, **kwargs)\n        if org_dtype != target_dtype:\n            self.to(org_dtype)\n\n        if target_dtype != dtype_inference:\n            if isinstance(result, tuple):\n                result = tuple(\n                    i.to(dtype_inference)\n                    if isinstance(i, torch.Tensor)\n                    else i\n                    for i in result\n                )\n            elif isinstance(result, torch.Tensor):\n                result = result.to(dtype_inference)\n        return result\n    return forward_wrapper\n\n\n@contextlib.contextmanager\ndef manual_cast(target_dtype):\n    applied = False\n    for module_type in patch_module_list:\n        if hasattr(module_type, \"org_forward\"):\n            continue\n        applied = True\n        org_forward = module_type.forward\n        if module_type == torch.nn.MultiheadAttention:\n            module_type.forward = manual_cast_forward(torch.float32)\n        else:\n            module_type.forward = manual_cast_forward(target_dtype)\n        module_type.org_forward = org_forward\n    try:\n        yield None\n    finally:\n        if applied:\n            for module_type in patch_module_list:\n                if hasattr(module_type, \"org_forward\"):\n                    module_type.forward = module_type.org_forward\n                    delattr(module_type, \"org_forward\")\n\n\ndef autocast(disable=False):\n    if disable:\n        return contextlib.nullcontext()\n\n    if force_fp16:\n        # No casting during inference if force_fp16 is enabled.\n        # All tensor dtype conversion happens before inference.\n        return contextlib.nullcontext()\n\n    if fp8 and device==cpu:\n        return torch.autocast(\"cpu\", dtype=torch.bfloat16, enabled=True)\n\n    if fp8 and dtype_inference == torch.float32:\n        return manual_cast(dtype)\n\n    if dtype == torch.float32 or dtype_inference == torch.float32:\n        return contextlib.nullcontext()\n\n    if has_xpu() or has_mps() or cuda_no_autocast():\n        return manual_cast(dtype)\n\n    return torch.autocast(\"cuda\")\n\n\ndef without_autocast(disable=False):\n    return torch.autocast(\"cuda\", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()\n\n\nclass NansException(Exception):\n    pass\n\n\ndef test_for_nans(x, where):\n    if shared.cmd_opts.disable_nan_check:\n        return\n\n    if not torch.isnan(x[(0, ) * len(x.shape)]):\n        return\n\n    if where == \"unet\":\n        message = \"A tensor with NaNs was produced in Unet.\"\n\n        if not shared.cmd_opts.no_half:\n            message += \" This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try setting the \\\"Upcast cross attention layer to float32\\\" option in Settings > Stable Diffusion or using the --no-half commandline argument to fix this.\"\n\n    elif where == \"vae\":\n        message = \"A tensor with NaNs was produced in VAE.\"\n\n        if not shared.cmd_opts.no_half and not shared.cmd_opts.no_half_vae:\n            message += \" This could be because there's not enough precision to represent the picture. Try adding --no-half-vae commandline argument to fix this.\"\n    else:\n        message = \"A tensor with NaNs was produced.\"\n\n    message += \" Use --disable-nan-check commandline argument to disable this check.\"\n\n    raise NansException(message)\n\n\n@lru_cache\ndef first_time_calculation():\n    \"\"\"\n    just do any calculation with pytorch layers - the first time this is done it allocates about 700MB of memory and\n    spends about 2.7 seconds doing that, at least with NVidia.\n    \"\"\"\n\n    x = torch.zeros((1, 1)).to(device, dtype)\n    linear = torch.nn.Linear(1, 1).to(device, dtype)\n    linear(x)\n\n    x = torch.zeros((1, 1, 3, 3)).to(device, dtype)\n    conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)\n    conv2d(x)\n\n\ndef force_model_fp16():\n    \"\"\"\n    ldm and sgm has modules.diffusionmodules.util.GroupNorm32.forward, which\n    force conversion of input to float32. If force_fp16 is enabled, we need to\n    prevent this casting.\n    \"\"\"\n    assert force_fp16\n    import sgm.modules.diffusionmodules.util as sgm_util\n    import ldm.modules.diffusionmodules.util as ldm_util\n    sgm_util.GroupNorm32 = torch.nn.GroupNorm\n    ldm_util.GroupNorm32 = torch.nn.GroupNorm\n    print(\"ldm/sgm GroupNorm32 replaced with normal torch.nn.GroupNorm due to `--precision half`.\")\n"
  },
  {
    "path": "modules/errors.py",
    "content": "import sys\r\nimport textwrap\r\nimport traceback\r\n\r\n\r\nexception_records = []\r\n\r\n\r\ndef format_traceback(tb):\r\n    return [[f\"{x.filename}, line {x.lineno}, {x.name}\", x.line] for x in traceback.extract_tb(tb)]\r\n\r\n\r\ndef format_exception(e, tb):\r\n    return {\"exception\": str(e), \"traceback\": format_traceback(tb)}\r\n\r\n\r\ndef get_exceptions():\r\n    try:\r\n        return list(reversed(exception_records))\r\n    except Exception as e:\r\n        return str(e)\r\n\r\n\r\ndef record_exception():\r\n    _, e, tb = sys.exc_info()\r\n    if e is None:\r\n        return\r\n\r\n    if exception_records and exception_records[-1] == e:\r\n        return\r\n\r\n    exception_records.append(format_exception(e, tb))\r\n\r\n    if len(exception_records) > 5:\r\n        exception_records.pop(0)\r\n\r\n\r\ndef report(message: str, *, exc_info: bool = False) -> None:\r\n    \"\"\"\r\n    Print an error message to stderr, with optional traceback.\r\n    \"\"\"\r\n\r\n    record_exception()\r\n\r\n    for line in message.splitlines():\r\n        print(\"***\", line, file=sys.stderr)\r\n    if exc_info:\r\n        print(textwrap.indent(traceback.format_exc(), \"    \"), file=sys.stderr)\r\n        print(\"---\", file=sys.stderr)\r\n\r\n\r\ndef print_error_explanation(message):\r\n    record_exception()\r\n\r\n    lines = message.strip().split(\"\\n\")\r\n    max_len = max([len(x) for x in lines])\r\n\r\n    print('=' * max_len, file=sys.stderr)\r\n    for line in lines:\r\n        print(line, file=sys.stderr)\r\n    print('=' * max_len, file=sys.stderr)\r\n\r\n\r\ndef display(e: Exception, task, *, full_traceback=False):\r\n    record_exception()\r\n\r\n    print(f\"{task or 'error'}: {type(e).__name__}\", file=sys.stderr)\r\n    te = traceback.TracebackException.from_exception(e)\r\n    if full_traceback:\r\n        # include frames leading up to the try-catch block\r\n        te.stack = traceback.StackSummary(traceback.extract_stack()[:-2] + te.stack)\r\n    print(*te.format(), sep=\"\", file=sys.stderr)\r\n\r\n    message = str(e)\r\n    if \"copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])\" in message:\r\n        print_error_explanation(\"\"\"\r\nThe most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its config file.\r\nSee https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this.\r\n        \"\"\")\r\n\r\n\r\nalready_displayed = {}\r\n\r\n\r\ndef display_once(e: Exception, task):\r\n    record_exception()\r\n\r\n    if task in already_displayed:\r\n        return\r\n\r\n    display(e, task)\r\n\r\n    already_displayed[task] = 1\r\n\r\n\r\ndef run(code, task):\r\n    try:\r\n        code()\r\n    except Exception as e:\r\n        display(task, e)\r\n\r\n\r\ndef check_versions():\r\n    from packaging import version\r\n    from modules import shared\r\n\r\n    import torch\r\n    import gradio\r\n\r\n    expected_torch_version = \"2.1.2\"\r\n    expected_xformers_version = \"0.0.23.post1\"\r\n    expected_gradio_version = \"3.41.2\"\r\n\r\n    if version.parse(torch.__version__) < version.parse(expected_torch_version):\r\n        print_error_explanation(f\"\"\"\r\nYou are running torch {torch.__version__}.\r\nThe program is tested to work with torch {expected_torch_version}.\r\nTo reinstall the desired version, run with commandline flag --reinstall-torch.\r\nBeware that this will cause a lot of large files to be downloaded, as well as\r\nthere are reports of issues with training tab on the latest version.\r\n\r\nUse --skip-version-check commandline argument to disable this check.\r\n        \"\"\".strip())\r\n\r\n    if shared.xformers_available:\r\n        import xformers\r\n\r\n        if version.parse(xformers.__version__) < version.parse(expected_xformers_version):\r\n            print_error_explanation(f\"\"\"\r\nYou are running xformers {xformers.__version__}.\r\nThe program is tested to work with xformers {expected_xformers_version}.\r\nTo reinstall the desired version, run with commandline flag --reinstall-xformers.\r\n\r\nUse --skip-version-check commandline argument to disable this check.\r\n            \"\"\".strip())\r\n\r\n    if gradio.__version__ != expected_gradio_version:\r\n        print_error_explanation(f\"\"\"\r\nYou are running gradio {gradio.__version__}.\r\nThe program is designed to work with gradio {expected_gradio_version}.\r\nUsing a different version of gradio is extremely likely to break the program.\r\n\r\nReasons why you have the mismatched gradio version can be:\r\n  - you use --skip-install flag.\r\n  - you use webui.py to start the program instead of launch.py.\r\n  - an extension installs the incompatible gradio version.\r\n\r\nUse --skip-version-check commandline argument to disable this check.\r\n        \"\"\".strip())\r\n\r\n"
  },
  {
    "path": "modules/esrgan_model.py",
    "content": "from modules import modelloader, devices, errors\r\nfrom modules.shared import opts\r\nfrom modules.upscaler import Upscaler, UpscalerData\r\nfrom modules.upscaler_utils import upscale_with_model\r\n\r\n\r\nclass UpscalerESRGAN(Upscaler):\r\n    def __init__(self, dirname):\r\n        self.name = \"ESRGAN\"\r\n        self.model_url = \"https://github.com/cszn/KAIR/releases/download/v1.0/ESRGAN.pth\"\r\n        self.model_name = \"ESRGAN_4x\"\r\n        self.scalers = []\r\n        self.user_path = dirname\r\n        super().__init__()\r\n        model_paths = self.find_models(ext_filter=[\".pt\", \".pth\"])\r\n        scalers = []\r\n        if len(model_paths) == 0:\r\n            scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)\r\n            scalers.append(scaler_data)\r\n        for file in model_paths:\r\n            if file.startswith(\"http\"):\r\n                name = self.model_name\r\n            else:\r\n                name = modelloader.friendly_name(file)\r\n\r\n            scaler_data = UpscalerData(name, file, self, 4)\r\n            self.scalers.append(scaler_data)\r\n\r\n    def do_upscale(self, img, selected_model):\r\n        try:\r\n            model = self.load_model(selected_model)\r\n        except Exception:\r\n            errors.report(f\"Unable to load ESRGAN model {selected_model}\", exc_info=True)\r\n            return img\r\n        model.to(devices.device_esrgan)\r\n        return esrgan_upscale(model, img)\r\n\r\n    def load_model(self, path: str):\r\n        if path.startswith(\"http\"):\r\n            # TODO: this doesn't use `path` at all?\r\n            filename = modelloader.load_file_from_url(\r\n                url=self.model_url,\r\n                model_dir=self.model_download_path,\r\n                file_name=f\"{self.model_name}.pth\",\r\n            )\r\n        else:\r\n            filename = path\r\n\r\n        return modelloader.load_spandrel_model(\r\n            filename,\r\n            device=('cpu' if devices.device_esrgan.type == 'mps' else None),\r\n            expected_architecture='ESRGAN',\r\n        )\r\n\r\n\r\ndef esrgan_upscale(model, img):\r\n    return upscale_with_model(\r\n        model,\r\n        img,\r\n        tile_size=opts.ESRGAN_tile,\r\n        tile_overlap=opts.ESRGAN_tile_overlap,\r\n    )\r\n"
  },
  {
    "path": "modules/extensions.py",
    "content": "from __future__ import annotations\r\n\r\nimport configparser\r\nimport dataclasses\r\nimport os\r\nimport threading\r\nimport re\r\n\r\nfrom modules import shared, errors, cache, scripts\r\nfrom modules.gitpython_hack import Repo\r\nfrom modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path  # noqa: F401\r\n\r\nextensions: list[Extension] = []\r\nextension_paths: dict[str, Extension] = {}\r\nloaded_extensions: dict[str, Exception] = {}\r\n\r\n\r\nos.makedirs(extensions_dir, exist_ok=True)\r\n\r\n\r\ndef active():\r\n    if shared.cmd_opts.disable_all_extensions or shared.opts.disable_all_extensions == \"all\":\r\n        return []\r\n    elif shared.cmd_opts.disable_extra_extensions or shared.opts.disable_all_extensions == \"extra\":\r\n        return [x for x in extensions if x.enabled and x.is_builtin]\r\n    else:\r\n        return [x for x in extensions if x.enabled]\r\n\r\n\r\n@dataclasses.dataclass\r\nclass CallbackOrderInfo:\r\n    name: str\r\n    before: list\r\n    after: list\r\n\r\n\r\nclass ExtensionMetadata:\r\n    filename = \"metadata.ini\"\r\n    config: configparser.ConfigParser\r\n    canonical_name: str\r\n    requires: list\r\n\r\n    def __init__(self, path, canonical_name):\r\n        self.config = configparser.ConfigParser()\r\n\r\n        filepath = os.path.join(path, self.filename)\r\n        # `self.config.read()` will quietly swallow OSErrors (which FileNotFoundError is),\r\n        # so no need to check whether the file exists beforehand.\r\n        try:\r\n            self.config.read(filepath)\r\n        except Exception:\r\n            errors.report(f\"Error reading {self.filename} for extension {canonical_name}.\", exc_info=True)\r\n\r\n        self.canonical_name = self.config.get(\"Extension\", \"Name\", fallback=canonical_name)\r\n        self.canonical_name = canonical_name.lower().strip()\r\n\r\n        self.requires = None\r\n\r\n    def get_script_requirements(self, field, section, extra_section=None):\r\n        \"\"\"reads a list of requirements from the config; field is the name of the field in the ini file,\r\n        like Requires or Before, and section is the name of the [section] in the ini file; additionally,\r\n        reads more requirements from [extra_section] if specified.\"\"\"\r\n\r\n        x = self.config.get(section, field, fallback='')\r\n\r\n        if extra_section:\r\n            x = x + ', ' + self.config.get(extra_section, field, fallback='')\r\n\r\n        listed_requirements = self.parse_list(x.lower())\r\n        res = []\r\n\r\n        for requirement in listed_requirements:\r\n            loaded_requirements = (x for x in requirement.split(\"|\") if x in loaded_extensions)\r\n            relevant_requirement = next(loaded_requirements, requirement)\r\n            res.append(relevant_requirement)\r\n\r\n        return res\r\n\r\n    def parse_list(self, text):\r\n        \"\"\"converts a line from config (\"ext1 ext2, ext3  \") into a python list ([\"ext1\", \"ext2\", \"ext3\"])\"\"\"\r\n\r\n        if not text:\r\n            return []\r\n\r\n        # both \",\" and \" \" are accepted as separator\r\n        return [x for x in re.split(r\"[,\\s]+\", text.strip()) if x]\r\n\r\n    def list_callback_order_instructions(self):\r\n        for section in self.config.sections():\r\n            if not section.startswith(\"callbacks/\"):\r\n                continue\r\n\r\n            callback_name = section[10:]\r\n\r\n            if not callback_name.startswith(self.canonical_name):\r\n                errors.report(f\"Callback order section for extension {self.canonical_name} is referencing the wrong extension: {section}\")\r\n                continue\r\n\r\n            before = self.parse_list(self.config.get(section, 'Before', fallback=''))\r\n            after = self.parse_list(self.config.get(section, 'After', fallback=''))\r\n\r\n            yield CallbackOrderInfo(callback_name, before, after)\r\n\r\n\r\nclass Extension:\r\n    lock = threading.Lock()\r\n    cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version']\r\n    metadata: ExtensionMetadata\r\n\r\n    def __init__(self, name, path, enabled=True, is_builtin=False, metadata=None):\r\n        self.name = name\r\n        self.path = path\r\n        self.enabled = enabled\r\n        self.status = ''\r\n        self.can_update = False\r\n        self.is_builtin = is_builtin\r\n        self.commit_hash = ''\r\n        self.commit_date = None\r\n        self.version = ''\r\n        self.branch = None\r\n        self.remote = None\r\n        self.have_info_from_repo = False\r\n        self.metadata = metadata if metadata else ExtensionMetadata(self.path, name.lower())\r\n        self.canonical_name = metadata.canonical_name\r\n\r\n    def to_dict(self):\r\n        return {x: getattr(self, x) for x in self.cached_fields}\r\n\r\n    def from_dict(self, d):\r\n        for field in self.cached_fields:\r\n            setattr(self, field, d[field])\r\n\r\n    def read_info_from_repo(self):\r\n        if self.is_builtin or self.have_info_from_repo:\r\n            return\r\n\r\n        def read_from_repo():\r\n            with self.lock:\r\n                if self.have_info_from_repo:\r\n                    return\r\n\r\n                self.do_read_info_from_repo()\r\n\r\n                return self.to_dict()\r\n\r\n        try:\r\n            d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, \".git\"), read_from_repo)\r\n            self.from_dict(d)\r\n        except FileNotFoundError:\r\n            pass\r\n        self.status = 'unknown' if self.status == '' else self.status\r\n\r\n    def do_read_info_from_repo(self):\r\n        repo = None\r\n        try:\r\n            if os.path.exists(os.path.join(self.path, \".git\")):\r\n                repo = Repo(self.path)\r\n        except Exception:\r\n            errors.report(f\"Error reading github repository info from {self.path}\", exc_info=True)\r\n\r\n        if repo is None or repo.bare:\r\n            self.remote = None\r\n        else:\r\n            try:\r\n                self.remote = next(repo.remote().urls, None)\r\n                commit = repo.head.commit\r\n                self.commit_date = commit.committed_date\r\n                if repo.active_branch:\r\n                    self.branch = repo.active_branch.name\r\n                self.commit_hash = commit.hexsha\r\n                self.version = self.commit_hash[:8]\r\n\r\n            except Exception:\r\n                errors.report(f\"Failed reading extension data from Git repository ({self.name})\", exc_info=True)\r\n                self.remote = None\r\n\r\n        self.have_info_from_repo = True\r\n\r\n    def list_files(self, subdir, extension):\r\n        dirpath = os.path.join(self.path, subdir)\r\n        if not os.path.isdir(dirpath):\r\n            return []\r\n\r\n        res = []\r\n        for filename in sorted(os.listdir(dirpath)):\r\n            res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename)))\r\n\r\n        res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]\r\n\r\n        return res\r\n\r\n    def check_updates(self):\r\n        repo = Repo(self.path)\r\n        branch_name = f'{repo.remote().name}/{self.branch}'\r\n        for fetch in repo.remote().fetch(dry_run=True):\r\n            if self.branch and fetch.name != branch_name:\r\n                continue\r\n            if fetch.flags != fetch.HEAD_UPTODATE:\r\n                self.can_update = True\r\n                self.status = \"new commits\"\r\n                return\r\n\r\n        try:\r\n            origin = repo.rev_parse(branch_name)\r\n            if repo.head.commit != origin:\r\n                self.can_update = True\r\n                self.status = \"behind HEAD\"\r\n                return\r\n        except Exception:\r\n            self.can_update = False\r\n            self.status = \"unknown (remote error)\"\r\n            return\r\n\r\n        self.can_update = False\r\n        self.status = \"latest\"\r\n\r\n    def fetch_and_reset_hard(self, commit=None):\r\n        repo = Repo(self.path)\r\n        if commit is None:\r\n            commit = f'{repo.remote().name}/{self.branch}'\r\n        # Fix: `error: Your local changes to the following files would be overwritten by merge`,\r\n        # because WSL2 Docker set 755 file permissions instead of 644, this results to the error.\r\n        repo.git.fetch(all=True)\r\n        repo.git.reset(commit, hard=True)\r\n        self.have_info_from_repo = False\r\n\r\n\r\ndef list_extensions():\r\n    extensions.clear()\r\n    extension_paths.clear()\r\n    loaded_extensions.clear()\r\n\r\n    if shared.cmd_opts.disable_all_extensions:\r\n        print(\"*** \\\"--disable-all-extensions\\\" arg was used, will not load any extensions ***\")\r\n    elif shared.opts.disable_all_extensions == \"all\":\r\n        print(\"*** \\\"Disable all extensions\\\" option was set, will not load any extensions ***\")\r\n    elif shared.cmd_opts.disable_extra_extensions:\r\n        print(\"*** \\\"--disable-extra-extensions\\\" arg was used, will only load built-in extensions ***\")\r\n    elif shared.opts.disable_all_extensions == \"extra\":\r\n        print(\"*** \\\"Disable all extensions\\\" option was set, will only load built-in extensions ***\")\r\n\r\n\r\n    # scan through extensions directory and load metadata\r\n    for dirname in [extensions_builtin_dir, extensions_dir]:\r\n        if not os.path.isdir(dirname):\r\n            continue\r\n\r\n        for extension_dirname in sorted(os.listdir(dirname)):\r\n            path = os.path.join(dirname, extension_dirname)\r\n            if not os.path.isdir(path):\r\n                continue\r\n\r\n            canonical_name = extension_dirname\r\n            metadata = ExtensionMetadata(path, canonical_name)\r\n\r\n            # check for duplicated canonical names\r\n            already_loaded_extension = loaded_extensions.get(metadata.canonical_name)\r\n            if already_loaded_extension is not None:\r\n                errors.report(f'Duplicate canonical name \"{canonical_name}\" found in extensions \"{extension_dirname}\" and \"{already_loaded_extension.name}\". Former will be discarded.', exc_info=False)\r\n                continue\r\n\r\n            is_builtin = dirname == extensions_builtin_dir\r\n            extension = Extension(name=extension_dirname, path=path, enabled=extension_dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin, metadata=metadata)\r\n            extensions.append(extension)\r\n            extension_paths[extension.path] = extension\r\n            loaded_extensions[canonical_name] = extension\r\n\r\n    for extension in extensions:\r\n        extension.metadata.requires = extension.metadata.get_script_requirements(\"Requires\", \"Extension\")\r\n\r\n    # check for requirements\r\n    for extension in extensions:\r\n        if not extension.enabled:\r\n            continue\r\n\r\n        for req in extension.metadata.requires:\r\n            required_extension = loaded_extensions.get(req)\r\n            if required_extension is None:\r\n                errors.report(f'Extension \"{extension.name}\" requires \"{req}\" which is not installed.', exc_info=False)\r\n                continue\r\n\r\n            if not required_extension.enabled:\r\n                errors.report(f'Extension \"{extension.name}\" requires \"{required_extension.name}\" which is disabled.', exc_info=False)\r\n                continue\r\n\r\n\r\ndef find_extension(filename):\r\n    parentdir = os.path.dirname(os.path.realpath(filename))\r\n\r\n    while parentdir != filename:\r\n        extension = extension_paths.get(parentdir)\r\n        if extension is not None:\r\n            return extension\r\n\r\n        filename = parentdir\r\n        parentdir = os.path.dirname(filename)\r\n\r\n    return None\r\n\r\n"
  },
  {
    "path": "modules/extra_networks.py",
    "content": "import json\r\nimport os\r\nimport re\r\nimport logging\r\nfrom collections import defaultdict\r\n\r\nfrom modules import errors\r\n\r\nextra_network_registry = {}\r\nextra_network_aliases = {}\r\n\r\n\r\ndef initialize():\r\n    extra_network_registry.clear()\r\n    extra_network_aliases.clear()\r\n\r\n\r\ndef register_extra_network(extra_network):\r\n    extra_network_registry[extra_network.name] = extra_network\r\n\r\n\r\ndef register_extra_network_alias(extra_network, alias):\r\n    extra_network_aliases[alias] = extra_network\r\n\r\n\r\ndef register_default_extra_networks():\r\n    from modules.extra_networks_hypernet import ExtraNetworkHypernet\r\n    register_extra_network(ExtraNetworkHypernet())\r\n\r\n\r\nclass ExtraNetworkParams:\r\n    def __init__(self, items=None):\r\n        self.items = items or []\r\n        self.positional = []\r\n        self.named = {}\r\n\r\n        for item in self.items:\r\n            parts = item.split('=', 2) if isinstance(item, str) else [item]\r\n            if len(parts) == 2:\r\n                self.named[parts[0]] = parts[1]\r\n            else:\r\n                self.positional.append(item)\r\n\r\n    def __eq__(self, other):\r\n        return self.items == other.items\r\n\r\n\r\nclass ExtraNetwork:\r\n    def __init__(self, name):\r\n        self.name = name\r\n\r\n    def activate(self, p, params_list):\r\n        \"\"\"\r\n        Called by processing on every run. Whatever the extra network is meant to do should be activated here.\r\n        Passes arguments related to this extra network in params_list.\r\n        User passes arguments by specifying this in his prompt:\r\n\r\n        <name:arg1:arg2:arg3>\r\n\r\n        Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments\r\n        separated by colon.\r\n\r\n        Even if the user does not mention this ExtraNetwork in his prompt, the call will still be made, with empty params_list -\r\n        in this case, all effects of this extra networks should be disabled.\r\n\r\n        Can be called multiple times before deactivate() - each new call should override the previous call completely.\r\n\r\n        For example, if this ExtraNetwork's name is 'hypernet' and user's prompt is:\r\n\r\n        > \"1girl, <hypernet:agm:1.1> <extrasupernet:master:12:13:14> <hypernet:ray>\"\r\n\r\n        params_list will be:\r\n\r\n        [\r\n            ExtraNetworkParams(items=[\"agm\", \"1.1\"]),\r\n            ExtraNetworkParams(items=[\"ray\"])\r\n        ]\r\n\r\n        \"\"\"\r\n        raise NotImplementedError\r\n\r\n    def deactivate(self, p):\r\n        \"\"\"\r\n        Called at the end of processing for housekeeping. No need to do anything here.\r\n        \"\"\"\r\n\r\n        raise NotImplementedError\r\n\r\n\r\ndef lookup_extra_networks(extra_network_data):\r\n    \"\"\"returns a dict mapping ExtraNetwork objects to lists of arguments for those extra networks.\r\n\r\n    Example input:\r\n    {\r\n        'lora': [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58310>],\r\n        'lyco': [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58F70>],\r\n        'hypernet': [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D5A800>]\r\n    }\r\n\r\n    Example output:\r\n\r\n    {\r\n        <extra_networks_lora.ExtraNetworkLora object at 0x0000020581BEECE0>: [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58310>, <modules.extra_networks.ExtraNetworkParams object at 0x0000020690D58F70>],\r\n        <modules.extra_networks_hypernet.ExtraNetworkHypernet object at 0x0000020581BEEE60>: [<modules.extra_networks.ExtraNetworkParams object at 0x0000020690D5A800>]\r\n    }\r\n    \"\"\"\r\n\r\n    res = {}\r\n\r\n    for extra_network_name, extra_network_args in list(extra_network_data.items()):\r\n        extra_network = extra_network_registry.get(extra_network_name, None)\r\n        alias = extra_network_aliases.get(extra_network_name, None)\r\n\r\n        if alias is not None and extra_network is None:\r\n            extra_network = alias\r\n\r\n        if extra_network is None:\r\n            logging.info(f\"Skipping unknown extra network: {extra_network_name}\")\r\n            continue\r\n\r\n        res.setdefault(extra_network, []).extend(extra_network_args)\r\n\r\n    return res\r\n\r\n\r\ndef activate(p, extra_network_data):\r\n    \"\"\"call activate for extra networks in extra_network_data in specified order, then call\r\n    activate for all remaining registered networks with an empty argument list\"\"\"\r\n\r\n    activated = []\r\n\r\n    for extra_network, extra_network_args in lookup_extra_networks(extra_network_data).items():\r\n\r\n        try:\r\n            extra_network.activate(p, extra_network_args)\r\n            activated.append(extra_network)\r\n        except Exception as e:\r\n            errors.display(e, f\"activating extra network {extra_network.name} with arguments {extra_network_args}\")\r\n\r\n    for extra_network_name, extra_network in extra_network_registry.items():\r\n        if extra_network in activated:\r\n            continue\r\n\r\n        try:\r\n            extra_network.activate(p, [])\r\n        except Exception as e:\r\n            errors.display(e, f\"activating extra network {extra_network_name}\")\r\n\r\n    if p.scripts is not None:\r\n        p.scripts.after_extra_networks_activate(p, batch_number=p.iteration, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds, extra_network_data=extra_network_data)\r\n\r\n\r\ndef deactivate(p, extra_network_data):\r\n    \"\"\"call deactivate for extra networks in extra_network_data in specified order, then call\r\n    deactivate for all remaining registered networks\"\"\"\r\n\r\n    data = lookup_extra_networks(extra_network_data)\r\n\r\n    for extra_network in data:\r\n        try:\r\n            extra_network.deactivate(p)\r\n        except Exception as e:\r\n            errors.display(e, f\"deactivating extra network {extra_network.name}\")\r\n\r\n    for extra_network_name, extra_network in extra_network_registry.items():\r\n        if extra_network in data:\r\n            continue\r\n\r\n        try:\r\n            extra_network.deactivate(p)\r\n        except Exception as e:\r\n            errors.display(e, f\"deactivating unmentioned extra network {extra_network_name}\")\r\n\r\n\r\nre_extra_net = re.compile(r\"<(\\w+):([^>]+)>\")\r\n\r\n\r\ndef parse_prompt(prompt):\r\n    res = defaultdict(list)\r\n\r\n    def found(m):\r\n        name = m.group(1)\r\n        args = m.group(2)\r\n\r\n        res[name].append(ExtraNetworkParams(items=args.split(\":\")))\r\n\r\n        return \"\"\r\n\r\n    prompt = re.sub(re_extra_net, found, prompt)\r\n\r\n    return prompt, res\r\n\r\n\r\ndef parse_prompts(prompts):\r\n    res = []\r\n    extra_data = None\r\n\r\n    for prompt in prompts:\r\n        updated_prompt, parsed_extra_data = parse_prompt(prompt)\r\n\r\n        if extra_data is None:\r\n            extra_data = parsed_extra_data\r\n\r\n        res.append(updated_prompt)\r\n\r\n    return res, extra_data\r\n\r\n\r\ndef get_user_metadata(filename, lister=None):\r\n    if filename is None:\r\n        return {}\r\n\r\n    basename, ext = os.path.splitext(filename)\r\n    metadata_filename = basename + '.json'\r\n\r\n    metadata = {}\r\n    try:\r\n        exists = lister.exists(metadata_filename) if lister else os.path.exists(metadata_filename)\r\n        if exists:\r\n            with open(metadata_filename, \"r\", encoding=\"utf8\") as file:\r\n                metadata = json.load(file)\r\n    except Exception as e:\r\n        errors.display(e, f\"reading extra network user metadata from {metadata_filename}\")\r\n\r\n    return metadata\r\n"
  },
  {
    "path": "modules/extra_networks_hypernet.py",
    "content": "from modules import extra_networks, shared\r\nfrom modules.hypernetworks import hypernetwork\r\n\r\n\r\nclass ExtraNetworkHypernet(extra_networks.ExtraNetwork):\r\n    def __init__(self):\r\n        super().__init__('hypernet')\r\n\r\n    def activate(self, p, params_list):\r\n        additional = shared.opts.sd_hypernetwork\r\n\r\n        if additional != \"None\" and additional in shared.hypernetworks and not any(x for x in params_list if x.items[0] == additional):\r\n            hypernet_prompt_text = f\"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>\"\r\n            p.all_prompts = [f\"{prompt}{hypernet_prompt_text}\" for prompt in p.all_prompts]\r\n            params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))\r\n\r\n        names = []\r\n        multipliers = []\r\n        for params in params_list:\r\n            assert params.items\r\n\r\n            names.append(params.items[0])\r\n            multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)\r\n\r\n        hypernetwork.load_hypernetworks(names, multipliers)\r\n\r\n    def deactivate(self, p):\r\n        pass\r\n"
  },
  {
    "path": "modules/extras.py",
    "content": "import os\r\nimport re\r\nimport shutil\r\nimport json\r\n\r\n\r\nimport torch\r\nimport tqdm\r\n\r\nfrom modules import shared, images, sd_models, sd_vae, sd_models_config, errors\r\nfrom modules.ui_common import plaintext_to_html\r\nimport gradio as gr\r\nimport safetensors.torch\r\n\r\n\r\ndef run_pnginfo(image):\r\n    if image is None:\r\n        return '', '', ''\r\n\r\n    geninfo, items = images.read_info_from_image(image)\r\n    items = {**{'parameters': geninfo}, **items}\r\n\r\n    info = ''\r\n    for key, text in items.items():\r\n        info += f\"\"\"\r\n<div>\r\n<p><b>{plaintext_to_html(str(key))}</b></p>\r\n<p>{plaintext_to_html(str(text))}</p>\r\n</div>\r\n\"\"\".strip()+\"\\n\"\r\n\r\n    if len(info) == 0:\r\n        message = \"Nothing found in the image.\"\r\n        info = f\"<div><p>{message}<p></div>\"\r\n\r\n    return '', geninfo, info\r\n\r\n\r\ndef create_config(ckpt_result, config_source, a, b, c):\r\n    def config(x):\r\n        res = sd_models_config.find_checkpoint_config_near_filename(x) if x else None\r\n        return res if res != shared.sd_default_config else None\r\n\r\n    if config_source == 0:\r\n        cfg = config(a) or config(b) or config(c)\r\n    elif config_source == 1:\r\n        cfg = config(b)\r\n    elif config_source == 2:\r\n        cfg = config(c)\r\n    else:\r\n        cfg = None\r\n\r\n    if cfg is None:\r\n        return\r\n\r\n    filename, _ = os.path.splitext(ckpt_result)\r\n    checkpoint_filename = filename + \".yaml\"\r\n\r\n    print(\"Copying config:\")\r\n    print(\"   from:\", cfg)\r\n    print(\"     to:\", checkpoint_filename)\r\n    shutil.copyfile(cfg, checkpoint_filename)\r\n\r\n\r\ncheckpoint_dict_skip_on_merge = [\"cond_stage_model.transformer.text_model.embeddings.position_ids\"]\r\n\r\n\r\ndef to_half(tensor, enable):\r\n    if enable and tensor.dtype == torch.float:\r\n        return tensor.half()\r\n\r\n    return tensor\r\n\r\n\r\ndef read_metadata(primary_model_name, secondary_model_name, tertiary_model_name):\r\n    metadata = {}\r\n\r\n    for checkpoint_name in [primary_model_name, secondary_model_name, tertiary_model_name]:\r\n        checkpoint_info = sd_models.checkpoints_list.get(checkpoint_name, None)\r\n        if checkpoint_info is None:\r\n            continue\r\n\r\n        metadata.update(checkpoint_info.metadata)\r\n\r\n    return json.dumps(metadata, indent=4, ensure_ascii=False)\r\n\r\n\r\ndef run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata, add_merge_recipe, copy_metadata_fields, metadata_json):\r\n    shared.state.begin(job=\"model-merge\")\r\n\r\n    def fail(message):\r\n        shared.state.textinfo = message\r\n        shared.state.end()\r\n        return [*[gr.update() for _ in range(4)], message]\r\n\r\n    def weighted_sum(theta0, theta1, alpha):\r\n        return ((1 - alpha) * theta0) + (alpha * theta1)\r\n\r\n    def get_difference(theta1, theta2):\r\n        return theta1 - theta2\r\n\r\n    def add_difference(theta0, theta1_2_diff, alpha):\r\n        return theta0 + (alpha * theta1_2_diff)\r\n\r\n    def filename_weighted_sum():\r\n        a = primary_model_info.model_name\r\n        b = secondary_model_info.model_name\r\n        Ma = round(1 - multiplier, 2)\r\n        Mb = round(multiplier, 2)\r\n\r\n        return f\"{Ma}({a}) + {Mb}({b})\"\r\n\r\n    def filename_add_difference():\r\n        a = primary_model_info.model_name\r\n        b = secondary_model_info.model_name\r\n        c = tertiary_model_info.model_name\r\n        M = round(multiplier, 2)\r\n\r\n        return f\"{a} + {M}({b} - {c})\"\r\n\r\n    def filename_nothing():\r\n        return primary_model_info.model_name\r\n\r\n    theta_funcs = {\r\n        \"Weighted sum\": (filename_weighted_sum, None, weighted_sum),\r\n        \"Add difference\": (filename_add_difference, get_difference, add_difference),\r\n        \"No interpolation\": (filename_nothing, None, None),\r\n    }\r\n    filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method]\r\n    shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0)\r\n\r\n    if not primary_model_name:\r\n        return fail(\"Failed: Merging requires a primary model.\")\r\n\r\n    primary_model_info = sd_models.checkpoints_list[primary_model_name]\r\n\r\n    if theta_func2 and not secondary_model_name:\r\n        return fail(\"Failed: Merging requires a secondary model.\")\r\n\r\n    secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None\r\n\r\n    if theta_func1 and not tertiary_model_name:\r\n        return fail(f\"Failed: Interpolation method ({interp_method}) requires a tertiary model.\")\r\n\r\n    tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None\r\n\r\n    result_is_inpainting_model = False\r\n    result_is_instruct_pix2pix_model = False\r\n\r\n    if theta_func2:\r\n        shared.state.textinfo = \"Loading B\"\r\n        print(f\"Loading {secondary_model_info.filename}...\")\r\n        theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')\r\n    else:\r\n        theta_1 = None\r\n\r\n    if theta_func1:\r\n        shared.state.textinfo = \"Loading C\"\r\n        print(f\"Loading {tertiary_model_info.filename}...\")\r\n        theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')\r\n\r\n        shared.state.textinfo = 'Merging B and C'\r\n        shared.state.sampling_steps = len(theta_1.keys())\r\n        for key in tqdm.tqdm(theta_1.keys()):\r\n            if key in checkpoint_dict_skip_on_merge:\r\n                continue\r\n\r\n            if 'model' in key:\r\n                if key in theta_2:\r\n                    t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))\r\n                    theta_1[key] = theta_func1(theta_1[key], t2)\r\n                else:\r\n                    theta_1[key] = torch.zeros_like(theta_1[key])\r\n\r\n            shared.state.sampling_step += 1\r\n        del theta_2\r\n\r\n        shared.state.nextjob()\r\n\r\n    shared.state.textinfo = f\"Loading {primary_model_info.filename}...\"\r\n    print(f\"Loading {primary_model_info.filename}...\")\r\n    theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')\r\n\r\n    print(\"Merging...\")\r\n    shared.state.textinfo = 'Merging A and B'\r\n    shared.state.sampling_steps = len(theta_0.keys())\r\n    for key in tqdm.tqdm(theta_0.keys()):\r\n        if theta_1 and 'model' in key and key in theta_1:\r\n\r\n            if key in checkpoint_dict_skip_on_merge:\r\n                continue\r\n\r\n            a = theta_0[key]\r\n            b = theta_1[key]\r\n\r\n            # this enables merging an inpainting model (A) with another one (B);\r\n            # where normal model would have 4 channels, for latenst space, inpainting model would\r\n            # have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9\r\n            if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]:\r\n                if a.shape[1] == 4 and b.shape[1] == 9:\r\n                    raise RuntimeError(\"When merging inpainting model with a normal one, A must be the inpainting model.\")\r\n                if a.shape[1] == 4 and b.shape[1] == 8:\r\n                    raise RuntimeError(\"When merging instruct-pix2pix model with a normal one, A must be the instruct-pix2pix model.\")\r\n\r\n                if a.shape[1] == 8 and b.shape[1] == 4:#If we have an Instruct-Pix2Pix model...\r\n                    theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)#Merge only the vectors the models have in common.  Otherwise we get an error due to dimension mismatch.\r\n                    result_is_instruct_pix2pix_model = True\r\n                else:\r\n                    assert a.shape[1] == 9 and b.shape[1] == 4, f\"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}\"\r\n                    theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)\r\n                    result_is_inpainting_model = True\r\n            else:\r\n                theta_0[key] = theta_func2(a, b, multiplier)\r\n\r\n            theta_0[key] = to_half(theta_0[key], save_as_half)\r\n\r\n        shared.state.sampling_step += 1\r\n\r\n    del theta_1\r\n\r\n    bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None)\r\n    if bake_in_vae_filename is not None:\r\n        print(f\"Baking in VAE from {bake_in_vae_filename}\")\r\n        shared.state.textinfo = 'Baking in VAE'\r\n        vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu')\r\n\r\n        for key in vae_dict.keys():\r\n            theta_0_key = 'first_stage_model.' + key\r\n            if theta_0_key in theta_0:\r\n                theta_0[theta_0_key] = to_half(vae_dict[key], save_as_half)\r\n\r\n        del vae_dict\r\n\r\n    if save_as_half and not theta_func2:\r\n        for key in theta_0.keys():\r\n            theta_0[key] = to_half(theta_0[key], save_as_half)\r\n\r\n    if discard_weights:\r\n        regex = re.compile(discard_weights)\r\n        for key in list(theta_0):\r\n            if re.search(regex, key):\r\n                theta_0.pop(key, None)\r\n\r\n    ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path\r\n\r\n    filename = filename_generator() if custom_name == '' else custom_name\r\n    filename += \".inpainting\" if result_is_inpainting_model else \"\"\r\n    filename += \".instruct-pix2pix\" if result_is_instruct_pix2pix_model else \"\"\r\n    filename += \".\" + checkpoint_format\r\n\r\n    output_modelname = os.path.join(ckpt_dir, filename)\r\n\r\n    shared.state.nextjob()\r\n    shared.state.textinfo = \"Saving\"\r\n    print(f\"Saving to {output_modelname}...\")\r\n\r\n    metadata = {}\r\n\r\n    if save_metadata and copy_metadata_fields:\r\n        if primary_model_info:\r\n            metadata.update(primary_model_info.metadata)\r\n        if secondary_model_info:\r\n            metadata.update(secondary_model_info.metadata)\r\n        if tertiary_model_info:\r\n            metadata.update(tertiary_model_info.metadata)\r\n\r\n    if save_metadata:\r\n        try:\r\n            metadata.update(json.loads(metadata_json))\r\n        except Exception as e:\r\n            errors.display(e, \"readin metadata from json\")\r\n\r\n        metadata[\"format\"] = \"pt\"\r\n\r\n    if save_metadata and add_merge_recipe:\r\n        merge_recipe = {\r\n            \"type\": \"webui\", # indicate this model was merged with webui's built-in merger\r\n            \"primary_model_hash\": primary_model_info.sha256,\r\n            \"secondary_model_hash\": secondary_model_info.sha256 if secondary_model_info else None,\r\n            \"tertiary_model_hash\": tertiary_model_info.sha256 if tertiary_model_info else None,\r\n            \"interp_method\": interp_method,\r\n            \"multiplier\": multiplier,\r\n            \"save_as_half\": save_as_half,\r\n            \"custom_name\": custom_name,\r\n            \"config_source\": config_source,\r\n            \"bake_in_vae\": bake_in_vae,\r\n            \"discard_weights\": discard_weights,\r\n            \"is_inpainting\": result_is_inpainting_model,\r\n            \"is_instruct_pix2pix\": result_is_instruct_pix2pix_model\r\n        }\r\n\r\n        sd_merge_models = {}\r\n\r\n        def add_model_metadata(checkpoint_info):\r\n            checkpoint_info.calculate_shorthash()\r\n            sd_merge_models[checkpoint_info.sha256] = {\r\n                \"name\": checkpoint_info.name,\r\n                \"legacy_hash\": checkpoint_info.hash,\r\n                \"sd_merge_recipe\": checkpoint_info.metadata.get(\"sd_merge_recipe\", None)\r\n            }\r\n\r\n            sd_merge_models.update(checkpoint_info.metadata.get(\"sd_merge_models\", {}))\r\n\r\n        add_model_metadata(primary_model_info)\r\n        if secondary_model_info:\r\n            add_model_metadata(secondary_model_info)\r\n        if tertiary_model_info:\r\n            add_model_metadata(tertiary_model_info)\r\n\r\n        metadata[\"sd_merge_recipe\"] = json.dumps(merge_recipe)\r\n        metadata[\"sd_merge_models\"] = json.dumps(sd_merge_models)\r\n\r\n    _, extension = os.path.splitext(output_modelname)\r\n    if extension.lower() == \".safetensors\":\r\n        safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata if len(metadata)>0 else None)\r\n    else:\r\n        torch.save(theta_0, output_modelname)\r\n\r\n    sd_models.list_models()\r\n    created_model = next((ckpt for ckpt in sd_models.checkpoints_list.values() if ckpt.name == filename), None)\r\n    if created_model:\r\n        created_model.calculate_shorthash()\r\n\r\n    create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)\r\n\r\n    print(f\"Checkpoint saved to {output_modelname}.\")\r\n    shared.state.textinfo = \"Checkpoint saved\"\r\n    shared.state.end()\r\n\r\n    return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], \"Checkpoint saved to \" + output_modelname]\r\n"
  },
  {
    "path": "modules/face_restoration.py",
    "content": "from modules import shared\r\n\r\n\r\nclass FaceRestoration:\r\n    def name(self):\r\n        return \"None\"\r\n\r\n    def restore(self, np_image):\r\n        return np_image\r\n\r\n\r\ndef restore_faces(np_image):\r\n    face_restorers = [x for x in shared.face_restorers if x.name() == shared.opts.face_restoration_model or shared.opts.face_restoration_model is None]\r\n    if len(face_restorers) == 0:\r\n        return np_image\r\n\r\n    face_restorer = face_restorers[0]\r\n\r\n    return face_restorer.restore(np_image)\r\n"
  },
  {
    "path": "modules/face_restoration_utils.py",
    "content": "from __future__ import annotations\n\nimport logging\nimport os\nfrom functools import cached_property\nfrom typing import TYPE_CHECKING, Callable\n\nimport cv2\nimport numpy as np\nimport torch\n\nfrom modules import devices, errors, face_restoration, shared\n\nif TYPE_CHECKING:\n    from facexlib.utils.face_restoration_helper import FaceRestoreHelper\n\nlogger = logging.getLogger(__name__)\n\n\ndef bgr_image_to_rgb_tensor(img: np.ndarray) -> torch.Tensor:\n    \"\"\"Convert a BGR NumPy image in [0..1] range to a PyTorch RGB float32 tensor.\"\"\"\n    assert img.shape[2] == 3, \"image must be RGB\"\n    if img.dtype == \"float64\":\n        img = img.astype(\"float32\")\n    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n    return torch.from_numpy(img.transpose(2, 0, 1)).float()\n\n\ndef rgb_tensor_to_bgr_image(tensor: torch.Tensor, *, min_max=(0.0, 1.0)) -> np.ndarray:\n    \"\"\"\n    Convert a PyTorch RGB tensor in range `min_max` to a BGR NumPy image in [0..1] range.\n    \"\"\"\n    tensor = tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)\n    tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0])\n    assert tensor.dim() == 3, \"tensor must be RGB\"\n    img_np = tensor.numpy().transpose(1, 2, 0)\n    if img_np.shape[2] == 1:  # gray image, no RGB/BGR required\n        return np.squeeze(img_np, axis=2)\n    return cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)\n\n\ndef create_face_helper(device) -> FaceRestoreHelper:\n    from facexlib.detection import retinaface\n    from facexlib.utils.face_restoration_helper import FaceRestoreHelper\n    if hasattr(retinaface, 'device'):\n        retinaface.device = device\n    return FaceRestoreHelper(\n        upscale_factor=1,\n        face_size=512,\n        crop_ratio=(1, 1),\n        det_model='retinaface_resnet50',\n        save_ext='png',\n        use_parse=True,\n        device=device,\n    )\n\n\ndef restore_with_face_helper(\n    np_image: np.ndarray,\n    face_helper: FaceRestoreHelper,\n    restore_face: Callable[[torch.Tensor], torch.Tensor],\n) -> np.ndarray:\n    \"\"\"\n    Find faces in the image using face_helper, restore them using restore_face, and paste them back into the image.\n\n    `restore_face` should take a cropped face image and return a restored face image.\n    \"\"\"\n    from torchvision.transforms.functional import normalize\n    np_image = np_image[:, :, ::-1]\n    original_resolution = np_image.shape[0:2]\n\n    try:\n        logger.debug(\"Detecting faces...\")\n        face_helper.clean_all()\n        face_helper.read_image(np_image)\n        face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)\n        face_helper.align_warp_face()\n        logger.debug(\"Found %d faces, restoring\", len(face_helper.cropped_faces))\n        for cropped_face in face_helper.cropped_faces:\n            cropped_face_t = bgr_image_to_rgb_tensor(cropped_face / 255.0)\n            normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)\n            cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)\n\n            try:\n                with torch.no_grad():\n                    cropped_face_t = restore_face(cropped_face_t)\n                devices.torch_gc()\n            except Exception:\n                errors.report('Failed face-restoration inference', exc_info=True)\n\n            restored_face = rgb_tensor_to_bgr_image(cropped_face_t, min_max=(-1, 1))\n            restored_face = (restored_face * 255.0).astype('uint8')\n            face_helper.add_restored_face(restored_face)\n\n        logger.debug(\"Merging restored faces into image\")\n        face_helper.get_inverse_affine(None)\n        img = face_helper.paste_faces_to_input_image()\n        img = img[:, :, ::-1]\n        if original_resolution != img.shape[0:2]:\n            img = cv2.resize(\n                img,\n                (0, 0),\n                fx=original_resolution[1] / img.shape[1],\n                fy=original_resolution[0] / img.shape[0],\n                interpolation=cv2.INTER_LINEAR,\n            )\n        logger.debug(\"Face restoration complete\")\n    finally:\n        face_helper.clean_all()\n    return img\n\n\nclass CommonFaceRestoration(face_restoration.FaceRestoration):\n    net: torch.Module | None\n    model_url: str\n    model_download_name: str\n\n    def __init__(self, model_path: str):\n        super().__init__()\n        self.net = None\n        self.model_path = model_path\n        os.makedirs(model_path, exist_ok=True)\n\n    @cached_property\n    def face_helper(self) -> FaceRestoreHelper:\n        return create_face_helper(self.get_device())\n\n    def send_model_to(self, device):\n        if self.net:\n            logger.debug(\"Sending %s to %s\", self.net, device)\n            self.net.to(device)\n        if self.face_helper:\n            logger.debug(\"Sending face helper to %s\", device)\n            self.face_helper.face_det.to(device)\n            self.face_helper.face_parse.to(device)\n\n    def get_device(self):\n        raise NotImplementedError(\"get_device must be implemented by subclasses\")\n\n    def load_net(self) -> torch.Module:\n        raise NotImplementedError(\"load_net must be implemented by subclasses\")\n\n    def restore_with_helper(\n        self,\n        np_image: np.ndarray,\n        restore_face: Callable[[torch.Tensor], torch.Tensor],\n    ) -> np.ndarray:\n        try:\n            if self.net is None:\n                self.net = self.load_net()\n        except Exception:\n            logger.warning(\"Unable to load face-restoration model\", exc_info=True)\n            return np_image\n\n        try:\n            self.send_model_to(self.get_device())\n            return restore_with_face_helper(np_image, self.face_helper, restore_face)\n        finally:\n            if shared.opts.face_restoration_unload:\n                self.send_model_to(devices.cpu)\n\n\ndef patch_facexlib(dirname: str) -> None:\n    import facexlib.detection\n    import facexlib.parsing\n\n    det_facex_load_file_from_url = facexlib.detection.load_file_from_url\n    par_facex_load_file_from_url = facexlib.parsing.load_file_from_url\n\n    def update_kwargs(kwargs):\n        return dict(kwargs, save_dir=dirname, model_dir=None)\n\n    def facex_load_file_from_url(**kwargs):\n        return det_facex_load_file_from_url(**update_kwargs(kwargs))\n\n    def facex_load_file_from_url2(**kwargs):\n        return par_facex_load_file_from_url(**update_kwargs(kwargs))\n\n    facexlib.detection.load_file_from_url = facex_load_file_from_url\n    facexlib.parsing.load_file_from_url = facex_load_file_from_url2\n"
  },
  {
    "path": "modules/fifo_lock.py",
    "content": "import threading\nimport collections\n\n\n# reference: https://gist.github.com/vitaliyp/6d54dd76ca2c3cdfc1149d33007dc34a\nclass FIFOLock(object):\n    def __init__(self):\n        self._lock = threading.Lock()\n        self._inner_lock = threading.Lock()\n        self._pending_threads = collections.deque()\n\n    def acquire(self, blocking=True):\n        with self._inner_lock:\n            lock_acquired = self._lock.acquire(False)\n            if lock_acquired:\n                return True\n            elif not blocking:\n                return False\n\n            release_event = threading.Event()\n            self._pending_threads.append(release_event)\n\n        release_event.wait()\n        return self._lock.acquire()\n\n    def release(self):\n        with self._inner_lock:\n            if self._pending_threads:\n                release_event = self._pending_threads.popleft()\n                release_event.set()\n\n            self._lock.release()\n\n    __enter__ = acquire\n\n    def __exit__(self, t, v, tb):\n        self.release()\n"
  },
  {
    "path": "modules/gfpgan_model.py",
    "content": "from __future__ import annotations\r\n\r\nimport logging\r\nimport os\r\n\r\nimport torch\r\n\r\nfrom modules import (\r\n    devices,\r\n    errors,\r\n    face_restoration,\r\n    face_restoration_utils,\r\n    modelloader,\r\n    shared,\r\n)\r\n\r\nlogger = logging.getLogger(__name__)\r\nmodel_url = \"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth\"\r\nmodel_download_name = \"GFPGANv1.4.pth\"\r\ngfpgan_face_restorer: face_restoration.FaceRestoration | None = None\r\n\r\n\r\nclass FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration):\r\n    def name(self):\r\n        return \"GFPGAN\"\r\n\r\n    def get_device(self):\r\n        return devices.device_gfpgan\r\n\r\n    def load_net(self) -> torch.Module:\r\n        for model_path in modelloader.load_models(\r\n            model_path=self.model_path,\r\n            model_url=model_url,\r\n            command_path=self.model_path,\r\n            download_name=model_download_name,\r\n            ext_filter=['.pth'],\r\n        ):\r\n            if 'GFPGAN' in os.path.basename(model_path):\r\n                return modelloader.load_spandrel_model(\r\n                    model_path,\r\n                    device=self.get_device(),\r\n                    expected_architecture='GFPGAN',\r\n                ).model\r\n        raise ValueError(\"No GFPGAN model found\")\r\n\r\n    def restore(self, np_image):\r\n        def restore_face(cropped_face_t):\r\n            assert self.net is not None\r\n            return self.net(cropped_face_t, return_rgb=False)[0]\r\n\r\n        return self.restore_with_helper(np_image, restore_face)\r\n\r\n\r\ndef gfpgan_fix_faces(np_image):\r\n    if gfpgan_face_restorer:\r\n        return gfpgan_face_restorer.restore(np_image)\r\n    logger.warning(\"GFPGAN face restorer not set up\")\r\n    return np_image\r\n\r\n\r\ndef setup_model(dirname: str) -> None:\r\n    global gfpgan_face_restorer\r\n\r\n    try:\r\n        face_restoration_utils.patch_facexlib(dirname)\r\n        gfpgan_face_restorer = FaceRestorerGFPGAN(model_path=dirname)\r\n        shared.face_restorers.append(gfpgan_face_restorer)\r\n    except Exception:\r\n        errors.report(\"Error setting up GFPGAN\", exc_info=True)\r\n"
  },
  {
    "path": "modules/gitpython_hack.py",
    "content": "from __future__ import annotations\n\nimport io\nimport subprocess\n\nimport git\n\n\nclass Git(git.Git):\n    \"\"\"\n    Git subclassed to never use persistent processes.\n    \"\"\"\n\n    def _get_persistent_cmd(self, attr_name, cmd_name, *args, **kwargs):\n        raise NotImplementedError(f\"Refusing to use persistent process: {attr_name} ({cmd_name} {args} {kwargs})\")\n\n    def get_object_header(self, ref: str | bytes) -> tuple[str, str, int]:\n        ret = subprocess.check_output(\n            [self.GIT_PYTHON_GIT_EXECUTABLE, \"cat-file\", \"--batch-check\"],\n            input=self._prepare_ref(ref),\n            cwd=self._working_dir,\n            timeout=2,\n        )\n        return self._parse_object_header(ret)\n\n    def stream_object_data(self, ref: str) -> tuple[str, str, int, Git.CatFileContentStream]:\n        # Not really streaming, per se; this buffers the entire object in memory.\n        # Shouldn't be a problem for our use case, since we're only using this for\n        # object headers (commit objects).\n        ret = subprocess.check_output(\n            [self.GIT_PYTHON_GIT_EXECUTABLE, \"cat-file\", \"--batch\"],\n            input=self._prepare_ref(ref),\n            cwd=self._working_dir,\n            timeout=30,\n        )\n        bio = io.BytesIO(ret)\n        hexsha, typename, size = self._parse_object_header(bio.readline())\n        return (hexsha, typename, size, self.CatFileContentStream(size, bio))\n\n\nclass Repo(git.Repo):\n    GitCommandWrapperType = Git\n"
  },
  {
    "path": "modules/gradio_extensons.py",
    "content": "import gradio as gr\r\n\r\nfrom modules import scripts, ui_tempdir, patches\r\n\r\n\r\ndef add_classes_to_gradio_component(comp):\r\n    \"\"\"\r\n    this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others\r\n    \"\"\"\r\n\r\n    comp.elem_classes = [f\"gradio-{comp.get_block_name()}\", *(comp.elem_classes or [])]\r\n\r\n    if getattr(comp, 'multiselect', False):\r\n        comp.elem_classes.append('multiselect')\r\n\r\n\r\ndef IOComponent_init(self, *args, **kwargs):\r\n    self.webui_tooltip = kwargs.pop('tooltip', None)\r\n\r\n    if scripts.scripts_current is not None:\r\n        scripts.scripts_current.before_component(self, **kwargs)\r\n\r\n    scripts.script_callbacks.before_component_callback(self, **kwargs)\r\n\r\n    res = original_IOComponent_init(self, *args, **kwargs)\r\n\r\n    add_classes_to_gradio_component(self)\r\n\r\n    scripts.script_callbacks.after_component_callback(self, **kwargs)\r\n\r\n    if scripts.scripts_current is not None:\r\n        scripts.scripts_current.after_component(self, **kwargs)\r\n\r\n    return res\r\n\r\n\r\ndef Block_get_config(self):\r\n    config = original_Block_get_config(self)\r\n\r\n    webui_tooltip = getattr(self, 'webui_tooltip', None)\r\n    if webui_tooltip:\r\n        config[\"webui_tooltip\"] = webui_tooltip\r\n\r\n    config.pop('example_inputs', None)\r\n\r\n    return config\r\n\r\n\r\ndef BlockContext_init(self, *args, **kwargs):\r\n    if scripts.scripts_current is not None:\r\n        scripts.scripts_current.before_component(self, **kwargs)\r\n\r\n    scripts.script_callbacks.before_component_callback(self, **kwargs)\r\n\r\n    res = original_BlockContext_init(self, *args, **kwargs)\r\n\r\n    add_classes_to_gradio_component(self)\r\n\r\n    scripts.script_callbacks.after_component_callback(self, **kwargs)\r\n\r\n    if scripts.scripts_current is not None:\r\n        scripts.scripts_current.after_component(self, **kwargs)\r\n\r\n    return res\r\n\r\n\r\ndef Blocks_get_config_file(self, *args, **kwargs):\r\n    config = original_Blocks_get_config_file(self, *args, **kwargs)\r\n\r\n    for comp_config in config[\"components\"]:\r\n        if \"example_inputs\" in comp_config:\r\n            comp_config[\"example_inputs\"] = {\"serialized\": []}\r\n\r\n    return config\r\n\r\n\r\noriginal_IOComponent_init = patches.patch(__name__, obj=gr.components.IOComponent, field=\"__init__\", replacement=IOComponent_init)\r\noriginal_Block_get_config = patches.patch(__name__, obj=gr.blocks.Block, field=\"get_config\", replacement=Block_get_config)\r\noriginal_BlockContext_init = patches.patch(__name__, obj=gr.blocks.BlockContext, field=\"__init__\", replacement=BlockContext_init)\r\noriginal_Blocks_get_config_file = patches.patch(__name__, obj=gr.blocks.Blocks, field=\"get_config_file\", replacement=Blocks_get_config_file)\r\n\r\n\r\nui_tempdir.install_ui_tempdir_override()\r\n"
  },
  {
    "path": "modules/hashes.py",
    "content": "import hashlib\r\nimport os.path\r\n\r\nfrom modules import shared\r\nimport modules.cache\r\n\r\ndump_cache = modules.cache.dump_cache\r\ncache = modules.cache.cache\r\n\r\n\r\ndef calculate_sha256(filename):\r\n    hash_sha256 = hashlib.sha256()\r\n    blksize = 1024 * 1024\r\n\r\n    with open(filename, \"rb\") as f:\r\n        for chunk in iter(lambda: f.read(blksize), b\"\"):\r\n            hash_sha256.update(chunk)\r\n\r\n    return hash_sha256.hexdigest()\r\n\r\n\r\ndef sha256_from_cache(filename, title, use_addnet_hash=False):\r\n    hashes = cache(\"hashes-addnet\") if use_addnet_hash else cache(\"hashes\")\r\n    try:\r\n        ondisk_mtime = os.path.getmtime(filename)\r\n    except FileNotFoundError:\r\n        return None\r\n\r\n    if title not in hashes:\r\n        return None\r\n\r\n    cached_sha256 = hashes[title].get(\"sha256\", None)\r\n    cached_mtime = hashes[title].get(\"mtime\", 0)\r\n\r\n    if ondisk_mtime > cached_mtime or cached_sha256 is None:\r\n        return None\r\n\r\n    return cached_sha256\r\n\r\n\r\ndef sha256(filename, title, use_addnet_hash=False):\r\n    hashes = cache(\"hashes-addnet\") if use_addnet_hash else cache(\"hashes\")\r\n\r\n    sha256_value = sha256_from_cache(filename, title, use_addnet_hash)\r\n    if sha256_value is not None:\r\n        return sha256_value\r\n\r\n    if shared.cmd_opts.no_hashing:\r\n        return None\r\n\r\n    print(f\"Calculating sha256 for {filename}: \", end='')\r\n    if use_addnet_hash:\r\n        with open(filename, \"rb\") as file:\r\n            sha256_value = addnet_hash_safetensors(file)\r\n    else:\r\n        sha256_value = calculate_sha256(filename)\r\n    print(f\"{sha256_value}\")\r\n\r\n    hashes[title] = {\r\n        \"mtime\": os.path.getmtime(filename),\r\n        \"sha256\": sha256_value,\r\n    }\r\n\r\n    dump_cache()\r\n\r\n    return sha256_value\r\n\r\n\r\ndef addnet_hash_safetensors(b):\r\n    \"\"\"kohya-ss hash for safetensors from https://github.com/kohya-ss/sd-scripts/blob/main/library/train_util.py\"\"\"\r\n    hash_sha256 = hashlib.sha256()\r\n    blksize = 1024 * 1024\r\n\r\n    b.seek(0)\r\n    header = b.read(8)\r\n    n = int.from_bytes(header, \"little\")\r\n\r\n    offset = n + 8\r\n    b.seek(offset)\r\n    for chunk in iter(lambda: b.read(blksize), b\"\"):\r\n        hash_sha256.update(chunk)\r\n\r\n    return hash_sha256.hexdigest()\r\n\r\n"
  },
  {
    "path": "modules/hat_model.py",
    "content": "import os\r\nimport sys\r\n\r\nfrom modules import modelloader, devices\r\nfrom modules.shared import opts\r\nfrom modules.upscaler import Upscaler, UpscalerData\r\nfrom modules.upscaler_utils import upscale_with_model\r\n\r\n\r\nclass UpscalerHAT(Upscaler):\r\n    def __init__(self, dirname):\r\n        self.name = \"HAT\"\r\n        self.scalers = []\r\n        self.user_path = dirname\r\n        super().__init__()\r\n        for file in self.find_models(ext_filter=[\".pt\", \".pth\"]):\r\n            name = modelloader.friendly_name(file)\r\n            scale = 4  # TODO: scale might not be 4, but we can't know without loading the model\r\n            scaler_data = UpscalerData(name, file, upscaler=self, scale=scale)\r\n            self.scalers.append(scaler_data)\r\n\r\n    def do_upscale(self, img, selected_model):\r\n        try:\r\n            model = self.load_model(selected_model)\r\n        except Exception as e:\r\n            print(f\"Unable to load HAT model {selected_model}: {e}\", file=sys.stderr)\r\n            return img\r\n        model.to(devices.device_esrgan)  # TODO: should probably be device_hat\r\n        return upscale_with_model(\r\n            model,\r\n            img,\r\n            tile_size=opts.ESRGAN_tile,  # TODO: should probably be HAT_tile\r\n            tile_overlap=opts.ESRGAN_tile_overlap,  # TODO: should probably be HAT_tile_overlap\r\n        )\r\n\r\n    def load_model(self, path: str):\r\n        if not os.path.isfile(path):\r\n            raise FileNotFoundError(f\"Model file {path} not found\")\r\n        return modelloader.load_spandrel_model(\r\n            path,\r\n            device=devices.device_esrgan,  # TODO: should probably be device_hat\r\n            expected_architecture='HAT',\r\n        )\r\n"
  },
  {
    "path": "modules/hypernetworks/hypernetwork.py",
    "content": "import datetime\r\nimport glob\r\nimport html\r\nimport os\r\nimport inspect\r\nfrom contextlib import closing\r\n\r\nimport modules.textual_inversion.dataset\r\nimport torch\r\nimport tqdm\r\nfrom einops import rearrange, repeat\r\nfrom ldm.util import default\r\nfrom modules import devices, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors\r\nfrom modules.textual_inversion import textual_inversion, saving_settings\r\nfrom modules.textual_inversion.learn_schedule import LearnRateScheduler\r\nfrom torch import einsum\r\nfrom torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_\r\n\r\nfrom collections import deque\r\nfrom statistics import stdev, mean\r\n\r\n\r\noptimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != \"Optimizer\"}\r\n\r\nclass HypernetworkModule(torch.nn.Module):\r\n    activation_dict = {\r\n        \"linear\": torch.nn.Identity,\r\n        \"relu\": torch.nn.ReLU,\r\n        \"leakyrelu\": torch.nn.LeakyReLU,\r\n        \"elu\": torch.nn.ELU,\r\n        \"swish\": torch.nn.Hardswish,\r\n        \"tanh\": torch.nn.Tanh,\r\n        \"sigmoid\": torch.nn.Sigmoid,\r\n    }\r\n    activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})\r\n\r\n    def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',\r\n                 add_layer_norm=False, activate_output=False, dropout_structure=None):\r\n        super().__init__()\r\n\r\n        self.multiplier = 1.0\r\n\r\n        assert layer_structure is not None, \"layer_structure must not be None\"\r\n        assert layer_structure[0] == 1, \"Multiplier Sequence should start with size 1!\"\r\n        assert layer_structure[-1] == 1, \"Multiplier Sequence should end with size 1!\"\r\n\r\n        linears = []\r\n        for i in range(len(layer_structure) - 1):\r\n\r\n            # Add a fully-connected layer\r\n            linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))\r\n\r\n            # Add an activation func except last layer\r\n            if activation_func == \"linear\" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output):\r\n                pass\r\n            elif activation_func in self.activation_dict:\r\n                linears.append(self.activation_dict[activation_func]())\r\n            else:\r\n                raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')\r\n\r\n            # Add layer normalization\r\n            if add_layer_norm:\r\n                linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))\r\n\r\n            # Everything should be now parsed into dropout structure, and applied here.\r\n            # Since we only have dropouts after layers, dropout structure should start with 0 and end with 0.\r\n            if dropout_structure is not None and dropout_structure[i+1] > 0:\r\n                assert 0 < dropout_structure[i+1] < 1, \"Dropout probability should be 0 or float between 0 and 1!\"\r\n                linears.append(torch.nn.Dropout(p=dropout_structure[i+1]))\r\n            # Code explanation : [1, 2, 1] -> dropout is missing when last_layer_dropout is false. [1, 2, 2, 1] -> [0, 0.3, 0, 0], when its True, [0, 0.3, 0.3, 0].\r\n\r\n        self.linear = torch.nn.Sequential(*linears)\r\n\r\n        if state_dict is not None:\r\n            self.fix_old_state_dict(state_dict)\r\n            self.load_state_dict(state_dict)\r\n        else:\r\n            for layer in self.linear:\r\n                if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:\r\n                    w, b = layer.weight.data, layer.bias.data\r\n                    if weight_init == \"Normal\" or type(layer) == torch.nn.LayerNorm:\r\n                        normal_(w, mean=0.0, std=0.01)\r\n                        normal_(b, mean=0.0, std=0)\r\n                    elif weight_init == 'XavierUniform':\r\n                        xavier_uniform_(w)\r\n                        zeros_(b)\r\n                    elif weight_init == 'XavierNormal':\r\n                        xavier_normal_(w)\r\n                        zeros_(b)\r\n                    elif weight_init == 'KaimingUniform':\r\n                        kaiming_uniform_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')\r\n                        zeros_(b)\r\n                    elif weight_init == 'KaimingNormal':\r\n                        kaiming_normal_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')\r\n                        zeros_(b)\r\n                    else:\r\n                        raise KeyError(f\"Key {weight_init} is not defined as initialization!\")\r\n        devices.torch_npu_set_device()\r\n        self.to(devices.device)\r\n\r\n    def fix_old_state_dict(self, state_dict):\r\n        changes = {\r\n            'linear1.bias': 'linear.0.bias',\r\n            'linear1.weight': 'linear.0.weight',\r\n            'linear2.bias': 'linear.1.bias',\r\n            'linear2.weight': 'linear.1.weight',\r\n        }\r\n\r\n        for fr, to in changes.items():\r\n            x = state_dict.get(fr, None)\r\n            if x is None:\r\n                continue\r\n\r\n            del state_dict[fr]\r\n            state_dict[to] = x\r\n\r\n    def forward(self, x):\r\n        return x + self.linear(x) * (self.multiplier if not self.training else 1)\r\n\r\n    def trainables(self):\r\n        layer_structure = []\r\n        for layer in self.linear:\r\n            if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:\r\n                layer_structure += [layer.weight, layer.bias]\r\n        return layer_structure\r\n\r\n\r\n#param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check.\r\ndef parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout):\r\n    if layer_structure is None:\r\n        layer_structure = [1, 2, 1]\r\n    if not use_dropout:\r\n        return [0] * len(layer_structure)\r\n    dropout_values = [0]\r\n    dropout_values.extend([0.3] * (len(layer_structure) - 3))\r\n    if last_layer_dropout:\r\n        dropout_values.append(0.3)\r\n    else:\r\n        dropout_values.append(0)\r\n    dropout_values.append(0)\r\n    return dropout_values\r\n\r\n\r\nclass Hypernetwork:\r\n    filename = None\r\n    name = None\r\n\r\n    def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs):\r\n        self.filename = None\r\n        self.name = name\r\n        self.layers = {}\r\n        self.step = 0\r\n        self.sd_checkpoint = None\r\n        self.sd_checkpoint_name = None\r\n        self.layer_structure = layer_structure\r\n        self.activation_func = activation_func\r\n        self.weight_init = weight_init\r\n        self.add_layer_norm = add_layer_norm\r\n        self.use_dropout = use_dropout\r\n        self.activate_output = activate_output\r\n        self.last_layer_dropout = kwargs.get('last_layer_dropout', True)\r\n        self.dropout_structure = kwargs.get('dropout_structure', None)\r\n        if self.dropout_structure is None:\r\n            self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)\r\n        self.optimizer_name = None\r\n        self.optimizer_state_dict = None\r\n        self.optional_info = None\r\n\r\n        for size in enable_sizes or []:\r\n            self.layers[size] = (\r\n                HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,\r\n                                   self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),\r\n                HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,\r\n                                   self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),\r\n            )\r\n        self.eval()\r\n\r\n    def weights(self):\r\n        res = []\r\n        for layers in self.layers.values():\r\n            for layer in layers:\r\n                res += layer.parameters()\r\n        return res\r\n\r\n    def train(self, mode=True):\r\n        for layers in self.layers.values():\r\n            for layer in layers:\r\n                layer.train(mode=mode)\r\n                for param in layer.parameters():\r\n                    param.requires_grad = mode\r\n\r\n    def to(self, device):\r\n        for layers in self.layers.values():\r\n            for layer in layers:\r\n                layer.to(device)\r\n\r\n        return self\r\n\r\n    def set_multiplier(self, multiplier):\r\n        for layers in self.layers.values():\r\n            for layer in layers:\r\n                layer.multiplier = multiplier\r\n\r\n        return self\r\n\r\n    def eval(self):\r\n        for layers in self.layers.values():\r\n            for layer in layers:\r\n                layer.eval()\r\n                for param in layer.parameters():\r\n                    param.requires_grad = False\r\n\r\n    def save(self, filename):\r\n        state_dict = {}\r\n        optimizer_saved_dict = {}\r\n\r\n        for k, v in self.layers.items():\r\n            state_dict[k] = (v[0].state_dict(), v[1].state_dict())\r\n\r\n        state_dict['step'] = self.step\r\n        state_dict['name'] = self.name\r\n        state_dict['layer_structure'] = self.layer_structure\r\n        state_dict['activation_func'] = self.activation_func\r\n        state_dict['is_layer_norm'] = self.add_layer_norm\r\n        state_dict['weight_initialization'] = self.weight_init\r\n        state_dict['sd_checkpoint'] = self.sd_checkpoint\r\n        state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name\r\n        state_dict['activate_output'] = self.activate_output\r\n        state_dict['use_dropout'] = self.use_dropout\r\n        state_dict['dropout_structure'] = self.dropout_structure\r\n        state_dict['last_layer_dropout'] = (self.dropout_structure[-2] != 0) if self.dropout_structure is not None else self.last_layer_dropout\r\n        state_dict['optional_info'] = self.optional_info if self.optional_info else None\r\n\r\n        if self.optimizer_name is not None:\r\n            optimizer_saved_dict['optimizer_name'] = self.optimizer_name\r\n\r\n        torch.save(state_dict, filename)\r\n        if shared.opts.save_optimizer_state and self.optimizer_state_dict:\r\n            optimizer_saved_dict['hash'] = self.shorthash()\r\n            optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict\r\n            torch.save(optimizer_saved_dict, filename + '.optim')\r\n\r\n    def load(self, filename):\r\n        self.filename = filename\r\n        if self.name is None:\r\n            self.name = os.path.splitext(os.path.basename(filename))[0]\r\n\r\n        state_dict = torch.load(filename, map_location='cpu')\r\n\r\n        self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])\r\n        self.optional_info = state_dict.get('optional_info', None)\r\n        self.activation_func = state_dict.get('activation_func', None)\r\n        self.weight_init = state_dict.get('weight_initialization', 'Normal')\r\n        self.add_layer_norm = state_dict.get('is_layer_norm', False)\r\n        self.dropout_structure = state_dict.get('dropout_structure', None)\r\n        self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False)\r\n        self.activate_output = state_dict.get('activate_output', True)\r\n        self.last_layer_dropout = state_dict.get('last_layer_dropout', False)\r\n        # Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0.\r\n        if self.dropout_structure is None:\r\n            self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)\r\n\r\n        if shared.opts.print_hypernet_extra:\r\n            if self.optional_info is not None:\r\n                print(f\"  INFO:\\n {self.optional_info}\\n\")\r\n\r\n            print(f\"  Layer structure: {self.layer_structure}\")\r\n            print(f\"  Activation function: {self.activation_func}\")\r\n            print(f\"  Weight initialization: {self.weight_init}\")\r\n            print(f\"  Layer norm: {self.add_layer_norm}\")\r\n            print(f\"  Dropout usage: {self.use_dropout}\" )\r\n            print(f\"  Activate last layer: {self.activate_output}\")\r\n            print(f\"  Dropout structure: {self.dropout_structure}\")\r\n\r\n        optimizer_saved_dict = torch.load(self.filename + '.optim', map_location='cpu') if os.path.exists(self.filename + '.optim') else {}\r\n\r\n        if self.shorthash() == optimizer_saved_dict.get('hash', None):\r\n            self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)\r\n        else:\r\n            self.optimizer_state_dict = None\r\n        if self.optimizer_state_dict:\r\n            self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')\r\n            if shared.opts.print_hypernet_extra:\r\n                print(\"Loaded existing optimizer from checkpoint\")\r\n                print(f\"Optimizer name is {self.optimizer_name}\")\r\n        else:\r\n            self.optimizer_name = \"AdamW\"\r\n            if shared.opts.print_hypernet_extra:\r\n                print(\"No saved optimizer exists in checkpoint\")\r\n\r\n        for size, sd in state_dict.items():\r\n            if type(size) == int:\r\n                self.layers[size] = (\r\n                    HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,\r\n                                       self.add_layer_norm, self.activate_output, self.dropout_structure),\r\n                    HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,\r\n                                       self.add_layer_norm, self.activate_output, self.dropout_structure),\r\n                )\r\n\r\n        self.name = state_dict.get('name', self.name)\r\n        self.step = state_dict.get('step', 0)\r\n        self.sd_checkpoint = state_dict.get('sd_checkpoint', None)\r\n        self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)\r\n        self.eval()\r\n\r\n    def shorthash(self):\r\n        sha256 = hashes.sha256(self.filename, f'hypernet/{self.name}')\r\n\r\n        return sha256[0:10] if sha256 else None\r\n\r\n\r\ndef list_hypernetworks(path):\r\n    res = {}\r\n    for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True), key=str.lower):\r\n        name = os.path.splitext(os.path.basename(filename))[0]\r\n        # Prevent a hypothetical \"None.pt\" from being listed.\r\n        if name != \"None\":\r\n            res[name] = filename\r\n    return res\r\n\r\n\r\ndef load_hypernetwork(name):\r\n    path = shared.hypernetworks.get(name, None)\r\n\r\n    if path is None:\r\n        return None\r\n\r\n    try:\r\n        hypernetwork = Hypernetwork()\r\n        hypernetwork.load(path)\r\n        return hypernetwork\r\n    except Exception:\r\n        errors.report(f\"Error loading hypernetwork {path}\", exc_info=True)\r\n        return None\r\n\r\n\r\ndef load_hypernetworks(names, multipliers=None):\r\n    already_loaded = {}\r\n\r\n    for hypernetwork in shared.loaded_hypernetworks:\r\n        if hypernetwork.name in names:\r\n            already_loaded[hypernetwork.name] = hypernetwork\r\n\r\n    shared.loaded_hypernetworks.clear()\r\n\r\n    for i, name in enumerate(names):\r\n        hypernetwork = already_loaded.get(name, None)\r\n        if hypernetwork is None:\r\n            hypernetwork = load_hypernetwork(name)\r\n\r\n        if hypernetwork is None:\r\n            continue\r\n\r\n        hypernetwork.set_multiplier(multipliers[i] if multipliers else 1.0)\r\n        shared.loaded_hypernetworks.append(hypernetwork)\r\n\r\n\r\ndef apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):\r\n    hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)\r\n\r\n    if hypernetwork_layers is None:\r\n        return context_k, context_v\r\n\r\n    if layer is not None:\r\n        layer.hyper_k = hypernetwork_layers[0]\r\n        layer.hyper_v = hypernetwork_layers[1]\r\n\r\n    context_k = devices.cond_cast_unet(hypernetwork_layers[0](devices.cond_cast_float(context_k)))\r\n    context_v = devices.cond_cast_unet(hypernetwork_layers[1](devices.cond_cast_float(context_v)))\r\n    return context_k, context_v\r\n\r\n\r\ndef apply_hypernetworks(hypernetworks, context, layer=None):\r\n    context_k = context\r\n    context_v = context\r\n    for hypernetwork in hypernetworks:\r\n        context_k, context_v = apply_single_hypernetwork(hypernetwork, context_k, context_v, layer)\r\n\r\n    return context_k, context_v\r\n\r\n\r\ndef attention_CrossAttention_forward(self, x, context=None, mask=None, **kwargs):\r\n    h = self.heads\r\n\r\n    q = self.to_q(x)\r\n    context = default(context, x)\r\n\r\n    context_k, context_v = apply_hypernetworks(shared.loaded_hypernetworks, context, self)\r\n    k = self.to_k(context_k)\r\n    v = self.to_v(context_v)\r\n\r\n    q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))\r\n\r\n    sim = einsum('b i d, b j d -> b i j', q, k) * self.scale\r\n\r\n    if mask is not None:\r\n        mask = rearrange(mask, 'b ... -> b (...)')\r\n        max_neg_value = -torch.finfo(sim.dtype).max\r\n        mask = repeat(mask, 'b j -> (b h) () j', h=h)\r\n        sim.masked_fill_(~mask, max_neg_value)\r\n\r\n    # attention, what we cannot get enough of\r\n    attn = sim.softmax(dim=-1)\r\n\r\n    out = einsum('b i j, b j d -> b i d', attn, v)\r\n    out = rearrange(out, '(b h) n d -> b n (h d)', h=h)\r\n    return self.to_out(out)\r\n\r\n\r\ndef stack_conds(conds):\r\n    if len(conds) == 1:\r\n        return torch.stack(conds)\r\n\r\n    # same as in reconstruct_multicond_batch\r\n    token_count = max([x.shape[0] for x in conds])\r\n    for i in range(len(conds)):\r\n        if conds[i].shape[0] != token_count:\r\n            last_vector = conds[i][-1:]\r\n            last_vector_repeated = last_vector.repeat([token_count - conds[i].shape[0], 1])\r\n            conds[i] = torch.vstack([conds[i], last_vector_repeated])\r\n\r\n    return torch.stack(conds)\r\n\r\n\r\ndef statistics(data):\r\n    if len(data) < 2:\r\n        std = 0\r\n    else:\r\n        std = stdev(data)\r\n    total_information = f\"loss:{mean(data):.3f}\" + u\"\\u00B1\" + f\"({std/ (len(data) ** 0.5):.3f})\"\r\n    recent_data = data[-32:]\r\n    if len(recent_data) < 2:\r\n        std = 0\r\n    else:\r\n        std = stdev(recent_data)\r\n    recent_information = f\"recent 32 loss:{mean(recent_data):.3f}\" + u\"\\u00B1\" + f\"({std / (len(recent_data) ** 0.5):.3f})\"\r\n    return total_information, recent_information\r\n\r\n\r\ndef create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):\r\n    # Remove illegal characters from name.\r\n    name = \"\".join( x for x in name if (x.isalnum() or x in \"._- \"))\r\n    assert name, \"Name cannot be empty!\"\r\n\r\n    fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f\"{name}.pt\")\r\n    if not overwrite_old:\r\n        assert not os.path.exists(fn), f\"file {fn} already exists\"\r\n\r\n    if type(layer_structure) == str:\r\n        layer_structure = [float(x.strip()) for x in layer_structure.split(\",\")]\r\n\r\n    if use_dropout and dropout_structure and type(dropout_structure) == str:\r\n        dropout_structure = [float(x.strip()) for x in dropout_structure.split(\",\")]\r\n    else:\r\n        dropout_structure = [0] * len(layer_structure)\r\n\r\n    hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(\r\n        name=name,\r\n        enable_sizes=[int(x) for x in enable_sizes],\r\n        layer_structure=layer_structure,\r\n        activation_func=activation_func,\r\n        weight_init=weight_init,\r\n        add_layer_norm=add_layer_norm,\r\n        use_dropout=use_dropout,\r\n        dropout_structure=dropout_structure\r\n    )\r\n    hypernet.save(fn)\r\n\r\n    shared.reload_hypernetworks()\r\n\r\n\r\ndef train_hypernetwork(id_task, hypernetwork_name: str, learn_rate: float, batch_size: int, gradient_step: int, data_root: str, log_directory: str, training_width: int, training_height: int, varsize: bool, steps: int, clip_grad_mode: str, clip_grad_value: float, shuffle_tags: bool, tag_drop_out: bool, latent_sampling_method: str, use_weight: bool, create_image_every: int, save_hypernetwork_every: int, template_filename: str, preview_from_txt2img: bool, preview_prompt: str, preview_negative_prompt: str, preview_steps: int, preview_sampler_name: str, preview_cfg_scale: float, preview_seed: int, preview_width: int, preview_height: int):\r\n    from modules import images, processing\r\n\r\n    save_hypernetwork_every = save_hypernetwork_every or 0\r\n    create_image_every = create_image_every or 0\r\n    template_file = textual_inversion.textual_inversion_templates.get(template_filename, None)\r\n    textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_hypernetwork_every, create_image_every, log_directory, name=\"hypernetwork\")\r\n    template_file = template_file.path\r\n\r\n    path = shared.hypernetworks.get(hypernetwork_name, None)\r\n    hypernetwork = Hypernetwork()\r\n    hypernetwork.load(path)\r\n    shared.loaded_hypernetworks = [hypernetwork]\r\n\r\n    shared.state.job = \"train-hypernetwork\"\r\n    shared.state.textinfo = \"Initializing hypernetwork training...\"\r\n    shared.state.job_count = steps\r\n\r\n    hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]\r\n    filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')\r\n\r\n    log_directory = os.path.join(log_directory, datetime.datetime.now().strftime(\"%Y-%m-%d\"), hypernetwork_name)\r\n    unload = shared.opts.unload_models_when_training\r\n\r\n    if save_hypernetwork_every > 0:\r\n        hypernetwork_dir = os.path.join(log_directory, \"hypernetworks\")\r\n        os.makedirs(hypernetwork_dir, exist_ok=True)\r\n    else:\r\n        hypernetwork_dir = None\r\n\r\n    if create_image_every > 0:\r\n        images_dir = os.path.join(log_directory, \"images\")\r\n        os.makedirs(images_dir, exist_ok=True)\r\n    else:\r\n        images_dir = None\r\n\r\n    checkpoint = sd_models.select_checkpoint()\r\n\r\n    initial_step = hypernetwork.step or 0\r\n    if initial_step >= steps:\r\n        shared.state.textinfo = \"Model has already been trained beyond specified max steps\"\r\n        return hypernetwork, filename\r\n\r\n    scheduler = LearnRateScheduler(learn_rate, steps, initial_step)\r\n\r\n    clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == \"value\" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == \"norm\" else None\r\n    if clip_grad:\r\n        clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)\r\n\r\n    if shared.opts.training_enable_tensorboard:\r\n        tensorboard_writer = textual_inversion.tensorboard_setup(log_directory)\r\n\r\n    # dataset loading may take a while, so input validations and early returns should be done before this\r\n    shared.state.textinfo = f\"Preparing dataset from {html.escape(data_root)}...\"\r\n\r\n    pin_memory = shared.opts.pin_memory\r\n\r\n    ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize, use_weight=use_weight)\r\n\r\n    if shared.opts.save_training_settings_to_txt:\r\n        saved_params = dict(\r\n            model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds),\r\n            **{field: getattr(hypernetwork, field) for field in ['layer_structure', 'activation_func', 'weight_init', 'add_layer_norm', 'use_dropout', ]}\r\n        )\r\n        saving_settings.save_settings_to_file(log_directory, {**saved_params, **locals()})\r\n\r\n    latent_sampling_method = ds.latent_sampling_method\r\n\r\n    dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)\r\n\r\n    old_parallel_processing_allowed = shared.parallel_processing_allowed\r\n\r\n    if unload:\r\n        shared.parallel_processing_allowed = False\r\n        shared.sd_model.cond_stage_model.to(devices.cpu)\r\n        shared.sd_model.first_stage_model.to(devices.cpu)\r\n\r\n    weights = hypernetwork.weights()\r\n    hypernetwork.train()\r\n\r\n    # Here we use optimizer from saved HN, or we can specify as UI option.\r\n    if hypernetwork.optimizer_name in optimizer_dict:\r\n        optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)\r\n        optimizer_name = hypernetwork.optimizer_name\r\n    else:\r\n        print(f\"Optimizer type {hypernetwork.optimizer_name} is not defined!\")\r\n        optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)\r\n        optimizer_name = 'AdamW'\r\n\r\n    if hypernetwork.optimizer_state_dict:  # This line must be changed if Optimizer type can be different from saved optimizer.\r\n        try:\r\n            optimizer.load_state_dict(hypernetwork.optimizer_state_dict)\r\n        except RuntimeError as e:\r\n            print(\"Cannot resume from saved optimizer!\")\r\n            print(e)\r\n\r\n    scaler = torch.cuda.amp.GradScaler()\r\n\r\n    batch_size = ds.batch_size\r\n    gradient_step = ds.gradient_step\r\n    # n steps = batch_size * gradient_step * n image processed\r\n    steps_per_epoch = len(ds) // batch_size // gradient_step\r\n    max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step\r\n    loss_step = 0\r\n    _loss_step = 0 #internal\r\n    # size = len(ds.indexes)\r\n    # loss_dict = defaultdict(lambda : deque(maxlen = 1024))\r\n    loss_logging = deque(maxlen=len(ds) * 3)  # this should be configurable parameter, this is 3 * epoch(dataset size)\r\n    # losses = torch.zeros((size,))\r\n    # previous_mean_losses = [0]\r\n    # previous_mean_loss = 0\r\n    # print(\"Mean loss of {} elements\".format(size))\r\n\r\n    steps_without_grad = 0\r\n\r\n    last_saved_file = \"<none>\"\r\n    last_saved_image = \"<none>\"\r\n    forced_filename = \"<none>\"\r\n\r\n    pbar = tqdm.tqdm(total=steps - initial_step)\r\n    try:\r\n        sd_hijack_checkpoint.add()\r\n\r\n        for _ in range((steps-initial_step) * gradient_step):\r\n            if scheduler.finished:\r\n                break\r\n            if shared.state.interrupted:\r\n                break\r\n            for j, batch in enumerate(dl):\r\n                # works as a drop_last=True for gradient accumulation\r\n                if j == max_steps_per_epoch:\r\n                    break\r\n                scheduler.apply(optimizer, hypernetwork.step)\r\n                if scheduler.finished:\r\n                    break\r\n                if shared.state.interrupted:\r\n                    break\r\n\r\n                if clip_grad:\r\n                    clip_grad_sched.step(hypernetwork.step)\r\n\r\n                with devices.autocast():\r\n                    x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)\r\n                    if use_weight:\r\n                        w = batch.weight.to(devices.device, non_blocking=pin_memory)\r\n                    if tag_drop_out != 0 or shuffle_tags:\r\n                        shared.sd_model.cond_stage_model.to(devices.device)\r\n                        c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)\r\n                        shared.sd_model.cond_stage_model.to(devices.cpu)\r\n                    else:\r\n                        c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)\r\n                    if use_weight:\r\n                        loss = shared.sd_model.weighted_forward(x, c, w)[0] / gradient_step\r\n                        del w\r\n                    else:\r\n                        loss = shared.sd_model.forward(x, c)[0] / gradient_step\r\n                    del x\r\n                    del c\r\n\r\n                    _loss_step += loss.item()\r\n                scaler.scale(loss).backward()\r\n\r\n                # go back until we reach gradient accumulation steps\r\n                if (j + 1) % gradient_step != 0:\r\n                    continue\r\n                loss_logging.append(_loss_step)\r\n                if clip_grad:\r\n                    clip_grad(weights, clip_grad_sched.learn_rate)\r\n\r\n                scaler.step(optimizer)\r\n                scaler.update()\r\n                hypernetwork.step += 1\r\n                pbar.update()\r\n                optimizer.zero_grad(set_to_none=True)\r\n                loss_step = _loss_step\r\n                _loss_step = 0\r\n\r\n                steps_done = hypernetwork.step + 1\r\n\r\n                epoch_num = hypernetwork.step // steps_per_epoch\r\n                epoch_step = hypernetwork.step % steps_per_epoch\r\n\r\n                description = f\"Training hypernetwork [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}\"\r\n                pbar.set_description(description)\r\n                if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:\r\n                    # Before saving, change name to match current checkpoint.\r\n                    hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'\r\n                    last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')\r\n                    hypernetwork.optimizer_name = optimizer_name\r\n                    if shared.opts.save_optimizer_state:\r\n                        hypernetwork.optimizer_state_dict = optimizer.state_dict()\r\n                    save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)\r\n                    hypernetwork.optimizer_state_dict = None  # dereference it after saving, to save memory.\r\n\r\n\r\n\r\n                if shared.opts.training_enable_tensorboard:\r\n                    epoch_num = hypernetwork.step // len(ds)\r\n                    epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1\r\n                    mean_loss = sum(loss_logging) / len(loss_logging)\r\n                    textual_inversion.tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num)\r\n\r\n                textual_inversion.write_loss(log_directory, \"hypernetwork_loss.csv\", hypernetwork.step, steps_per_epoch, {\r\n                    \"loss\": f\"{loss_step:.7f}\",\r\n                    \"learn_rate\": scheduler.learn_rate\r\n                })\r\n\r\n                if images_dir is not None and steps_done % create_image_every == 0:\r\n                    forced_filename = f'{hypernetwork_name}-{steps_done}'\r\n                    last_saved_image = os.path.join(images_dir, forced_filename)\r\n                    hypernetwork.eval()\r\n                    rng_state = torch.get_rng_state()\r\n                    cuda_rng_state = None\r\n                    if torch.cuda.is_available():\r\n                        cuda_rng_state = torch.cuda.get_rng_state_all()\r\n                    shared.sd_model.cond_stage_model.to(devices.device)\r\n                    shared.sd_model.first_stage_model.to(devices.device)\r\n\r\n                    p = processing.StableDiffusionProcessingTxt2Img(\r\n                        sd_model=shared.sd_model,\r\n                        do_not_save_grid=True,\r\n                        do_not_save_samples=True,\r\n                    )\r\n\r\n                    p.disable_extra_networks = True\r\n\r\n                    if preview_from_txt2img:\r\n                        p.prompt = preview_prompt\r\n                        p.negative_prompt = preview_negative_prompt\r\n                        p.steps = preview_steps\r\n                        p.sampler_name = sd_samplers.samplers_map[preview_sampler_name.lower()]\r\n                        p.cfg_scale = preview_cfg_scale\r\n                        p.seed = preview_seed\r\n                        p.width = preview_width\r\n                        p.height = preview_height\r\n                    else:\r\n                        p.prompt = batch.cond_text[0]\r\n                        p.steps = 20\r\n                        p.width = training_width\r\n                        p.height = training_height\r\n\r\n                    preview_text = p.prompt\r\n\r\n                    with closing(p):\r\n                        processed = processing.process_images(p)\r\n                        image = processed.images[0] if len(processed.images) > 0 else None\r\n\r\n                    if unload:\r\n                        shared.sd_model.cond_stage_model.to(devices.cpu)\r\n                        shared.sd_model.first_stage_model.to(devices.cpu)\r\n                    torch.set_rng_state(rng_state)\r\n                    if torch.cuda.is_available():\r\n                        torch.cuda.set_rng_state_all(cuda_rng_state)\r\n                    hypernetwork.train()\r\n                    if image is not None:\r\n                        shared.state.assign_current_image(image)\r\n                        if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:\r\n                            textual_inversion.tensorboard_add_image(tensorboard_writer,\r\n                                                                    f\"Validation at epoch {epoch_num}\", image,\r\n                                                                    hypernetwork.step)\r\n                        last_saved_image, last_text_info = images.save_image(image, images_dir, \"\", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)\r\n                        last_saved_image += f\", prompt: {preview_text}\"\r\n\r\n                shared.state.job_no = hypernetwork.step\r\n\r\n                shared.state.textinfo = f\"\"\"\r\n<p>\r\nLoss: {loss_step:.7f}<br/>\r\nStep: {steps_done}<br/>\r\nLast prompt: {html.escape(batch.cond_text[0])}<br/>\r\nLast saved hypernetwork: {html.escape(last_saved_file)}<br/>\r\nLast saved image: {html.escape(last_saved_image)}<br/>\r\n</p>\r\n\"\"\"\r\n    except Exception:\r\n        errors.report(\"Exception in training hypernetwork\", exc_info=True)\r\n    finally:\r\n        pbar.leave = False\r\n        pbar.close()\r\n        hypernetwork.eval()\r\n        sd_hijack_checkpoint.remove()\r\n\r\n\r\n\r\n    filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')\r\n    hypernetwork.optimizer_name = optimizer_name\r\n    if shared.opts.save_optimizer_state:\r\n        hypernetwork.optimizer_state_dict = optimizer.state_dict()\r\n    save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)\r\n\r\n    del optimizer\r\n    hypernetwork.optimizer_state_dict = None  # dereference it after saving, to save memory.\r\n    shared.sd_model.cond_stage_model.to(devices.device)\r\n    shared.sd_model.first_stage_model.to(devices.device)\r\n    shared.parallel_processing_allowed = old_parallel_processing_allowed\r\n\r\n    return hypernetwork, filename\r\n\r\ndef save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):\r\n    old_hypernetwork_name = hypernetwork.name\r\n    old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, \"sd_checkpoint\") else None\r\n    old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, \"sd_checkpoint_name\") else None\r\n    try:\r\n        hypernetwork.sd_checkpoint = checkpoint.shorthash\r\n        hypernetwork.sd_checkpoint_name = checkpoint.model_name\r\n        hypernetwork.name = hypernetwork_name\r\n        hypernetwork.save(filename)\r\n    except:\r\n        hypernetwork.sd_checkpoint = old_sd_checkpoint\r\n        hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name\r\n        hypernetwork.name = old_hypernetwork_name\r\n        raise\r\n"
  },
  {
    "path": "modules/hypernetworks/ui.py",
    "content": "import html\r\n\r\nimport gradio as gr\r\nimport modules.hypernetworks.hypernetwork\r\nfrom modules import devices, sd_hijack, shared\r\n\r\nnot_available = [\"hardswish\", \"multiheadattention\"]\r\nkeys = [x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict if x not in not_available]\r\n\r\n\r\ndef create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):\r\n    filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)\r\n\r\n    return gr.Dropdown.update(choices=sorted(shared.hypernetworks)), f\"Created: {filename}\", \"\"\r\n\r\n\r\ndef train_hypernetwork(*args):\r\n    shared.loaded_hypernetworks = []\r\n\r\n    assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible'\r\n\r\n    try:\r\n        sd_hijack.undo_optimizations()\r\n\r\n        hypernetwork, filename = modules.hypernetworks.hypernetwork.train_hypernetwork(*args)\r\n\r\n        res = f\"\"\"\r\nTraining {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps.\r\nHypernetwork saved to {html.escape(filename)}\r\n\"\"\"\r\n        return res, \"\"\r\n    except Exception:\r\n        raise\r\n    finally:\r\n        shared.sd_model.cond_stage_model.to(devices.device)\r\n        shared.sd_model.first_stage_model.to(devices.device)\r\n        sd_hijack.apply_optimizations()\r\n\r\n"
  },
  {
    "path": "modules/images.py",
    "content": "from __future__ import annotations\r\n\r\nimport datetime\r\nimport functools\r\nimport pytz\r\nimport io\r\nimport math\r\nimport os\r\nfrom collections import namedtuple\r\nimport re\r\n\r\nimport numpy as np\r\nimport piexif\r\nimport piexif.helper\r\nfrom PIL import Image, ImageFont, ImageDraw, ImageColor, PngImagePlugin, ImageOps\r\n# pillow_avif needs to be imported somewhere in code for it to work\r\nimport pillow_avif # noqa: F401\r\nimport string\r\nimport json\r\nimport hashlib\r\n\r\nfrom modules import sd_samplers, shared, script_callbacks, errors\r\nfrom modules.paths_internal import roboto_ttf_file\r\nfrom modules.shared import opts\r\n\r\nLANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)\r\n\r\n\r\ndef get_font(fontsize: int):\r\n    try:\r\n        return ImageFont.truetype(opts.font or roboto_ttf_file, fontsize)\r\n    except Exception:\r\n        return ImageFont.truetype(roboto_ttf_file, fontsize)\r\n\r\n\r\ndef image_grid(imgs, batch_size=1, rows=None):\r\n    if rows is None:\r\n        if opts.n_rows > 0:\r\n            rows = opts.n_rows\r\n        elif opts.n_rows == 0:\r\n            rows = batch_size\r\n        elif opts.grid_prevent_empty_spots:\r\n            rows = math.floor(math.sqrt(len(imgs)))\r\n            while len(imgs) % rows != 0:\r\n                rows -= 1\r\n        else:\r\n            rows = math.sqrt(len(imgs))\r\n            rows = round(rows)\r\n    if rows > len(imgs):\r\n        rows = len(imgs)\r\n\r\n    cols = math.ceil(len(imgs) / rows)\r\n\r\n    params = script_callbacks.ImageGridLoopParams(imgs, cols, rows)\r\n    script_callbacks.image_grid_callback(params)\r\n\r\n    w, h = map(max, zip(*(img.size for img in imgs)))\r\n    grid_background_color = ImageColor.getcolor(opts.grid_background_color, 'RGB')\r\n    grid = Image.new('RGB', size=(params.cols * w, params.rows * h), color=grid_background_color)\r\n\r\n    for i, img in enumerate(params.imgs):\r\n        img_w, img_h = img.size\r\n        w_offset, h_offset = 0 if img_w == w else (w - img_w) // 2, 0 if img_h == h else (h - img_h) // 2\r\n        grid.paste(img, box=(i % params.cols * w + w_offset, i // params.cols * h + h_offset))\r\n\r\n    return grid\r\n\r\n\r\nclass Grid(namedtuple(\"_Grid\", [\"tiles\", \"tile_w\", \"tile_h\", \"image_w\", \"image_h\", \"overlap\"])):\r\n    @property\r\n    def tile_count(self) -> int:\r\n        \"\"\"\r\n        The total number of tiles in the grid.\r\n        \"\"\"\r\n        return sum(len(row[2]) for row in self.tiles)\r\n\r\n\r\ndef split_grid(image: Image.Image, tile_w: int = 512, tile_h: int = 512, overlap: int = 64) -> Grid:\r\n    w, h = image.size\r\n\r\n    non_overlap_width = tile_w - overlap\r\n    non_overlap_height = tile_h - overlap\r\n\r\n    cols = math.ceil((w - overlap) / non_overlap_width)\r\n    rows = math.ceil((h - overlap) / non_overlap_height)\r\n\r\n    dx = (w - tile_w) / (cols - 1) if cols > 1 else 0\r\n    dy = (h - tile_h) / (rows - 1) if rows > 1 else 0\r\n\r\n    grid = Grid([], tile_w, tile_h, w, h, overlap)\r\n    for row in range(rows):\r\n        row_images = []\r\n\r\n        y = int(row * dy)\r\n\r\n        if y + tile_h >= h:\r\n            y = h - tile_h\r\n\r\n        for col in range(cols):\r\n            x = int(col * dx)\r\n\r\n            if x + tile_w >= w:\r\n                x = w - tile_w\r\n\r\n            tile = image.crop((x, y, x + tile_w, y + tile_h))\r\n\r\n            row_images.append([x, tile_w, tile])\r\n\r\n        grid.tiles.append([y, tile_h, row_images])\r\n\r\n    return grid\r\n\r\n\r\ndef combine_grid(grid):\r\n    def make_mask_image(r):\r\n        r = r * 255 / grid.overlap\r\n        r = r.astype(np.uint8)\r\n        return Image.fromarray(r, 'L')\r\n\r\n    mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))\r\n    mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))\r\n\r\n    combined_image = Image.new(\"RGB\", (grid.image_w, grid.image_h))\r\n    for y, h, row in grid.tiles:\r\n        combined_row = Image.new(\"RGB\", (grid.image_w, h))\r\n        for x, w, tile in row:\r\n            if x == 0:\r\n                combined_row.paste(tile, (0, 0))\r\n                continue\r\n\r\n            combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w)\r\n            combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0))\r\n\r\n        if y == 0:\r\n            combined_image.paste(combined_row, (0, 0))\r\n            continue\r\n\r\n        combined_image.paste(combined_row.crop((0, 0, combined_row.width, grid.overlap)), (0, y), mask=mask_h)\r\n        combined_image.paste(combined_row.crop((0, grid.overlap, combined_row.width, h)), (0, y + grid.overlap))\r\n\r\n    return combined_image\r\n\r\n\r\nclass GridAnnotation:\r\n    def __init__(self, text='', is_active=True):\r\n        self.text = text\r\n        self.is_active = is_active\r\n        self.size = None\r\n\r\n\r\ndef draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):\r\n\r\n    color_active = ImageColor.getcolor(opts.grid_text_active_color, 'RGB')\r\n    color_inactive = ImageColor.getcolor(opts.grid_text_inactive_color, 'RGB')\r\n    color_background = ImageColor.getcolor(opts.grid_background_color, 'RGB')\r\n\r\n    def wrap(drawing, text, font, line_length):\r\n        lines = ['']\r\n        for word in text.split():\r\n            line = f'{lines[-1]} {word}'.strip()\r\n            if drawing.textlength(line, font=font) <= line_length:\r\n                lines[-1] = line\r\n            else:\r\n                lines.append(word)\r\n        return lines\r\n\r\n    def draw_texts(drawing, draw_x, draw_y, lines, initial_fnt, initial_fontsize):\r\n        for line in lines:\r\n            fnt = initial_fnt\r\n            fontsize = initial_fontsize\r\n            while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0:\r\n                fontsize -= 1\r\n                fnt = get_font(fontsize)\r\n            drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor=\"mm\", align=\"center\")\r\n\r\n            if not line.is_active:\r\n                drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, draw_y + line.size[1] // 2), fill=color_inactive, width=4)\r\n\r\n            draw_y += line.size[1] + line_spacing\r\n\r\n    fontsize = (width + height) // 25\r\n    line_spacing = fontsize // 2\r\n\r\n    fnt = get_font(fontsize)\r\n\r\n    pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in ver_texts]) == 0 else width * 3 // 4\r\n\r\n    cols = im.width // width\r\n    rows = im.height // height\r\n\r\n    assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}'\r\n    assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}'\r\n\r\n    calc_img = Image.new(\"RGB\", (1, 1), color_background)\r\n    calc_d = ImageDraw.Draw(calc_img)\r\n\r\n    for texts, allowed_width in zip(hor_texts + ver_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts)):\r\n        items = [] + texts\r\n        texts.clear()\r\n\r\n        for line in items:\r\n            wrapped = wrap(calc_d, line.text, fnt, allowed_width)\r\n            texts += [GridAnnotation(x, line.is_active) for x in wrapped]\r\n\r\n        for line in texts:\r\n            bbox = calc_d.multiline_textbbox((0, 0), line.text, font=fnt)\r\n            line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])\r\n            line.allowed_width = allowed_width\r\n\r\n    hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]\r\n    ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_texts]\r\n\r\n    pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2\r\n\r\n    result = Image.new(\"RGB\", (im.width + pad_left + margin * (cols-1), im.height + pad_top + margin * (rows-1)), color_background)\r\n\r\n    for row in range(rows):\r\n        for col in range(cols):\r\n            cell = im.crop((width * col, height * row, width * (col+1), height * (row+1)))\r\n            result.paste(cell, (pad_left + (width + margin) * col, pad_top + (height + margin) * row))\r\n\r\n    d = ImageDraw.Draw(result)\r\n\r\n    for col in range(cols):\r\n        x = pad_left + (width + margin) * col + width / 2\r\n        y = pad_top / 2 - hor_text_heights[col] / 2\r\n\r\n        draw_texts(d, x, y, hor_texts[col], fnt, fontsize)\r\n\r\n    for row in range(rows):\r\n        x = pad_left / 2\r\n        y = pad_top + (height + margin) * row + height / 2 - ver_text_heights[row] / 2\r\n\r\n        draw_texts(d, x, y, ver_texts[row], fnt, fontsize)\r\n\r\n    return result\r\n\r\n\r\ndef draw_prompt_matrix(im, width, height, all_prompts, margin=0):\r\n    prompts = all_prompts[1:]\r\n    boundary = math.ceil(len(prompts) / 2)\r\n\r\n    prompts_horiz = prompts[:boundary]\r\n    prompts_vert = prompts[boundary:]\r\n\r\n    hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))]\r\n    ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))]\r\n\r\n    return draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin)\r\n\r\n\r\ndef resize_image(resize_mode, im, width, height, upscaler_name=None):\r\n    \"\"\"\r\n    Resizes an image with the specified resize_mode, width, and height.\r\n\r\n    Args:\r\n        resize_mode: The mode to use when resizing the image.\r\n            0: Resize the image to the specified width and height.\r\n            1: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.\r\n            2: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.\r\n        im: The image to resize.\r\n        width: The width to resize the image to.\r\n        height: The height to resize the image to.\r\n        upscaler_name: The name of the upscaler to use. If not provided, defaults to opts.upscaler_for_img2img.\r\n    \"\"\"\r\n\r\n    upscaler_name = upscaler_name or opts.upscaler_for_img2img\r\n\r\n    def resize(im, w, h):\r\n        if upscaler_name is None or upscaler_name == \"None\" or im.mode == 'L':\r\n            return im.resize((w, h), resample=LANCZOS)\r\n\r\n        scale = max(w / im.width, h / im.height)\r\n\r\n        if scale > 1.0:\r\n            upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name]\r\n            if len(upscalers) == 0:\r\n                upscaler = shared.sd_upscalers[0]\r\n                print(f\"could not find upscaler named {upscaler_name or '<empty string>'}, using {upscaler.name} as a fallback\")\r\n            else:\r\n                upscaler = upscalers[0]\r\n\r\n            im = upscaler.scaler.upscale(im, scale, upscaler.data_path)\r\n\r\n        if im.width != w or im.height != h:\r\n            im = im.resize((w, h), resample=LANCZOS)\r\n\r\n        return im\r\n\r\n    if resize_mode == 0:\r\n        res = resize(im, width, height)\r\n\r\n    elif resize_mode == 1:\r\n        ratio = width / height\r\n        src_ratio = im.width / im.height\r\n\r\n        src_w = width if ratio > src_ratio else im.width * height // im.height\r\n        src_h = height if ratio <= src_ratio else im.height * width // im.width\r\n\r\n        resized = resize(im, src_w, src_h)\r\n        res = Image.new(\"RGB\", (width, height))\r\n        res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))\r\n\r\n    else:\r\n        ratio = width / height\r\n        src_ratio = im.width / im.height\r\n\r\n        src_w = width if ratio < src_ratio else im.width * height // im.height\r\n        src_h = height if ratio >= src_ratio else im.height * width // im.width\r\n\r\n        resized = resize(im, src_w, src_h)\r\n        res = Image.new(\"RGB\", (width, height))\r\n        res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))\r\n\r\n        if ratio < src_ratio:\r\n            fill_height = height // 2 - src_h // 2\r\n            if fill_height > 0:\r\n                res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))\r\n                res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))\r\n        elif ratio > src_ratio:\r\n            fill_width = width // 2 - src_w // 2\r\n            if fill_width > 0:\r\n                res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))\r\n                res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))\r\n\r\n    return res\r\n\r\n\r\nif not shared.cmd_opts.unix_filenames_sanitization:\r\n    invalid_filename_chars = '#<>:\"/\\\\|?*\\n\\r\\t'\r\nelse:\r\n    invalid_filename_chars = '/'\r\ninvalid_filename_prefix = ' '\r\ninvalid_filename_postfix = ' .'\r\nre_nonletters = re.compile(r'[\\s' + string.punctuation + ']+')\r\nre_pattern = re.compile(r\"(.*?)(?:\\[([^\\[\\]]+)\\]|$)\")\r\nre_pattern_arg = re.compile(r\"(.*)<([^>]*)>$\")\r\nmax_filename_part_length = shared.cmd_opts.filenames_max_length\r\nNOTHING_AND_SKIP_PREVIOUS_TEXT = object()\r\n\r\n\r\ndef sanitize_filename_part(text, replace_spaces=True):\r\n    if text is None:\r\n        return None\r\n\r\n    if replace_spaces:\r\n        text = text.replace(' ', '_')\r\n\r\n    text = text.translate({ord(x): '_' for x in invalid_filename_chars})\r\n    text = text.lstrip(invalid_filename_prefix)[:max_filename_part_length]\r\n    text = text.rstrip(invalid_filename_postfix)\r\n    return text\r\n\r\n\r\n@functools.cache\r\ndef get_scheduler_str(sampler_name, scheduler_name):\r\n    \"\"\"Returns {Scheduler} if the scheduler is applicable to the sampler\"\"\"\r\n    if scheduler_name == 'Automatic':\r\n        config = sd_samplers.find_sampler_config(sampler_name)\r\n        scheduler_name = config.options.get('scheduler', 'Automatic')\r\n    return scheduler_name.capitalize()\r\n\r\n\r\n@functools.cache\r\ndef get_sampler_scheduler_str(sampler_name, scheduler_name):\r\n    \"\"\"Returns the '{Sampler} {Scheduler}' if the scheduler is applicable to the sampler\"\"\"\r\n    return f'{sampler_name} {get_scheduler_str(sampler_name, scheduler_name)}'\r\n\r\n\r\ndef get_sampler_scheduler(p, sampler):\r\n    \"\"\"Returns '{Sampler} {Scheduler}' / '{Scheduler}' / 'NOTHING_AND_SKIP_PREVIOUS_TEXT'\"\"\"\r\n    if hasattr(p, 'scheduler') and hasattr(p, 'sampler_name'):\r\n        if sampler:\r\n            sampler_scheduler = get_sampler_scheduler_str(p.sampler_name, p.scheduler)\r\n        else:\r\n            sampler_scheduler = get_scheduler_str(p.sampler_name, p.scheduler)\r\n        return sanitize_filename_part(sampler_scheduler, replace_spaces=False)\r\n    return NOTHING_AND_SKIP_PREVIOUS_TEXT\r\n\r\n\r\nclass FilenameGenerator:\r\n    replacements = {\r\n        'basename': lambda self: self.basename or 'img',\r\n        'seed': lambda self: self.seed if self.seed is not None else '',\r\n        'seed_first': lambda self: self.seed if self.p.batch_size == 1 else self.p.all_seeds[0],\r\n        'seed_last': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 else self.p.all_seeds[-1],\r\n        'steps': lambda self:  self.p and self.p.steps,\r\n        'cfg': lambda self: self.p and self.p.cfg_scale,\r\n        'width': lambda self: self.image.width,\r\n        'height': lambda self: self.image.height,\r\n        'styles': lambda self: self.p and sanitize_filename_part(\", \".join([style for style in self.p.styles if not style == \"None\"]) or \"None\", replace_spaces=False),\r\n        'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False),\r\n        'sampler_scheduler': lambda self: self.p and get_sampler_scheduler(self.p, True),\r\n        'scheduler': lambda self: self.p and get_sampler_scheduler(self.p, False),\r\n        'model_hash': lambda self: getattr(self.p, \"sd_model_hash\", shared.sd_model.sd_model_hash),\r\n        'model_name': lambda self: sanitize_filename_part(shared.sd_model.sd_checkpoint_info.name_for_extra, replace_spaces=False),\r\n        'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),\r\n        'datetime': lambda self, *args: self.datetime(*args),  # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]\r\n        'job_timestamp': lambda self: getattr(self.p, \"job_timestamp\", shared.state.job_timestamp),\r\n        'prompt_hash': lambda self, *args: self.string_hash(self.prompt, *args),\r\n        'negative_prompt_hash': lambda self, *args: self.string_hash(self.p.negative_prompt, *args),\r\n        'full_prompt_hash': lambda self, *args: self.string_hash(f\"{self.p.prompt} {self.p.negative_prompt}\", *args),  # a space in between to create a unique string\r\n        'prompt': lambda self: sanitize_filename_part(self.prompt),\r\n        'prompt_no_styles': lambda self: self.prompt_no_style(),\r\n        'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),\r\n        'prompt_words': lambda self: self.prompt_words(),\r\n        'batch_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 or self.zip else self.p.batch_index + 1,\r\n        'batch_size': lambda self: self.p.batch_size,\r\n        'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if (self.p.n_iter == 1 and self.p.batch_size == 1) or self.zip else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,\r\n        'hasprompt': lambda self, *args: self.hasprompt(*args),  # accepts formats:[hasprompt<prompt1|default><prompt2>..]\r\n        'clip_skip': lambda self: opts.data[\"CLIP_stop_at_last_layers\"],\r\n        'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,\r\n        'user': lambda self: self.p.user,\r\n        'vae_filename': lambda self: self.get_vae_filename(),\r\n        'none': lambda self: '',  # Overrides the default, so you can get just the sequence number\r\n        'image_hash': lambda self, *args: self.image_hash(*args)  # accepts formats: [image_hash<length>] default full hash\r\n    }\r\n    default_time_format = '%Y%m%d%H%M%S'\r\n\r\n    def __init__(self, p, seed, prompt, image, zip=False, basename=\"\"):\r\n        self.p = p\r\n        self.seed = seed\r\n        self.prompt = prompt\r\n        self.image = image\r\n        self.zip = zip\r\n        self.basename = basename\r\n\r\n    def get_vae_filename(self):\r\n        \"\"\"Get the name of the VAE file.\"\"\"\r\n\r\n        import modules.sd_vae as sd_vae\r\n\r\n        if sd_vae.loaded_vae_file is None:\r\n            return \"NoneType\"\r\n\r\n        file_name = os.path.basename(sd_vae.loaded_vae_file)\r\n        split_file_name = file_name.split('.')\r\n        if len(split_file_name) > 1 and split_file_name[0] == '':\r\n            return split_file_name[1]  # if the first character of the filename is \".\" then [1] is obtained.\r\n        else:\r\n            return split_file_name[0]\r\n\r\n\r\n    def hasprompt(self, *args):\r\n        lower = self.prompt.lower()\r\n        if self.p is None or self.prompt is None:\r\n            return None\r\n        outres = \"\"\r\n        for arg in args:\r\n            if arg != \"\":\r\n                division = arg.split(\"|\")\r\n                expected = division[0].lower()\r\n                default = division[1] if len(division) > 1 else \"\"\r\n                if lower.find(expected) >= 0:\r\n                    outres = f'{outres}{expected}'\r\n                else:\r\n                    outres = outres if default == \"\" else f'{outres}{default}'\r\n        return sanitize_filename_part(outres)\r\n\r\n    def prompt_no_style(self):\r\n        if self.p is None or self.prompt is None:\r\n            return None\r\n\r\n        prompt_no_style = self.prompt\r\n        for style in shared.prompt_styles.get_style_prompts(self.p.styles):\r\n            if style:\r\n                for part in style.split(\"{prompt}\"):\r\n                    prompt_no_style = prompt_no_style.replace(part, \"\").replace(\", ,\", \",\").strip().strip(',')\r\n\r\n                prompt_no_style = prompt_no_style.replace(style, \"\").strip().strip(',').strip()\r\n\r\n        return sanitize_filename_part(prompt_no_style, replace_spaces=False)\r\n\r\n    def prompt_words(self):\r\n        words = [x for x in re_nonletters.split(self.prompt or \"\") if x]\r\n        if len(words) == 0:\r\n            words = [\"empty\"]\r\n        return sanitize_filename_part(\" \".join(words[0:opts.directories_max_prompt_words]), replace_spaces=False)\r\n\r\n    def datetime(self, *args):\r\n        time_datetime = datetime.datetime.now()\r\n\r\n        time_format = args[0] if (args and args[0] != \"\") else self.default_time_format\r\n        try:\r\n            time_zone = pytz.timezone(args[1]) if len(args) > 1 else None\r\n        except pytz.exceptions.UnknownTimeZoneError:\r\n            time_zone = None\r\n\r\n        time_zone_time = time_datetime.astimezone(time_zone)\r\n        try:\r\n            formatted_time = time_zone_time.strftime(time_format)\r\n        except (ValueError, TypeError):\r\n            formatted_time = time_zone_time.strftime(self.default_time_format)\r\n\r\n        return sanitize_filename_part(formatted_time, replace_spaces=False)\r\n\r\n    def image_hash(self, *args):\r\n        length = int(args[0]) if (args and args[0] != \"\") else None\r\n        return hashlib.sha256(self.image.tobytes()).hexdigest()[0:length]\r\n\r\n    def string_hash(self, text, *args):\r\n        length = int(args[0]) if (args and args[0] != \"\") else 8\r\n        return hashlib.sha256(text.encode()).hexdigest()[0:length]\r\n\r\n    def apply(self, x):\r\n        res = ''\r\n\r\n        for m in re_pattern.finditer(x):\r\n            text, pattern = m.groups()\r\n\r\n            if pattern is None:\r\n                res += text\r\n                continue\r\n\r\n            pattern_args = []\r\n            while True:\r\n                m = re_pattern_arg.match(pattern)\r\n                if m is None:\r\n                    break\r\n\r\n                pattern, arg = m.groups()\r\n                pattern_args.insert(0, arg)\r\n\r\n            fun = self.replacements.get(pattern.lower())\r\n            if fun is not None:\r\n                try:\r\n                    replacement = fun(self, *pattern_args)\r\n                except Exception:\r\n                    replacement = None\r\n                    errors.report(f\"Error adding [{pattern}] to filename\", exc_info=True)\r\n\r\n                if replacement == NOTHING_AND_SKIP_PREVIOUS_TEXT:\r\n                    continue\r\n                elif replacement is not None:\r\n                    res += text + str(replacement)\r\n                    continue\r\n\r\n            res += f'{text}[{pattern}]'\r\n\r\n        return res\r\n\r\n\r\ndef get_next_sequence_number(path, basename):\r\n    \"\"\"\r\n    Determines and returns the next sequence number to use when saving an image in the specified directory.\r\n\r\n    The sequence starts at 0.\r\n    \"\"\"\r\n    result = -1\r\n    if basename != '':\r\n        basename = f\"{basename}-\"\r\n\r\n    prefix_length = len(basename)\r\n    for p in os.listdir(path):\r\n        if p.startswith(basename):\r\n            parts = os.path.splitext(p[prefix_length:])[0].split('-')  # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)\r\n            try:\r\n                result = max(int(parts[0]), result)\r\n            except ValueError:\r\n                pass\r\n\r\n    return result + 1\r\n\r\n\r\ndef save_image_with_geninfo(image, geninfo, filename, extension=None, existing_pnginfo=None, pnginfo_section_name='parameters'):\r\n    \"\"\"\r\n    Saves image to filename, including geninfo as text information for generation info.\r\n    For PNG images, geninfo is added to existing pnginfo dictionary using the pnginfo_section_name argument as key.\r\n    For JPG images, there's no dictionary and geninfo just replaces the EXIF description.\r\n    \"\"\"\r\n\r\n    if extension is None:\r\n        extension = os.path.splitext(filename)[1]\r\n\r\n    image_format = Image.registered_extensions()[extension]\r\n\r\n    if extension.lower() == '.png':\r\n        existing_pnginfo = existing_pnginfo or {}\r\n        if opts.enable_pnginfo:\r\n            existing_pnginfo[pnginfo_section_name] = geninfo\r\n\r\n        if opts.enable_pnginfo:\r\n            pnginfo_data = PngImagePlugin.PngInfo()\r\n            for k, v in (existing_pnginfo or {}).items():\r\n                pnginfo_data.add_text(k, str(v))\r\n        else:\r\n            pnginfo_data = None\r\n\r\n        image.save(filename, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)\r\n\r\n    elif extension.lower() in (\".jpg\", \".jpeg\", \".webp\"):\r\n        if image.mode == 'RGBA':\r\n            image = image.convert(\"RGB\")\r\n        elif image.mode == 'I;16':\r\n            image = image.point(lambda p: p * 0.0038910505836576).convert(\"RGB\" if extension.lower() == \".webp\" else \"L\")\r\n\r\n        image.save(filename, format=image_format, quality=opts.jpeg_quality, lossless=opts.webp_lossless)\r\n\r\n        if opts.enable_pnginfo and geninfo is not None:\r\n            exif_bytes = piexif.dump({\r\n                \"Exif\": {\r\n                    piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(geninfo or \"\", encoding=\"unicode\")\r\n                },\r\n            })\r\n\r\n            piexif.insert(exif_bytes, filename)\r\n    elif extension.lower() == '.avif':\r\n        if opts.enable_pnginfo and geninfo is not None:\r\n            exif_bytes = piexif.dump({\r\n                \"Exif\": {\r\n                    piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(geninfo or \"\", encoding=\"unicode\")\r\n                },\r\n            })\r\n        else:\r\n            exif_bytes = None\r\n\r\n        image.save(filename,format=image_format, quality=opts.jpeg_quality, exif=exif_bytes)\r\n    elif extension.lower() == \".gif\":\r\n        image.save(filename, format=image_format, comment=geninfo)\r\n    else:\r\n        image.save(filename, format=image_format, quality=opts.jpeg_quality)\r\n\r\n\r\ndef save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix=\"\", save_to_dirs=None):\r\n    \"\"\"Save an image.\r\n\r\n    Args:\r\n        image (`PIL.Image`):\r\n            The image to be saved.\r\n        path (`str`):\r\n            The directory to save the image. Note, the option `save_to_dirs` will make the image to be saved into a sub directory.\r\n        basename (`str`):\r\n            The base filename which will be applied to `filename pattern`.\r\n        seed, prompt, short_filename,\r\n        extension (`str`):\r\n            Image file extension, default is `png`.\r\n        pngsectionname (`str`):\r\n            Specify the name of the section which `info` will be saved in.\r\n        info (`str` or `PngImagePlugin.iTXt`):\r\n            PNG info chunks.\r\n        existing_info (`dict`):\r\n            Additional PNG info. `existing_info == {pngsectionname: info, ...}`\r\n        no_prompt:\r\n            TODO I don't know its meaning.\r\n        p (`StableDiffusionProcessing`)\r\n        forced_filename (`str`):\r\n            If specified, `basename` and filename pattern will be ignored.\r\n        save_to_dirs (bool):\r\n            If true, the image will be saved into a subdirectory of `path`.\r\n\r\n    Returns: (fullfn, txt_fullfn)\r\n        fullfn (`str`):\r\n            The full path of the saved imaged.\r\n        txt_fullfn (`str` or None):\r\n            If a text file is saved for this image, this will be its full path. Otherwise None.\r\n    \"\"\"\r\n    namegen = FilenameGenerator(p, seed, prompt, image, basename=basename)\r\n\r\n    # WebP and JPG formats have maximum dimension limits of 16383 and 65535 respectively. switch to PNG which has a much higher limit\r\n    if (image.height > 65535 or image.width > 65535) and extension.lower() in (\"jpg\", \"jpeg\") or (image.height > 16383 or image.width > 16383) and extension.lower() == \"webp\":\r\n        print('Image dimensions too large; saving as PNG')\r\n        extension = \"png\"\r\n\r\n    if save_to_dirs is None:\r\n        save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)\r\n\r\n    if save_to_dirs:\r\n        dirname = namegen.apply(opts.directories_filename_pattern or \"[prompt_words]\").lstrip(' ').rstrip('\\\\ /')\r\n        path = os.path.join(path, dirname)\r\n\r\n    os.makedirs(path, exist_ok=True)\r\n\r\n    if forced_filename is None:\r\n        if short_filename or seed is None:\r\n            file_decoration = \"\"\r\n        elif opts.save_to_dirs:\r\n            file_decoration = opts.samples_filename_pattern or \"[seed]\"\r\n        else:\r\n            file_decoration = opts.samples_filename_pattern or \"[seed]-[prompt_spaces]\"\r\n\r\n        file_decoration = namegen.apply(file_decoration) + suffix\r\n\r\n        add_number = opts.save_images_add_number or file_decoration == ''\r\n\r\n        if file_decoration != \"\" and add_number:\r\n            file_decoration = f\"-{file_decoration}\"\r\n\r\n        if add_number:\r\n            basecount = get_next_sequence_number(path, basename)\r\n            fullfn = None\r\n            for i in range(500):\r\n                fn = f\"{basecount + i:05}\" if basename == '' else f\"{basename}-{basecount + i:04}\"\r\n                fullfn = os.path.join(path, f\"{fn}{file_decoration}.{extension}\")\r\n                if not os.path.exists(fullfn):\r\n                    break\r\n        else:\r\n            fullfn = os.path.join(path, f\"{file_decoration}.{extension}\")\r\n    else:\r\n        fullfn = os.path.join(path, f\"{forced_filename}.{extension}\")\r\n\r\n    pnginfo = existing_info or {}\r\n    if info is not None:\r\n        pnginfo[pnginfo_section_name] = info\r\n\r\n    params = script_callbacks.ImageSaveParams(image, p, fullfn, pnginfo)\r\n    script_callbacks.before_image_saved_callback(params)\r\n\r\n    image = params.image\r\n    fullfn = params.filename\r\n    info = params.pnginfo.get(pnginfo_section_name, None)\r\n\r\n    def _atomically_save_image(image_to_save, filename_without_extension, extension):\r\n        \"\"\"\r\n        save image with .tmp extension to avoid race condition when another process detects new image in the directory\r\n        \"\"\"\r\n        temp_file_path = f\"{filename_without_extension}.tmp\"\r\n\r\n        save_image_with_geninfo(image_to_save, info, temp_file_path, extension, existing_pnginfo=params.pnginfo, pnginfo_section_name=pnginfo_section_name)\r\n\r\n        filename = filename_without_extension + extension\r\n        if shared.opts.save_images_replace_action != \"Replace\":\r\n            n = 0\r\n            while os.path.exists(filename):\r\n                n += 1\r\n                filename = f\"{filename_without_extension}-{n}{extension}\"\r\n        os.replace(temp_file_path, filename)\r\n\r\n    fullfn_without_extension, extension = os.path.splitext(params.filename)\r\n    if hasattr(os, 'statvfs'):\r\n        max_name_len = os.statvfs(path).f_namemax\r\n        fullfn_without_extension = fullfn_without_extension[:max_name_len - max(4, len(extension))]\r\n        params.filename = fullfn_without_extension + extension\r\n        fullfn = params.filename\r\n    _atomically_save_image(image, fullfn_without_extension, extension)\r\n\r\n    image.already_saved_as = fullfn\r\n\r\n    oversize = image.width > opts.target_side_length or image.height > opts.target_side_length\r\n    if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > opts.img_downscale_threshold * 1024 * 1024):\r\n        ratio = image.width / image.height\r\n        resize_to = None\r\n        if oversize and ratio > 1:\r\n            resize_to = round(opts.target_side_length), round(image.height * opts.target_side_length / image.width)\r\n        elif oversize:\r\n            resize_to = round(image.width * opts.target_side_length / image.height), round(opts.target_side_length)\r\n\r\n        if resize_to is not None:\r\n            try:\r\n                # Resizing image with LANCZOS could throw an exception if e.g. image mode is I;16\r\n                image = image.resize(resize_to, LANCZOS)\r\n            except Exception:\r\n                image = image.resize(resize_to)\r\n        try:\r\n            _atomically_save_image(image, fullfn_without_extension, \".jpg\")\r\n        except Exception as e:\r\n            errors.display(e, \"saving image as downscaled JPG\")\r\n\r\n    if opts.save_txt and info is not None:\r\n        txt_fullfn = f\"{fullfn_without_extension}.txt\"\r\n        with open(txt_fullfn, \"w\", encoding=\"utf8\") as file:\r\n            file.write(f\"{info}\\n\")\r\n    else:\r\n        txt_fullfn = None\r\n\r\n    script_callbacks.image_saved_callback(params)\r\n\r\n    return fullfn, txt_fullfn\r\n\r\n\r\nIGNORED_INFO_KEYS = {\r\n    'jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif',\r\n    'loop', 'background', 'timestamp', 'duration', 'progressive', 'progression',\r\n    'icc_profile', 'chromaticity', 'photoshop',\r\n}\r\n\r\n\r\ndef read_info_from_image(image: Image.Image) -> tuple[str | None, dict]:\r\n    items = (image.info or {}).copy()\r\n\r\n    geninfo = items.pop('parameters', None)\r\n\r\n    if \"exif\" in items:\r\n        exif_data = items[\"exif\"]\r\n        try:\r\n            exif = piexif.load(exif_data)\r\n        except OSError:\r\n            # memory / exif was not valid so piexif tried to read from a file\r\n            exif = None\r\n        exif_comment = (exif or {}).get(\"Exif\", {}).get(piexif.ExifIFD.UserComment, b'')\r\n        try:\r\n            exif_comment = piexif.helper.UserComment.load(exif_comment)\r\n        except ValueError:\r\n            exif_comment = exif_comment.decode('utf8', errors=\"ignore\")\r\n\r\n        if exif_comment:\r\n            geninfo = exif_comment\r\n    elif \"comment\" in items: # for gif\r\n        if isinstance(items[\"comment\"], bytes):\r\n            geninfo = items[\"comment\"].decode('utf8', errors=\"ignore\")\r\n        else:\r\n            geninfo = items[\"comment\"]\r\n\r\n    for field in IGNORED_INFO_KEYS:\r\n        items.pop(field, None)\r\n\r\n    if items.get(\"Software\", None) == \"NovelAI\":\r\n        try:\r\n            json_info = json.loads(items[\"Comment\"])\r\n            sampler = sd_samplers.samplers_map.get(json_info[\"sampler\"], \"Euler a\")\r\n\r\n            geninfo = f\"\"\"{items[\"Description\"]}\r\nNegative prompt: {json_info[\"uc\"]}\r\nSteps: {json_info[\"steps\"]}, Sampler: {sampler}, CFG scale: {json_info[\"scale\"]}, Seed: {json_info[\"seed\"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337\"\"\"\r\n        except Exception:\r\n            errors.report(\"Error parsing NovelAI image generation parameters\", exc_info=True)\r\n\r\n    return geninfo, items\r\n\r\n\r\ndef image_data(data):\r\n    import gradio as gr\r\n\r\n    try:\r\n        image = read(io.BytesIO(data))\r\n        textinfo, _ = read_info_from_image(image)\r\n        return textinfo, None\r\n    except Exception:\r\n        pass\r\n\r\n    try:\r\n        text = data.decode('utf8')\r\n        assert len(text) < 10000\r\n        return text, None\r\n\r\n    except Exception:\r\n        pass\r\n\r\n    return gr.update(), None\r\n\r\n\r\ndef flatten(img, bgcolor):\r\n    \"\"\"replaces transparency with bgcolor (example: \"#ffffff\"), returning an RGB mode image with no transparency\"\"\"\r\n\r\n    if img.mode == \"RGBA\":\r\n        background = Image.new('RGBA', img.size, bgcolor)\r\n        background.paste(img, mask=img)\r\n        img = background\r\n\r\n    return img.convert('RGB')\r\n\r\n\r\ndef read(fp, **kwargs):\r\n    image = Image.open(fp, **kwargs)\r\n    image = fix_image(image)\r\n\r\n    return image\r\n\r\n\r\ndef fix_image(image: Image.Image):\r\n    if image is None:\r\n        return None\r\n\r\n    try:\r\n        image = ImageOps.exif_transpose(image)\r\n        image = fix_png_transparency(image)\r\n    except Exception:\r\n        pass\r\n\r\n    return image\r\n\r\n\r\ndef fix_png_transparency(image: Image.Image):\r\n    if image.mode not in (\"RGB\", \"P\") or not isinstance(image.info.get(\"transparency\"), bytes):\r\n        return image\r\n\r\n    image = image.convert(\"RGBA\")\r\n    return image\r\n"
  },
  {
    "path": "modules/img2img.py",
    "content": "import os\r\nfrom contextlib import closing\r\nfrom pathlib import Path\r\n\r\nimport numpy as np\r\nfrom PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageError\r\nimport gradio as gr\r\n\r\nfrom modules import images\r\nfrom modules.infotext_utils import create_override_settings_dict, parse_generation_parameters\r\nfrom modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images\r\nfrom modules.shared import opts, state\r\nfrom modules.sd_models import get_closet_checkpoint_match\r\nimport modules.shared as shared\r\nimport modules.processing as processing\r\nfrom modules.ui import plaintext_to_html\r\nimport modules.scripts\r\n\r\n\r\ndef process_batch(p, input, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None):\r\n    output_dir = output_dir.strip()\r\n    processing.fix_seed(p)\r\n\r\n    if isinstance(input, str):\r\n        batch_images = list(shared.walk_files(input, allowed_extensions=(\".png\", \".jpg\", \".jpeg\", \".webp\", \".tif\", \".tiff\")))\r\n    else:\r\n        batch_images = [os.path.abspath(x.name) for x in input]\r\n\r\n    is_inpaint_batch = False\r\n    if inpaint_mask_dir:\r\n        inpaint_masks = shared.listfiles(inpaint_mask_dir)\r\n        is_inpaint_batch = bool(inpaint_masks)\r\n\r\n        if is_inpaint_batch:\r\n            print(f\"\\nInpaint batch is enabled. {len(inpaint_masks)} masks found.\")\r\n\r\n    print(f\"Will process {len(batch_images)} images, creating {p.n_iter * p.batch_size} new images for each.\")\r\n\r\n    state.job_count = len(batch_images) * p.n_iter\r\n\r\n    # extract \"default\" params to use in case getting png info fails\r\n    prompt = p.prompt\r\n    negative_prompt = p.negative_prompt\r\n    seed = p.seed\r\n    cfg_scale = p.cfg_scale\r\n    sampler_name = p.sampler_name\r\n    steps = p.steps\r\n    override_settings = p.override_settings\r\n    sd_model_checkpoint_override = get_closet_checkpoint_match(override_settings.get(\"sd_model_checkpoint\", None))\r\n    batch_results = None\r\n    discard_further_results = False\r\n    for i, image in enumerate(batch_images):\r\n        state.job = f\"{i+1} out of {len(batch_images)}\"\r\n        if state.skipped:\r\n            state.skipped = False\r\n\r\n        if state.interrupted or state.stopping_generation:\r\n            break\r\n\r\n        try:\r\n            img = images.read(image)\r\n        except UnidentifiedImageError as e:\r\n            print(e)\r\n            continue\r\n        # Use the EXIF orientation of photos taken by smartphones.\r\n        img = ImageOps.exif_transpose(img)\r\n\r\n        if to_scale:\r\n            p.width = int(img.width * scale_by)\r\n            p.height = int(img.height * scale_by)\r\n\r\n        p.init_images = [img] * p.batch_size\r\n\r\n        image_path = Path(image)\r\n        if is_inpaint_batch:\r\n            # try to find corresponding mask for an image using simple filename matching\r\n            if len(inpaint_masks) == 1:\r\n                mask_image_path = inpaint_masks[0]\r\n            else:\r\n                # try to find corresponding mask for an image using simple filename matching\r\n                mask_image_dir = Path(inpaint_mask_dir)\r\n                masks_found = list(mask_image_dir.glob(f\"{image_path.stem}.*\"))\r\n\r\n                if len(masks_found) == 0:\r\n                    print(f\"Warning: mask is not found for {image_path} in {mask_image_dir}. Skipping it.\")\r\n                    continue\r\n\r\n                # it should contain only 1 matching mask\r\n                # otherwise user has many masks with the same name but different extensions\r\n                mask_image_path = masks_found[0]\r\n\r\n            mask_image = images.read(mask_image_path)\r\n            p.image_mask = mask_image\r\n\r\n        if use_png_info:\r\n            try:\r\n                info_img = img\r\n                if png_info_dir:\r\n                    info_img_path = os.path.join(png_info_dir, os.path.basename(image))\r\n                    info_img = images.read(info_img_path)\r\n                geninfo, _ = images.read_info_from_image(info_img)\r\n                parsed_parameters = parse_generation_parameters(geninfo)\r\n                parsed_parameters = {k: v for k, v in parsed_parameters.items() if k in (png_info_props or {})}\r\n            except Exception:\r\n                parsed_parameters = {}\r\n\r\n            p.prompt = prompt + (\" \" + parsed_parameters[\"Prompt\"] if \"Prompt\" in parsed_parameters else \"\")\r\n            p.negative_prompt = negative_prompt + (\" \" + parsed_parameters[\"Negative prompt\"] if \"Negative prompt\" in parsed_parameters else \"\")\r\n            p.seed = int(parsed_parameters.get(\"Seed\", seed))\r\n            p.cfg_scale = float(parsed_parameters.get(\"CFG scale\", cfg_scale))\r\n            p.sampler_name = parsed_parameters.get(\"Sampler\", sampler_name)\r\n            p.steps = int(parsed_parameters.get(\"Steps\", steps))\r\n\r\n            model_info = get_closet_checkpoint_match(parsed_parameters.get(\"Model hash\", None))\r\n            if model_info is not None:\r\n                p.override_settings['sd_model_checkpoint'] = model_info.name\r\n            elif sd_model_checkpoint_override:\r\n                p.override_settings['sd_model_checkpoint'] = sd_model_checkpoint_override\r\n            else:\r\n                p.override_settings.pop(\"sd_model_checkpoint\", None)\r\n\r\n        if output_dir:\r\n            p.outpath_samples = output_dir\r\n            p.override_settings['save_to_dirs'] = False\r\n            p.override_settings['save_images_replace_action'] = \"Add number suffix\"\r\n            if p.n_iter > 1 or p.batch_size > 1:\r\n                p.override_settings['samples_filename_pattern'] = f'{image_path.stem}-[generation_number]'\r\n            else:\r\n                p.override_settings['samples_filename_pattern'] = f'{image_path.stem}'\r\n\r\n        proc = modules.scripts.scripts_img2img.run(p, *args)\r\n\r\n        if proc is None:\r\n            p.override_settings.pop('save_images_replace_action', None)\r\n            proc = process_images(p)\r\n\r\n        if not discard_further_results and proc:\r\n            if batch_results:\r\n                batch_results.images.extend(proc.images)\r\n                batch_results.infotexts.extend(proc.infotexts)\r\n            else:\r\n                batch_results = proc\r\n\r\n            if 0 <= shared.opts.img2img_batch_show_results_limit < len(batch_results.images):\r\n                discard_further_results = True\r\n                batch_results.images = batch_results.images[:int(shared.opts.img2img_batch_show_results_limit)]\r\n                batch_results.infotexts = batch_results.infotexts[:int(shared.opts.img2img_batch_show_results_limit)]\r\n\r\n    return batch_results\r\n\r\n\r\ndef img2img(id_task: str, request: gr.Request, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, img2img_batch_source_type: str, img2img_batch_upload: list, *args):\r\n    override_settings = create_override_settings_dict(override_settings_texts)\r\n\r\n    is_batch = mode == 5\r\n\r\n    if mode == 0:  # img2img\r\n        image = init_img\r\n        mask = None\r\n    elif mode == 1:  # img2img sketch\r\n        image = sketch\r\n        mask = None\r\n    elif mode == 2:  # inpaint\r\n        image, mask = init_img_with_mask[\"image\"], init_img_with_mask[\"mask\"]\r\n        mask = processing.create_binary_mask(mask)\r\n    elif mode == 3:  # inpaint sketch\r\n        image = inpaint_color_sketch\r\n        orig = inpaint_color_sketch_orig or inpaint_color_sketch\r\n        pred = np.any(np.array(image) != np.array(orig), axis=-1)\r\n        mask = Image.fromarray(pred.astype(np.uint8) * 255, \"L\")\r\n        mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)\r\n        blur = ImageFilter.GaussianBlur(mask_blur)\r\n        image = Image.composite(image.filter(blur), orig, mask.filter(blur))\r\n    elif mode == 4:  # inpaint upload mask\r\n        image = init_img_inpaint\r\n        mask = init_mask_inpaint\r\n    else:\r\n        image = None\r\n        mask = None\r\n\r\n    image = images.fix_image(image)\r\n    mask = images.fix_image(mask)\r\n\r\n    if selected_scale_tab == 1 and not is_batch:\r\n        assert image, \"Can't scale by because no image is selected\"\r\n\r\n        width = int(image.width * scale_by)\r\n        height = int(image.height * scale_by)\r\n\r\n    assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'\r\n\r\n    p = StableDiffusionProcessingImg2Img(\r\n        sd_model=shared.sd_model,\r\n        outpath_samples=opts.outdir_samples or opts.outdir_img2img_samples,\r\n        outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids,\r\n        prompt=prompt,\r\n        negative_prompt=negative_prompt,\r\n        styles=prompt_styles,\r\n        batch_size=batch_size,\r\n        n_iter=n_iter,\r\n        cfg_scale=cfg_scale,\r\n        width=width,\r\n        height=height,\r\n        init_images=[image],\r\n        mask=mask,\r\n        mask_blur=mask_blur,\r\n        inpainting_fill=inpainting_fill,\r\n        resize_mode=resize_mode,\r\n        denoising_strength=denoising_strength,\r\n        image_cfg_scale=image_cfg_scale,\r\n        inpaint_full_res=inpaint_full_res,\r\n        inpaint_full_res_padding=inpaint_full_res_padding,\r\n        inpainting_mask_invert=inpainting_mask_invert,\r\n        override_settings=override_settings,\r\n    )\r\n\r\n    p.scripts = modules.scripts.scripts_img2img\r\n    p.script_args = args\r\n\r\n    p.user = request.username\r\n\r\n    if shared.opts.enable_console_prompts:\r\n        print(f\"\\nimg2img: {prompt}\", file=shared.progress_print_out)\r\n\r\n    with closing(p):\r\n        if is_batch:\r\n            if img2img_batch_source_type == \"upload\":\r\n                assert isinstance(img2img_batch_upload, list) and img2img_batch_upload\r\n                output_dir = \"\"\r\n                inpaint_mask_dir = \"\"\r\n                png_info_dir = img2img_batch_png_info_dir if not shared.cmd_opts.hide_ui_dir_config else \"\"\r\n                processed = process_batch(p, img2img_batch_upload, output_dir, inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=png_info_dir)\r\n            else: # \"from dir\"\r\n                assert not shared.cmd_opts.hide_ui_dir_config, \"Launched with --hide-ui-dir-config, batch img2img disabled\"\r\n                processed = process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by, use_png_info=img2img_batch_use_png_info, png_info_props=img2img_batch_png_info_props, png_info_dir=img2img_batch_png_info_dir)\r\n\r\n            if processed is None:\r\n                processed = Processed(p, [], p.seed, \"\")\r\n        else:\r\n            processed = modules.scripts.scripts_img2img.run(p, *args)\r\n            if processed is None:\r\n                processed = process_images(p)\r\n\r\n    shared.total_tqdm.clear()\r\n\r\n    generation_info_js = processed.js()\r\n    if opts.samples_log_stdout:\r\n        print(generation_info_js)\r\n\r\n    if opts.do_not_show_images:\r\n        processed.images = []\r\n\r\n    return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname=\"comments\")\r\n"
  },
  {
    "path": "modules/import_hook.py",
    "content": "import sys\n\n# this will break any attempt to import xformers which will prevent stability diffusion repo from trying to use it\nif \"--xformers\" not in \"\".join(sys.argv):\n    sys.modules[\"xformers\"] = None\n\n# Hack to fix a changed import in torchvision 0.17+, which otherwise breaks\n# basicsr; see https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/13985\ntry:\n    import torchvision.transforms.functional_tensor  # noqa: F401\nexcept ImportError:\n    try:\n        import torchvision.transforms.functional as functional\n        sys.modules[\"torchvision.transforms.functional_tensor\"] = functional\n    except ImportError:\n        pass  # shrug...\n"
  },
  {
    "path": "modules/infotext_utils.py",
    "content": "from __future__ import annotations\r\nimport base64\r\nimport io\r\nimport json\r\nimport os\r\nimport re\r\nimport sys\r\n\r\nimport gradio as gr\r\nfrom modules.paths import data_path\r\nfrom modules import shared, ui_tempdir, script_callbacks, processing, infotext_versions, images, prompt_parser, errors\r\nfrom PIL import Image\r\n\r\nsys.modules['modules.generation_parameters_copypaste'] = sys.modules[__name__]  # alias for old name\r\n\r\nre_param_code = r'\\s*(\\w[\\w \\-/]+):\\s*(\"(?:\\\\.|[^\\\\\"])+\"|[^,]*)(?:,|$)'\r\nre_param = re.compile(re_param_code)\r\nre_imagesize = re.compile(r\"^(\\d+)x(\\d+)$\")\r\nre_hypernet_hash = re.compile(\"\\(([0-9a-f]+)\\)$\")\r\ntype_of_gr_update = type(gr.update())\r\n\r\n\r\nclass ParamBinding:\r\n    def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None):\r\n        self.paste_button = paste_button\r\n        self.tabname = tabname\r\n        self.source_text_component = source_text_component\r\n        self.source_image_component = source_image_component\r\n        self.source_tabname = source_tabname\r\n        self.override_settings_component = override_settings_component\r\n        self.paste_field_names = paste_field_names or []\r\n\r\n\r\nclass PasteField(tuple):\r\n    def __new__(cls, component, target, *, api=None):\r\n        return super().__new__(cls, (component, target))\r\n\r\n    def __init__(self, component, target, *, api=None):\r\n        super().__init__()\r\n\r\n        self.api = api\r\n        self.component = component\r\n        self.label = target if isinstance(target, str) else None\r\n        self.function = target if callable(target) else None\r\n\r\n\r\npaste_fields: dict[str, dict] = {}\r\nregistered_param_bindings: list[ParamBinding] = []\r\n\r\n\r\ndef reset():\r\n    paste_fields.clear()\r\n    registered_param_bindings.clear()\r\n\r\n\r\ndef quote(text):\r\n    if ',' not in str(text) and '\\n' not in str(text) and ':' not in str(text):\r\n        return text\r\n\r\n    return json.dumps(text, ensure_ascii=False)\r\n\r\n\r\ndef unquote(text):\r\n    if len(text) == 0 or text[0] != '\"' or text[-1] != '\"':\r\n        return text\r\n\r\n    try:\r\n        return json.loads(text)\r\n    except Exception:\r\n        return text\r\n\r\n\r\ndef image_from_url_text(filedata):\r\n    if filedata is None:\r\n        return None\r\n\r\n    if type(filedata) == list and filedata and type(filedata[0]) == dict and filedata[0].get(\"is_file\", False):\r\n        filedata = filedata[0]\r\n\r\n    if type(filedata) == dict and filedata.get(\"is_file\", False):\r\n        filename = filedata[\"name\"]\r\n        is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)\r\n        assert is_in_right_dir, 'trying to open image file outside of allowed directories'\r\n\r\n        filename = filename.rsplit('?', 1)[0]\r\n        return images.read(filename)\r\n\r\n    if type(filedata) == list:\r\n        if len(filedata) == 0:\r\n            return None\r\n\r\n        filedata = filedata[0]\r\n\r\n    if filedata.startswith(\"data:image/png;base64,\"):\r\n        filedata = filedata[len(\"data:image/png;base64,\"):]\r\n\r\n    filedata = base64.decodebytes(filedata.encode('utf-8'))\r\n    image = images.read(io.BytesIO(filedata))\r\n    return image\r\n\r\n\r\ndef add_paste_fields(tabname, init_img, fields, override_settings_component=None):\r\n\r\n    if fields:\r\n        for i in range(len(fields)):\r\n            if not isinstance(fields[i], PasteField):\r\n                fields[i] = PasteField(*fields[i])\r\n\r\n    paste_fields[tabname] = {\"init_img\": init_img, \"fields\": fields, \"override_settings_component\": override_settings_component}\r\n\r\n    # backwards compatibility for existing extensions\r\n    import modules.ui\r\n    if tabname == 'txt2img':\r\n        modules.ui.txt2img_paste_fields = fields\r\n    elif tabname == 'img2img':\r\n        modules.ui.img2img_paste_fields = fields\r\n\r\n\r\ndef create_buttons(tabs_list):\r\n    buttons = {}\r\n    for tab in tabs_list:\r\n        buttons[tab] = gr.Button(f\"Send to {tab}\", elem_id=f\"{tab}_tab\")\r\n    return buttons\r\n\r\n\r\ndef bind_buttons(buttons, send_image, send_generate_info):\r\n    \"\"\"old function for backwards compatibility; do not use this, use register_paste_params_button\"\"\"\r\n    for tabname, button in buttons.items():\r\n        source_text_component = send_generate_info if isinstance(send_generate_info, gr.components.Component) else None\r\n        source_tabname = send_generate_info if isinstance(send_generate_info, str) else None\r\n\r\n        register_paste_params_button(ParamBinding(paste_button=button, tabname=tabname, source_text_component=source_text_component, source_image_component=send_image, source_tabname=source_tabname))\r\n\r\n\r\ndef register_paste_params_button(binding: ParamBinding):\r\n    registered_param_bindings.append(binding)\r\n\r\n\r\ndef connect_paste_params_buttons():\r\n    for binding in registered_param_bindings:\r\n        destination_image_component = paste_fields[binding.tabname][\"init_img\"]\r\n        fields = paste_fields[binding.tabname][\"fields\"]\r\n        override_settings_component = binding.override_settings_component or paste_fields[binding.tabname][\"override_settings_component\"]\r\n\r\n        destination_width_component = next(iter([field for field, name in fields if name == \"Size-1\"] if fields else []), None)\r\n        destination_height_component = next(iter([field for field, name in fields if name == \"Size-2\"] if fields else []), None)\r\n\r\n        if binding.source_image_component and destination_image_component:\r\n            need_send_dementions = destination_width_component and binding.tabname != 'inpaint'\r\n            if isinstance(binding.source_image_component, gr.Gallery):\r\n                func = send_image_and_dimensions if need_send_dementions else image_from_url_text\r\n                jsfunc = \"extract_image_from_gallery\"\r\n            else:\r\n                func = send_image_and_dimensions if need_send_dementions else lambda x: x\r\n                jsfunc = None\r\n\r\n            binding.paste_button.click(\r\n                fn=func,\r\n                _js=jsfunc,\r\n                inputs=[binding.source_image_component],\r\n                outputs=[destination_image_component, destination_width_component, destination_height_component] if need_send_dementions else [destination_image_component],\r\n                show_progress=False,\r\n            )\r\n\r\n        if binding.source_text_component is not None and fields is not None:\r\n            connect_paste(binding.paste_button, fields, binding.source_text_component, override_settings_component, binding.tabname)\r\n\r\n        if binding.source_tabname is not None and fields is not None:\r\n            paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + ([\"Seed\"] if shared.opts.send_seed else []) + binding.paste_field_names\r\n            binding.paste_button.click(\r\n                fn=lambda *x: x,\r\n                inputs=[field for field, name in paste_fields[binding.source_tabname][\"fields\"] if name in paste_field_names],\r\n                outputs=[field for field, name in fields if name in paste_field_names],\r\n                show_progress=False,\r\n            )\r\n\r\n        binding.paste_button.click(\r\n            fn=None,\r\n            _js=f\"switch_to_{binding.tabname}\",\r\n            inputs=None,\r\n            outputs=None,\r\n            show_progress=False,\r\n        )\r\n\r\n\r\ndef send_image_and_dimensions(x):\r\n    if isinstance(x, Image.Image):\r\n        img = x\r\n    else:\r\n        img = image_from_url_text(x)\r\n\r\n    if shared.opts.send_size and isinstance(img, Image.Image):\r\n        w = img.width\r\n        h = img.height\r\n    else:\r\n        w = gr.update()\r\n        h = gr.update()\r\n\r\n    return img, w, h\r\n\r\n\r\ndef restore_old_hires_fix_params(res):\r\n    \"\"\"for infotexts that specify old First pass size parameter, convert it into\r\n    width, height, and hr scale\"\"\"\r\n\r\n    firstpass_width = res.get('First pass size-1', None)\r\n    firstpass_height = res.get('First pass size-2', None)\r\n\r\n    if shared.opts.use_old_hires_fix_width_height:\r\n        hires_width = int(res.get(\"Hires resize-1\", 0))\r\n        hires_height = int(res.get(\"Hires resize-2\", 0))\r\n\r\n        if hires_width and hires_height:\r\n            res['Size-1'] = hires_width\r\n            res['Size-2'] = hires_height\r\n            return\r\n\r\n    if firstpass_width is None or firstpass_height is None:\r\n        return\r\n\r\n    firstpass_width, firstpass_height = int(firstpass_width), int(firstpass_height)\r\n    width = int(res.get(\"Size-1\", 512))\r\n    height = int(res.get(\"Size-2\", 512))\r\n\r\n    if firstpass_width == 0 or firstpass_height == 0:\r\n        firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height)\r\n\r\n    res['Size-1'] = firstpass_width\r\n    res['Size-2'] = firstpass_height\r\n    res['Hires resize-1'] = width\r\n    res['Hires resize-2'] = height\r\n\r\n\r\ndef parse_generation_parameters(x: str, skip_fields: list[str] | None = None):\r\n    \"\"\"parses generation parameters string, the one you see in text field under the picture in UI:\r\n```\r\ngirl with an artist's beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate\r\nNegative prompt: ugly, fat, obese, chubby, (((deformed))), [blurry], bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), messy drawing\r\nSteps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model hash: 45dee52b\r\n```\r\n\r\n    returns a dict with field values\r\n    \"\"\"\r\n    if skip_fields is None:\r\n        skip_fields = shared.opts.infotext_skip_pasting\r\n\r\n    res = {}\r\n\r\n    prompt = \"\"\r\n    negative_prompt = \"\"\r\n\r\n    done_with_prompt = False\r\n\r\n    *lines, lastline = x.strip().split(\"\\n\")\r\n    if len(re_param.findall(lastline)) < 3:\r\n        lines.append(lastline)\r\n        lastline = ''\r\n\r\n    for line in lines:\r\n        line = line.strip()\r\n        if line.startswith(\"Negative prompt:\"):\r\n            done_with_prompt = True\r\n            line = line[16:].strip()\r\n        if done_with_prompt:\r\n            negative_prompt += (\"\" if negative_prompt == \"\" else \"\\n\") + line\r\n        else:\r\n            prompt += (\"\" if prompt == \"\" else \"\\n\") + line\r\n\r\n    for k, v in re_param.findall(lastline):\r\n        try:\r\n            if v[0] == '\"' and v[-1] == '\"':\r\n                v = unquote(v)\r\n\r\n            m = re_imagesize.match(v)\r\n            if m is not None:\r\n                res[f\"{k}-1\"] = m.group(1)\r\n                res[f\"{k}-2\"] = m.group(2)\r\n            else:\r\n                res[k] = v\r\n        except Exception:\r\n            print(f\"Error parsing \\\"{k}: {v}\\\"\")\r\n\r\n    # Extract styles from prompt\r\n    if shared.opts.infotext_styles != \"Ignore\":\r\n        found_styles, prompt_no_styles, negative_prompt_no_styles = shared.prompt_styles.extract_styles_from_prompt(prompt, negative_prompt)\r\n\r\n        same_hr_styles = True\r\n        if (\"Hires prompt\" in res or \"Hires negative prompt\" in res) and (infotext_ver > infotext_versions.v180_hr_styles if (infotext_ver := infotext_versions.parse_version(res.get(\"Version\"))) else True):\r\n            hr_prompt, hr_negative_prompt = res.get(\"Hires prompt\", prompt), res.get(\"Hires negative prompt\", negative_prompt)\r\n            hr_found_styles, hr_prompt_no_styles, hr_negative_prompt_no_styles = shared.prompt_styles.extract_styles_from_prompt(hr_prompt, hr_negative_prompt)\r\n            if same_hr_styles := found_styles == hr_found_styles:\r\n                res[\"Hires prompt\"] = '' if hr_prompt_no_styles == prompt_no_styles else hr_prompt_no_styles\r\n                res['Hires negative prompt'] = '' if hr_negative_prompt_no_styles == negative_prompt_no_styles else hr_negative_prompt_no_styles\r\n\r\n        if same_hr_styles:\r\n            prompt, negative_prompt = prompt_no_styles, negative_prompt_no_styles\r\n            if (shared.opts.infotext_styles == \"Apply if any\" and found_styles) or shared.opts.infotext_styles == \"Apply\":\r\n                res['Styles array'] = found_styles\r\n\r\n    res[\"Prompt\"] = prompt\r\n    res[\"Negative prompt\"] = negative_prompt\r\n\r\n    # Missing CLIP skip means it was set to 1 (the default)\r\n    if \"Clip skip\" not in res:\r\n        res[\"Clip skip\"] = \"1\"\r\n\r\n    hypernet = res.get(\"Hypernet\", None)\r\n    if hypernet is not None:\r\n        res[\"Prompt\"] += f\"\"\"<hypernet:{hypernet}:{res.get(\"Hypernet strength\", \"1.0\")}>\"\"\"\r\n\r\n    if \"Hires resize-1\" not in res:\r\n        res[\"Hires resize-1\"] = 0\r\n        res[\"Hires resize-2\"] = 0\r\n\r\n    if \"Hires sampler\" not in res:\r\n        res[\"Hires sampler\"] = \"Use same sampler\"\r\n\r\n    if \"Hires schedule type\" not in res:\r\n        res[\"Hires schedule type\"] = \"Use same scheduler\"\r\n\r\n    if \"Hires checkpoint\" not in res:\r\n        res[\"Hires checkpoint\"] = \"Use same checkpoint\"\r\n\r\n    if \"Hires prompt\" not in res:\r\n        res[\"Hires prompt\"] = \"\"\r\n\r\n    if \"Hires negative prompt\" not in res:\r\n        res[\"Hires negative prompt\"] = \"\"\r\n\r\n    if \"Mask mode\" not in res:\r\n        res[\"Mask mode\"] = \"Inpaint masked\"\r\n\r\n    if \"Masked content\" not in res:\r\n        res[\"Masked content\"] = 'original'\r\n\r\n    if \"Inpaint area\" not in res:\r\n        res[\"Inpaint area\"] = \"Whole picture\"\r\n\r\n    if \"Masked area padding\" not in res:\r\n        res[\"Masked area padding\"] = 32\r\n\r\n    restore_old_hires_fix_params(res)\r\n\r\n    # Missing RNG means the default was set, which is GPU RNG\r\n    if \"RNG\" not in res:\r\n        res[\"RNG\"] = \"GPU\"\r\n\r\n    if \"Schedule type\" not in res:\r\n        res[\"Schedule type\"] = \"Automatic\"\r\n\r\n    if \"Schedule max sigma\" not in res:\r\n        res[\"Schedule max sigma\"] = 0\r\n\r\n    if \"Schedule min sigma\" not in res:\r\n        res[\"Schedule min sigma\"] = 0\r\n\r\n    if \"Schedule rho\" not in res:\r\n        res[\"Schedule rho\"] = 0\r\n\r\n    if \"VAE Encoder\" not in res:\r\n        res[\"VAE Encoder\"] = \"Full\"\r\n\r\n    if \"VAE Decoder\" not in res:\r\n        res[\"VAE Decoder\"] = \"Full\"\r\n\r\n    if \"FP8 weight\" not in res:\r\n        res[\"FP8 weight\"] = \"Disable\"\r\n\r\n    if \"Cache FP16 weight for LoRA\" not in res and res[\"FP8 weight\"] != \"Disable\":\r\n        res[\"Cache FP16 weight for LoRA\"] = False\r\n\r\n    prompt_attention = prompt_parser.parse_prompt_attention(prompt)\r\n    prompt_attention += prompt_parser.parse_prompt_attention(negative_prompt)\r\n    prompt_uses_emphasis = len(prompt_attention) != len([p for p in prompt_attention if p[1] == 1.0 or p[0] == 'BREAK'])\r\n    if \"Emphasis\" not in res and prompt_uses_emphasis:\r\n        res[\"Emphasis\"] = \"Original\"\r\n\r\n    if \"Refiner switch by sampling steps\" not in res:\r\n        res[\"Refiner switch by sampling steps\"] = False\r\n\r\n    infotext_versions.backcompat(res)\r\n\r\n    for key in skip_fields:\r\n        res.pop(key, None)\r\n\r\n    return res\r\n\r\n\r\ninfotext_to_setting_name_mapping = [\r\n\r\n]\r\n\"\"\"Mapping of infotext labels to setting names. Only left for backwards compatibility - use OptionInfo(..., infotext='...') instead.\r\nExample content:\r\n\r\ninfotext_to_setting_name_mapping = [\r\n    ('Conditional mask weight', 'inpainting_mask_weight'),\r\n    ('Model hash', 'sd_model_checkpoint'),\r\n    ('ENSD', 'eta_noise_seed_delta'),\r\n    ('Schedule type', 'k_sched_type'),\r\n]\r\n\"\"\"\r\n\r\n\r\ndef create_override_settings_dict(text_pairs):\r\n    \"\"\"creates processing's override_settings parameters from gradio's multiselect\r\n\r\n    Example input:\r\n        ['Clip skip: 2', 'Model hash: e6e99610c4', 'ENSD: 31337']\r\n\r\n    Example output:\r\n        {'CLIP_stop_at_last_layers': 2, 'sd_model_checkpoint': 'e6e99610c4', 'eta_noise_seed_delta': 31337}\r\n    \"\"\"\r\n\r\n    res = {}\r\n\r\n    params = {}\r\n    for pair in text_pairs:\r\n        k, v = pair.split(\":\", maxsplit=1)\r\n\r\n        params[k] = v.strip()\r\n\r\n    mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]\r\n    for param_name, setting_name in mapping + infotext_to_setting_name_mapping:\r\n        value = params.get(param_name, None)\r\n\r\n        if value is None:\r\n            continue\r\n\r\n        res[setting_name] = shared.opts.cast_value(setting_name, value)\r\n\r\n    return res\r\n\r\n\r\ndef get_override_settings(params, *, skip_fields=None):\r\n    \"\"\"Returns a list of settings overrides from the infotext parameters dictionary.\r\n\r\n    This function checks the `params` dictionary for any keys that correspond to settings in `shared.opts` and returns\r\n    a list of tuples containing the parameter name, setting name, and new value cast to correct type.\r\n\r\n    It checks for conditions before adding an override:\r\n    - ignores settings that match the current value\r\n    - ignores parameter keys present in skip_fields argument.\r\n\r\n    Example input:\r\n        {\"Clip skip\": \"2\"}\r\n\r\n    Example output:\r\n        [(\"Clip skip\", \"CLIP_stop_at_last_layers\", 2)]\r\n    \"\"\"\r\n\r\n    res = []\r\n\r\n    mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]\r\n    for param_name, setting_name in mapping + infotext_to_setting_name_mapping:\r\n        if param_name in (skip_fields or {}):\r\n            continue\r\n\r\n        v = params.get(param_name, None)\r\n        if v is None:\r\n            continue\r\n\r\n        if setting_name == \"sd_model_checkpoint\" and shared.opts.disable_weights_auto_swap:\r\n            continue\r\n\r\n        v = shared.opts.cast_value(setting_name, v)\r\n        current_value = getattr(shared.opts, setting_name, None)\r\n\r\n        if v == current_value:\r\n            continue\r\n\r\n        res.append((param_name, setting_name, v))\r\n\r\n    return res\r\n\r\n\r\ndef connect_paste(button, paste_fields, input_comp, override_settings_component, tabname):\r\n    def paste_func(prompt):\r\n        if not prompt and not shared.cmd_opts.hide_ui_dir_config and not shared.cmd_opts.no_prompt_history:\r\n            filename = os.path.join(data_path, \"params.txt\")\r\n            try:\r\n                with open(filename, \"r\", encoding=\"utf8\") as file:\r\n                    prompt = file.read()\r\n            except OSError:\r\n                pass\r\n\r\n        params = parse_generation_parameters(prompt)\r\n        script_callbacks.infotext_pasted_callback(prompt, params)\r\n        res = []\r\n\r\n        for output, key in paste_fields:\r\n            if callable(key):\r\n                try:\r\n                    v = key(params)\r\n                except Exception:\r\n                    errors.report(f\"Error executing {key}\", exc_info=True)\r\n                    v = None\r\n            else:\r\n                v = params.get(key, None)\r\n\r\n            if v is None:\r\n                res.append(gr.update())\r\n            elif isinstance(v, type_of_gr_update):\r\n                res.append(v)\r\n            else:\r\n                try:\r\n                    valtype = type(output.value)\r\n\r\n                    if valtype == bool and v == \"False\":\r\n                        val = False\r\n                    elif valtype == int:\r\n                        val = float(v)\r\n                    else:\r\n                        val = valtype(v)\r\n\r\n                    res.append(gr.update(value=val))\r\n                except Exception:\r\n                    res.append(gr.update())\r\n\r\n        return res\r\n\r\n    if override_settings_component is not None:\r\n        already_handled_fields = {key: 1 for _, key in paste_fields}\r\n\r\n        def paste_settings(params):\r\n            vals = get_override_settings(params, skip_fields=already_handled_fields)\r\n\r\n            vals_pairs = [f\"{infotext_text}: {value}\" for infotext_text, setting_name, value in vals]\r\n\r\n            return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=bool(vals_pairs))\r\n\r\n        paste_fields = paste_fields + [(override_settings_component, paste_settings)]\r\n\r\n    button.click(\r\n        fn=paste_func,\r\n        inputs=[input_comp],\r\n        outputs=[x[0] for x in paste_fields],\r\n        show_progress=False,\r\n    )\r\n    button.click(\r\n        fn=None,\r\n        _js=f\"recalculate_prompts_{tabname}\",\r\n        inputs=[],\r\n        outputs=[],\r\n        show_progress=False,\r\n    )\r\n\r\n"
  },
  {
    "path": "modules/infotext_versions.py",
    "content": "from modules import shared\r\nfrom packaging import version\r\nimport re\r\n\r\n\r\nv160 = version.parse(\"1.6.0\")\r\nv170_tsnr = version.parse(\"v1.7.0-225\")\r\nv180 = version.parse(\"1.8.0\")\r\nv180_hr_styles = version.parse(\"1.8.0-139\")\r\n\r\n\r\ndef parse_version(text):\r\n    if text is None:\r\n        return None\r\n\r\n    m = re.match(r'([^-]+-[^-]+)-.*', text)\r\n    if m:\r\n        text = m.group(1)\r\n\r\n    try:\r\n        return version.parse(text)\r\n    except Exception:\r\n        return None\r\n\r\n\r\ndef backcompat(d):\r\n    \"\"\"Checks infotext Version field, and enables backwards compatibility options according to it.\"\"\"\r\n\r\n    if not shared.opts.auto_backcompat:\r\n        return\r\n\r\n    ver = parse_version(d.get(\"Version\"))\r\n    if ver is None:\r\n        return\r\n\r\n    if ver < v160 and '[' in d.get('Prompt', ''):\r\n        d[\"Old prompt editing timelines\"] = True\r\n\r\n    if ver < v160 and d.get('Sampler', '') in ('DDIM', 'PLMS'):\r\n        d[\"Pad conds v0\"] = True\r\n\r\n    if ver < v170_tsnr:\r\n        d[\"Downcast alphas_cumprod\"] = True\r\n\r\n    if ver < v180 and d.get('Refiner'):\r\n        d[\"Refiner switch by sampling steps\"] = True\r\n"
  },
  {
    "path": "modules/initialize.py",
    "content": "import importlib\r\nimport logging\r\nimport os\r\nimport sys\r\nimport warnings\r\nfrom threading import Thread\r\n\r\nfrom modules.timer import startup_timer\r\n\r\n\r\ndef imports():\r\n    logging.getLogger(\"torch.distributed.nn\").setLevel(logging.ERROR)  # sshh...\r\n    logging.getLogger(\"xformers\").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())\r\n\r\n    import torch  # noqa: F401\r\n    startup_timer.record(\"import torch\")\r\n    import pytorch_lightning  # noqa: F401\r\n    startup_timer.record(\"import torch\")\r\n    warnings.filterwarnings(action=\"ignore\", category=DeprecationWarning, module=\"pytorch_lightning\")\r\n    warnings.filterwarnings(action=\"ignore\", category=UserWarning, module=\"torchvision\")\r\n\r\n    os.environ.setdefault('GRADIO_ANALYTICS_ENABLED', 'False')\r\n    import gradio  # noqa: F401\r\n    startup_timer.record(\"import gradio\")\r\n\r\n    from modules import paths, timer, import_hook, errors  # noqa: F401\r\n    startup_timer.record(\"setup paths\")\r\n\r\n    import ldm.modules.encoders.modules  # noqa: F401\r\n    startup_timer.record(\"import ldm\")\r\n\r\n    import sgm.modules.encoders.modules  # noqa: F401\r\n    startup_timer.record(\"import sgm\")\r\n\r\n    from modules import shared_init\r\n    shared_init.initialize()\r\n    startup_timer.record(\"initialize shared\")\r\n\r\n    from modules import processing, gradio_extensons, ui  # noqa: F401\r\n    startup_timer.record(\"other imports\")\r\n\r\n\r\ndef check_versions():\r\n    from modules.shared_cmd_options import cmd_opts\r\n\r\n    if not cmd_opts.skip_version_check:\r\n        from modules import errors\r\n        errors.check_versions()\r\n\r\n\r\ndef initialize():\r\n    from modules import initialize_util\r\n    initialize_util.fix_torch_version()\r\n    initialize_util.fix_pytorch_lightning()\r\n    initialize_util.fix_asyncio_event_loop_policy()\r\n    initialize_util.validate_tls_options()\r\n    initialize_util.configure_sigint_handler()\r\n    initialize_util.configure_opts_onchange()\r\n\r\n    from modules import sd_models\r\n    sd_models.setup_model()\r\n    startup_timer.record(\"setup SD model\")\r\n\r\n    from modules.shared_cmd_options import cmd_opts\r\n\r\n    from modules import codeformer_model\r\n    warnings.filterwarnings(action=\"ignore\", category=UserWarning, module=\"torchvision.transforms.functional_tensor\")\r\n    codeformer_model.setup_model(cmd_opts.codeformer_models_path)\r\n    startup_timer.record(\"setup codeformer\")\r\n\r\n    from modules import gfpgan_model\r\n    gfpgan_model.setup_model(cmd_opts.gfpgan_models_path)\r\n    startup_timer.record(\"setup gfpgan\")\r\n\r\n    initialize_rest(reload_script_modules=False)\r\n\r\n\r\ndef initialize_rest(*, reload_script_modules=False):\r\n    \"\"\"\r\n    Called both from initialize() and when reloading the webui.\r\n    \"\"\"\r\n    from modules.shared_cmd_options import cmd_opts\r\n\r\n    from modules import sd_samplers\r\n    sd_samplers.set_samplers()\r\n    startup_timer.record(\"set samplers\")\r\n\r\n    from modules import extensions\r\n    extensions.list_extensions()\r\n    startup_timer.record(\"list extensions\")\r\n\r\n    from modules import initialize_util\r\n    initialize_util.restore_config_state_file()\r\n    startup_timer.record(\"restore config state file\")\r\n\r\n    from modules import shared, upscaler, scripts\r\n    if cmd_opts.ui_debug_mode:\r\n        shared.sd_upscalers = upscaler.UpscalerLanczos().scalers\r\n        scripts.load_scripts()\r\n        return\r\n\r\n    from modules import sd_models\r\n    sd_models.list_models()\r\n    startup_timer.record(\"list SD models\")\r\n\r\n    from modules import localization\r\n    localization.list_localizations(cmd_opts.localizations_dir)\r\n    startup_timer.record(\"list localizations\")\r\n\r\n    with startup_timer.subcategory(\"load scripts\"):\r\n        scripts.load_scripts()\r\n\r\n    if reload_script_modules and shared.opts.enable_reloading_ui_scripts:\r\n        for module in [module for name, module in sys.modules.items() if name.startswith(\"modules.ui\")]:\r\n            importlib.reload(module)\r\n        startup_timer.record(\"reload script modules\")\r\n\r\n    from modules import modelloader\r\n    modelloader.load_upscalers()\r\n    startup_timer.record(\"load upscalers\")\r\n\r\n    from modules import sd_vae\r\n    sd_vae.refresh_vae_list()\r\n    startup_timer.record(\"refresh VAE\")\r\n\r\n    from modules import textual_inversion\r\n    textual_inversion.textual_inversion.list_textual_inversion_templates()\r\n    startup_timer.record(\"refresh textual inversion templates\")\r\n\r\n    from modules import script_callbacks, sd_hijack_optimizations, sd_hijack\r\n    script_callbacks.on_list_optimizers(sd_hijack_optimizations.list_optimizers)\r\n    sd_hijack.list_optimizers()\r\n    startup_timer.record(\"scripts list_optimizers\")\r\n\r\n    from modules import sd_unet\r\n    sd_unet.list_unets()\r\n    startup_timer.record(\"scripts list_unets\")\r\n\r\n    def load_model():\r\n        \"\"\"\r\n        Accesses shared.sd_model property to load model.\r\n        After it's available, if it has been loaded before this access by some extension,\r\n        its optimization may be None because the list of optimizers has not been filled\r\n        by that time, so we apply optimization again.\r\n        \"\"\"\r\n        from modules import devices\r\n        devices.torch_npu_set_device()\r\n\r\n        shared.sd_model  # noqa: B018\r\n\r\n        if sd_hijack.current_optimizer is None:\r\n            sd_hijack.apply_optimizations()\r\n\r\n        devices.first_time_calculation()\r\n    if not shared.cmd_opts.skip_load_model_at_start:\r\n        Thread(target=load_model).start()\r\n\r\n    from modules import shared_items\r\n    shared_items.reload_hypernetworks()\r\n    startup_timer.record(\"reload hypernetworks\")\r\n\r\n    from modules import ui_extra_networks\r\n    ui_extra_networks.initialize()\r\n    ui_extra_networks.register_default_pages()\r\n\r\n    from modules import extra_networks\r\n    extra_networks.initialize()\r\n    extra_networks.register_default_extra_networks()\r\n    startup_timer.record(\"initialize extra networks\")\r\n"
  },
  {
    "path": "modules/initialize_util.py",
    "content": "import json\r\nimport os\r\nimport signal\r\nimport sys\r\nimport re\r\n\r\nfrom modules.timer import startup_timer\r\n\r\n\r\ndef gradio_server_name():\r\n    from modules.shared_cmd_options import cmd_opts\r\n\r\n    if cmd_opts.server_name:\r\n        return cmd_opts.server_name\r\n    else:\r\n        return \"0.0.0.0\" if cmd_opts.listen else None\r\n\r\n\r\ndef fix_torch_version():\r\n    import torch\r\n\r\n    # Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors\r\n    if \".dev\" in torch.__version__ or \"+git\" in torch.__version__:\r\n        torch.__long_version__ = torch.__version__\r\n        torch.__version__ = re.search(r'[\\d.]+[\\d]', torch.__version__).group(0)\r\n\r\ndef fix_pytorch_lightning():\r\n    # Checks if pytorch_lightning.utilities.distributed already exists in the sys.modules cache\r\n    if 'pytorch_lightning.utilities.distributed' not in sys.modules:\r\n        import pytorch_lightning\r\n        # Lets the user know that the library was not found and then will set it to pytorch_lightning.utilities.rank_zero\r\n        print(\"Pytorch_lightning.distributed not found, attempting pytorch_lightning.rank_zero\")\r\n        sys.modules[\"pytorch_lightning.utilities.distributed\"] = pytorch_lightning.utilities.rank_zero\r\n\r\ndef fix_asyncio_event_loop_policy():\r\n    \"\"\"\r\n        The default `asyncio` event loop policy only automatically creates\r\n        event loops in the main threads. Other threads must create event\r\n        loops explicitly or `asyncio.get_event_loop` (and therefore\r\n        `.IOLoop.current`) will fail. Installing this policy allows event\r\n        loops to be created automatically on any thread, matching the\r\n        behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2).\r\n    \"\"\"\r\n\r\n    import asyncio\r\n\r\n    if sys.platform == \"win32\" and hasattr(asyncio, \"WindowsSelectorEventLoopPolicy\"):\r\n        # \"Any thread\" and \"selector\" should be orthogonal, but there's not a clean\r\n        # interface for composing policies so pick the right base.\r\n        _BasePolicy = asyncio.WindowsSelectorEventLoopPolicy  # type: ignore\r\n    else:\r\n        _BasePolicy = asyncio.DefaultEventLoopPolicy\r\n\r\n    class AnyThreadEventLoopPolicy(_BasePolicy):  # type: ignore\r\n        \"\"\"Event loop policy that allows loop creation on any thread.\r\n        Usage::\r\n\r\n            asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())\r\n        \"\"\"\r\n\r\n        def get_event_loop(self) -> asyncio.AbstractEventLoop:\r\n            try:\r\n                return super().get_event_loop()\r\n            except (RuntimeError, AssertionError):\r\n                # This was an AssertionError in python 3.4.2 (which ships with debian jessie)\r\n                # and changed to a RuntimeError in 3.4.3.\r\n                # \"There is no current event loop in thread %r\"\r\n                loop = self.new_event_loop()\r\n                self.set_event_loop(loop)\r\n                return loop\r\n\r\n    asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())\r\n\r\n\r\ndef restore_config_state_file():\r\n    from modules import shared, config_states\r\n\r\n    config_state_file = shared.opts.restore_config_state_file\r\n    if config_state_file == \"\":\r\n        return\r\n\r\n    shared.opts.restore_config_state_file = \"\"\r\n    shared.opts.save(shared.config_filename)\r\n\r\n    if os.path.isfile(config_state_file):\r\n        print(f\"*** About to restore extension state from file: {config_state_file}\")\r\n        with open(config_state_file, \"r\", encoding=\"utf-8\") as f:\r\n            config_state = json.load(f)\r\n            config_states.restore_extension_config(config_state)\r\n        startup_timer.record(\"restore extension config\")\r\n    elif config_state_file:\r\n        print(f\"!!! Config state backup not found: {config_state_file}\")\r\n\r\n\r\ndef validate_tls_options():\r\n    from modules.shared_cmd_options import cmd_opts\r\n\r\n    if not (cmd_opts.tls_keyfile and cmd_opts.tls_certfile):\r\n        return\r\n\r\n    try:\r\n        if not os.path.exists(cmd_opts.tls_keyfile):\r\n            print(\"Invalid path to TLS keyfile given\")\r\n        if not os.path.exists(cmd_opts.tls_certfile):\r\n            print(f\"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'\")\r\n    except TypeError:\r\n        cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None\r\n        print(\"TLS setup invalid, running webui without TLS\")\r\n    else:\r\n        print(\"Running with TLS\")\r\n    startup_timer.record(\"TLS\")\r\n\r\n\r\ndef get_gradio_auth_creds():\r\n    \"\"\"\r\n    Convert the gradio_auth and gradio_auth_path commandline arguments into\r\n    an iterable of (username, password) tuples.\r\n    \"\"\"\r\n    from modules.shared_cmd_options import cmd_opts\r\n\r\n    def process_credential_line(s):\r\n        s = s.strip()\r\n        if not s:\r\n            return None\r\n        return tuple(s.split(':', 1))\r\n\r\n    if cmd_opts.gradio_auth:\r\n        for cred in cmd_opts.gradio_auth.split(','):\r\n            cred = process_credential_line(cred)\r\n            if cred:\r\n                yield cred\r\n\r\n    if cmd_opts.gradio_auth_path:\r\n        with open(cmd_opts.gradio_auth_path, 'r', encoding=\"utf8\") as file:\r\n            for line in file.readlines():\r\n                for cred in line.strip().split(','):\r\n                    cred = process_credential_line(cred)\r\n                    if cred:\r\n                        yield cred\r\n\r\n\r\ndef dumpstacks():\r\n    import threading\r\n    import traceback\r\n\r\n    id2name = {th.ident: th.name for th in threading.enumerate()}\r\n    code = []\r\n    for threadId, stack in sys._current_frames().items():\r\n        code.append(f\"\\n# Thread: {id2name.get(threadId, '')}({threadId})\")\r\n        for filename, lineno, name, line in traceback.extract_stack(stack):\r\n            code.append(f\"\"\"File: \"{filename}\", line {lineno}, in {name}\"\"\")\r\n            if line:\r\n                code.append(\"  \" + line.strip())\r\n\r\n    print(\"\\n\".join(code))\r\n\r\n\r\ndef configure_sigint_handler():\r\n    # make the program just exit at ctrl+c without waiting for anything\r\n\r\n    from modules import shared\r\n\r\n    def sigint_handler(sig, frame):\r\n        print(f'Interrupted with signal {sig} in {frame}')\r\n\r\n        if shared.opts.dump_stacks_on_signal:\r\n            dumpstacks()\r\n\r\n        os._exit(0)\r\n\r\n    if not os.environ.get(\"COVERAGE_RUN\"):\r\n        # Don't install the immediate-quit handler when running under coverage,\r\n        # as then the coverage report won't be generated.\r\n        signal.signal(signal.SIGINT, sigint_handler)\r\n\r\n\r\ndef configure_opts_onchange():\r\n    from modules import shared, sd_models, sd_vae, ui_tempdir, sd_hijack\r\n    from modules.call_queue import wrap_queued_call\r\n\r\n    shared.opts.onchange(\"sd_model_checkpoint\", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)\r\n    shared.opts.onchange(\"sd_vae\", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False)\r\n    shared.opts.onchange(\"sd_vae_overrides_per_model_preferences\", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False)\r\n    shared.opts.onchange(\"temp_dir\", ui_tempdir.on_tmpdir_changed)\r\n    shared.opts.onchange(\"gradio_theme\", shared.reload_gradio_theme)\r\n    shared.opts.onchange(\"cross_attention_optimization\", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)\r\n    shared.opts.onchange(\"fp8_storage\", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)\r\n    shared.opts.onchange(\"cache_fp16_weight\", wrap_queued_call(lambda: sd_models.reload_model_weights(forced_reload=True)), call=False)\r\n    startup_timer.record(\"opts onchange\")\r\n\r\n\r\ndef setup_middleware(app):\r\n    from starlette.middleware.gzip import GZipMiddleware\r\n\r\n    app.middleware_stack = None  # reset current middleware to allow modifying user provided list\r\n    app.add_middleware(GZipMiddleware, minimum_size=1000)\r\n    configure_cors_middleware(app)\r\n    app.build_middleware_stack()  # rebuild middleware stack on-the-fly\r\n\r\n\r\ndef configure_cors_middleware(app):\r\n    from starlette.middleware.cors import CORSMiddleware\r\n    from modules.shared_cmd_options import cmd_opts\r\n\r\n    cors_options = {\r\n        \"allow_methods\": [\"*\"],\r\n        \"allow_headers\": [\"*\"],\r\n        \"allow_credentials\": True,\r\n    }\r\n    if cmd_opts.cors_allow_origins:\r\n        cors_options[\"allow_origins\"] = cmd_opts.cors_allow_origins.split(',')\r\n    if cmd_opts.cors_allow_origins_regex:\r\n        cors_options[\"allow_origin_regex\"] = cmd_opts.cors_allow_origins_regex\r\n    app.add_middleware(CORSMiddleware, **cors_options)\r\n\r\n"
  },
  {
    "path": "modules/interrogate.py",
    "content": "import os\r\nimport sys\r\nfrom collections import namedtuple\r\nfrom pathlib import Path\r\nimport re\r\n\r\nimport torch\r\nimport torch.hub\r\n\r\nfrom torchvision import transforms\r\nfrom torchvision.transforms.functional import InterpolationMode\r\n\r\nfrom modules import devices, paths, shared, lowvram, modelloader, errors, torch_utils\r\n\r\nblip_image_eval_size = 384\r\nclip_model_name = 'ViT-L/14'\r\n\r\nCategory = namedtuple(\"Category\", [\"name\", \"topn\", \"items\"])\r\n\r\nre_topn = re.compile(r\"\\.top(\\d+)$\")\r\n\r\ndef category_types():\r\n    return [f.stem for f in Path(shared.interrogator.content_dir).glob('*.txt')]\r\n\r\n\r\ndef download_default_clip_interrogate_categories(content_dir):\r\n    print(\"Downloading CLIP categories...\")\r\n\r\n    tmpdir = f\"{content_dir}_tmp\"\r\n    category_types = [\"artists\", \"flavors\", \"mediums\", \"movements\"]\r\n\r\n    try:\r\n        os.makedirs(tmpdir, exist_ok=True)\r\n        for category_type in category_types:\r\n            torch.hub.download_url_to_file(f\"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt\", os.path.join(tmpdir, f\"{category_type}.txt\"))\r\n        os.rename(tmpdir, content_dir)\r\n\r\n    except Exception as e:\r\n        errors.display(e, \"downloading default CLIP interrogate categories\")\r\n    finally:\r\n        if os.path.exists(tmpdir):\r\n            os.removedirs(tmpdir)\r\n\r\n\r\nclass InterrogateModels:\r\n    blip_model = None\r\n    clip_model = None\r\n    clip_preprocess = None\r\n    dtype = None\r\n    running_on_cpu = None\r\n\r\n    def __init__(self, content_dir):\r\n        self.loaded_categories = None\r\n        self.skip_categories = []\r\n        self.content_dir = content_dir\r\n        self.running_on_cpu = devices.device_interrogate == torch.device(\"cpu\")\r\n\r\n    def categories(self):\r\n        if not os.path.exists(self.content_dir):\r\n            download_default_clip_interrogate_categories(self.content_dir)\r\n\r\n        if self.loaded_categories is not None and self.skip_categories == shared.opts.interrogate_clip_skip_categories:\r\n           return self.loaded_categories\r\n\r\n        self.loaded_categories = []\r\n\r\n        if os.path.exists(self.content_dir):\r\n            self.skip_categories = shared.opts.interrogate_clip_skip_categories\r\n            category_types = []\r\n            for filename in Path(self.content_dir).glob('*.txt'):\r\n                category_types.append(filename.stem)\r\n                if filename.stem in self.skip_categories:\r\n                    continue\r\n                m = re_topn.search(filename.stem)\r\n                topn = 1 if m is None else int(m.group(1))\r\n                with open(filename, \"r\", encoding=\"utf8\") as file:\r\n                    lines = [x.strip() for x in file.readlines()]\r\n\r\n                self.loaded_categories.append(Category(name=filename.stem, topn=topn, items=lines))\r\n\r\n        return self.loaded_categories\r\n\r\n    def create_fake_fairscale(self):\r\n        class FakeFairscale:\r\n            def checkpoint_wrapper(self):\r\n                pass\r\n\r\n        sys.modules[\"fairscale.nn.checkpoint.checkpoint_activations\"] = FakeFairscale\r\n\r\n    def load_blip_model(self):\r\n        self.create_fake_fairscale()\r\n        import models.blip\r\n\r\n        files = modelloader.load_models(\r\n            model_path=os.path.join(paths.models_path, \"BLIP\"),\r\n            model_url='https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth',\r\n            ext_filter=[\".pth\"],\r\n            download_name='model_base_caption_capfilt_large.pth',\r\n        )\r\n\r\n        blip_model = models.blip.blip_decoder(pretrained=files[0], image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths[\"BLIP\"], \"configs\", \"med_config.json\"))\r\n        blip_model.eval()\r\n\r\n        return blip_model\r\n\r\n    def load_clip_model(self):\r\n        import clip\r\n\r\n        if self.running_on_cpu:\r\n            model, preprocess = clip.load(clip_model_name, device=\"cpu\", download_root=shared.cmd_opts.clip_models_path)\r\n        else:\r\n            model, preprocess = clip.load(clip_model_name, download_root=shared.cmd_opts.clip_models_path)\r\n\r\n        model.eval()\r\n        model = model.to(devices.device_interrogate)\r\n\r\n        return model, preprocess\r\n\r\n    def load(self):\r\n        if self.blip_model is None:\r\n            self.blip_model = self.load_blip_model()\r\n            if not shared.cmd_opts.no_half and not self.running_on_cpu:\r\n                self.blip_model = self.blip_model.half()\r\n\r\n        self.blip_model = self.blip_model.to(devices.device_interrogate)\r\n\r\n        if self.clip_model is None:\r\n            self.clip_model, self.clip_preprocess = self.load_clip_model()\r\n            if not shared.cmd_opts.no_half and not self.running_on_cpu:\r\n                self.clip_model = self.clip_model.half()\r\n\r\n        self.clip_model = self.clip_model.to(devices.device_interrogate)\r\n\r\n        self.dtype = torch_utils.get_param(self.clip_model).dtype\r\n\r\n    def send_clip_to_ram(self):\r\n        if not shared.opts.interrogate_keep_models_in_memory:\r\n            if self.clip_model is not None:\r\n                self.clip_model = self.clip_model.to(devices.cpu)\r\n\r\n    def send_blip_to_ram(self):\r\n        if not shared.opts.interrogate_keep_models_in_memory:\r\n            if self.blip_model is not None:\r\n                self.blip_model = self.blip_model.to(devices.cpu)\r\n\r\n    def unload(self):\r\n        self.send_clip_to_ram()\r\n        self.send_blip_to_ram()\r\n\r\n        devices.torch_gc()\r\n\r\n    def rank(self, image_features, text_array, top_count=1):\r\n        import clip\r\n\r\n        devices.torch_gc()\r\n\r\n        if shared.opts.interrogate_clip_dict_limit != 0:\r\n            text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]\r\n\r\n        top_count = min(top_count, len(text_array))\r\n        text_tokens = clip.tokenize(list(text_array), truncate=True).to(devices.device_interrogate)\r\n        text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)\r\n        text_features /= text_features.norm(dim=-1, keepdim=True)\r\n\r\n        similarity = torch.zeros((1, len(text_array))).to(devices.device_interrogate)\r\n        for i in range(image_features.shape[0]):\r\n            similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)\r\n        similarity /= image_features.shape[0]\r\n\r\n        top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)\r\n        return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]\r\n\r\n    def generate_caption(self, pil_image):\r\n        gpu_image = transforms.Compose([\r\n            transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),\r\n            transforms.ToTensor(),\r\n            transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))\r\n        ])(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)\r\n\r\n        with torch.no_grad():\r\n            caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)\r\n\r\n        return caption[0]\r\n\r\n    def interrogate(self, pil_image):\r\n        res = \"\"\r\n        shared.state.begin(job=\"interrogate\")\r\n        try:\r\n            lowvram.send_everything_to_cpu()\r\n            devices.torch_gc()\r\n\r\n            self.load()\r\n\r\n            caption = self.generate_caption(pil_image)\r\n            self.send_blip_to_ram()\r\n            devices.torch_gc()\r\n\r\n            res = caption\r\n\r\n            clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)\r\n\r\n            with torch.no_grad(), devices.autocast():\r\n                image_features = self.clip_model.encode_image(clip_image).type(self.dtype)\r\n\r\n                image_features /= image_features.norm(dim=-1, keepdim=True)\r\n\r\n                for cat in self.categories():\r\n                    matches = self.rank(image_features, cat.items, top_count=cat.topn)\r\n                    for match, score in matches:\r\n                        if shared.opts.interrogate_return_ranks:\r\n                            res += f\", ({match}:{score/100:.3f})\"\r\n                        else:\r\n                            res += f\", {match}\"\r\n\r\n        except Exception:\r\n            errors.report(\"Error interrogating\", exc_info=True)\r\n            res += \"<error>\"\r\n\r\n        self.unload()\r\n        shared.state.end()\r\n\r\n        return res\r\n"
  },
  {
    "path": "modules/launch_utils.py",
    "content": "# this scripts installs necessary requirements and launches main program in webui.py\r\nimport logging\r\nimport re\r\nimport subprocess\r\nimport os\r\nimport shutil\r\nimport sys\r\nimport importlib.util\r\nimport importlib.metadata\r\nimport platform\r\nimport json\r\nimport shlex\r\nfrom functools import lru_cache\r\n\r\nfrom modules import cmd_args, errors\r\nfrom modules.paths_internal import script_path, extensions_dir\r\nfrom modules.timer import startup_timer\r\nfrom modules import logging_config\r\n\r\nargs, _ = cmd_args.parser.parse_known_args()\r\nlogging_config.setup_logging(args.loglevel)\r\n\r\npython = sys.executable\r\ngit = os.environ.get('GIT', \"git\")\r\nindex_url = os.environ.get('INDEX_URL', \"\")\r\ndir_repos = \"repositories\"\r\n\r\n# Whether to default to printing command output\r\ndefault_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == \"1\")\r\n\r\nos.environ.setdefault('GRADIO_ANALYTICS_ENABLED', 'False')\r\n\r\n\r\ndef check_python_version():\r\n    is_windows = platform.system() == \"Windows\"\r\n    major = sys.version_info.major\r\n    minor = sys.version_info.minor\r\n    micro = sys.version_info.micro\r\n\r\n    if is_windows:\r\n        supported_minors = [10]\r\n    else:\r\n        supported_minors = [7, 8, 9, 10, 11]\r\n\r\n    if not (major == 3 and minor in supported_minors):\r\n        import modules.errors\r\n\r\n        modules.errors.print_error_explanation(f\"\"\"\r\nINCOMPATIBLE PYTHON VERSION\r\n\r\nThis program is tested with 3.10.6 Python, but you have {major}.{minor}.{micro}.\r\nIf you encounter an error with \"RuntimeError: Couldn't install torch.\" message,\r\nor any other error regarding unsuccessful package (library) installation,\r\nplease downgrade (or upgrade) to the latest version of 3.10 Python\r\nand delete current Python and \"venv\" folder in WebUI's directory.\r\n\r\nYou can download 3.10 Python from here: https://www.python.org/downloads/release/python-3106/\r\n\r\n{\"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/tag/v1.0.0-pre\" if is_windows else \"\"}\r\n\r\nUse --skip-python-version-check to suppress this warning.\r\n\"\"\")\r\n\r\n\r\n@lru_cache()\r\ndef commit_hash():\r\n    try:\r\n        return subprocess.check_output([git, \"-C\", script_path, \"rev-parse\", \"HEAD\"], shell=False, encoding='utf8').strip()\r\n    except Exception:\r\n        return \"<none>\"\r\n\r\n\r\n@lru_cache()\r\ndef git_tag():\r\n    try:\r\n        return subprocess.check_output([git, \"-C\", script_path, \"describe\", \"--tags\"], shell=False, encoding='utf8').strip()\r\n    except Exception:\r\n        try:\r\n\r\n            changelog_md = os.path.join(script_path, \"CHANGELOG.md\")\r\n            with open(changelog_md, \"r\", encoding=\"utf-8\") as file:\r\n                line = next((line.strip() for line in file if line.strip()), \"<none>\")\r\n                line = line.replace(\"## \", \"\")\r\n                return line\r\n        except Exception:\r\n            return \"<none>\"\r\n\r\n\r\ndef run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live) -> str:\r\n    if desc is not None:\r\n        print(desc)\r\n\r\n    run_kwargs = {\r\n        \"args\": command,\r\n        \"shell\": True,\r\n        \"env\": os.environ if custom_env is None else custom_env,\r\n        \"encoding\": 'utf8',\r\n        \"errors\": 'ignore',\r\n    }\r\n\r\n    if not live:\r\n        run_kwargs[\"stdout\"] = run_kwargs[\"stderr\"] = subprocess.PIPE\r\n\r\n    result = subprocess.run(**run_kwargs)\r\n\r\n    if result.returncode != 0:\r\n        error_bits = [\r\n            f\"{errdesc or 'Error running command'}.\",\r\n            f\"Command: {command}\",\r\n            f\"Error code: {result.returncode}\",\r\n        ]\r\n        if result.stdout:\r\n            error_bits.append(f\"stdout: {result.stdout}\")\r\n        if result.stderr:\r\n            error_bits.append(f\"stderr: {result.stderr}\")\r\n        raise RuntimeError(\"\\n\".join(error_bits))\r\n\r\n    return (result.stdout or \"\")\r\n\r\n\r\ndef is_installed(package):\r\n    try:\r\n        dist = importlib.metadata.distribution(package)\r\n    except importlib.metadata.PackageNotFoundError:\r\n        try:\r\n            spec = importlib.util.find_spec(package)\r\n        except ModuleNotFoundError:\r\n            return False\r\n\r\n        return spec is not None\r\n\r\n    return dist is not None\r\n\r\n\r\ndef repo_dir(name):\r\n    return os.path.join(script_path, dir_repos, name)\r\n\r\n\r\ndef run_pip(command, desc=None, live=default_command_live):\r\n    if args.skip_install:\r\n        return\r\n\r\n    index_url_line = f' --index-url {index_url}' if index_url != '' else ''\r\n    return run(f'\"{python}\" -m pip {command} --prefer-binary{index_url_line}', desc=f\"Installing {desc}\", errdesc=f\"Couldn't install {desc}\", live=live)\r\n\r\n\r\ndef check_run_python(code: str) -> bool:\r\n    result = subprocess.run([python, \"-c\", code], capture_output=True, shell=False)\r\n    return result.returncode == 0\r\n\r\n\r\ndef git_fix_workspace(dir, name):\r\n    run(f'\"{git}\" -C \"{dir}\" fetch --refetch --no-auto-gc', f\"Fetching all contents for {name}\", f\"Couldn't fetch {name}\", live=True)\r\n    run(f'\"{git}\" -C \"{dir}\" gc --aggressive --prune=now', f\"Pruning {name}\", f\"Couldn't prune {name}\", live=True)\r\n    return\r\n\r\n\r\ndef run_git(dir, name, command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live, autofix=True):\r\n    try:\r\n        return run(f'\"{git}\" -C \"{dir}\" {command}', desc=desc, errdesc=errdesc, custom_env=custom_env, live=live)\r\n    except RuntimeError:\r\n        if not autofix:\r\n            raise\r\n\r\n    print(f\"{errdesc}, attempting autofix...\")\r\n    git_fix_workspace(dir, name)\r\n\r\n    return run(f'\"{git}\" -C \"{dir}\" {command}', desc=desc, errdesc=errdesc, custom_env=custom_env, live=live)\r\n\r\n\r\ndef git_clone(url, dir, name, commithash=None):\r\n    # TODO clone into temporary dir and move if successful\r\n\r\n    if os.path.exists(dir):\r\n        if commithash is None:\r\n            return\r\n\r\n        current_hash = run_git(dir, name, 'rev-parse HEAD', None, f\"Couldn't determine {name}'s hash: {commithash}\", live=False).strip()\r\n        if current_hash == commithash:\r\n            return\r\n\r\n        if run_git(dir, name, 'config --get remote.origin.url', None, f\"Couldn't determine {name}'s origin URL\", live=False).strip() != url:\r\n            run_git(dir, name, f'remote set-url origin \"{url}\"', None, f\"Failed to set {name}'s origin URL\", live=False)\r\n\r\n        run_git(dir, name, 'fetch', f\"Fetching updates for {name}...\", f\"Couldn't fetch {name}\", autofix=False)\r\n\r\n        run_git(dir, name, f'checkout {commithash}', f\"Checking out commit for {name} with hash: {commithash}...\", f\"Couldn't checkout commit {commithash} for {name}\", live=True)\r\n\r\n        return\r\n\r\n    try:\r\n        run(f'\"{git}\" clone --config core.filemode=false \"{url}\" \"{dir}\"', f\"Cloning {name} into {dir}...\", f\"Couldn't clone {name}\", live=True)\r\n    except RuntimeError:\r\n        shutil.rmtree(dir, ignore_errors=True)\r\n        raise\r\n\r\n    if commithash is not None:\r\n        run(f'\"{git}\" -C \"{dir}\" checkout {commithash}', None, \"Couldn't checkout {name}'s hash: {commithash}\")\r\n\r\n\r\ndef git_pull_recursive(dir):\r\n    for subdir, _, _ in os.walk(dir):\r\n        if os.path.exists(os.path.join(subdir, '.git')):\r\n            try:\r\n                output = subprocess.check_output([git, '-C', subdir, 'pull', '--autostash'])\r\n                print(f\"Pulled changes for repository in '{subdir}':\\n{output.decode('utf-8').strip()}\\n\")\r\n            except subprocess.CalledProcessError as e:\r\n                print(f\"Couldn't perform 'git pull' on repository in '{subdir}':\\n{e.output.decode('utf-8').strip()}\\n\")\r\n\r\n\r\ndef version_check(commit):\r\n    try:\r\n        import requests\r\n        commits = requests.get('https://api.github.com/repos/AUTOMATIC1111/stable-diffusion-webui/branches/master').json()\r\n        if commit != \"<none>\" and commits['commit']['sha'] != commit:\r\n            print(\"--------------------------------------------------------\")\r\n            print(\"| You are not up to date with the most recent release. |\")\r\n            print(\"| Consider running `git pull` to update.               |\")\r\n            print(\"--------------------------------------------------------\")\r\n        elif commits['commit']['sha'] == commit:\r\n            print(\"You are up to date with the most recent release.\")\r\n        else:\r\n            print(\"Not a git clone, can't perform version check.\")\r\n    except Exception as e:\r\n        print(\"version check failed\", e)\r\n\r\n\r\ndef run_extension_installer(extension_dir):\r\n    path_installer = os.path.join(extension_dir, \"install.py\")\r\n    if not os.path.isfile(path_installer):\r\n        return\r\n\r\n    try:\r\n        env = os.environ.copy()\r\n        env['PYTHONPATH'] = f\"{script_path}{os.pathsep}{env.get('PYTHONPATH', '')}\"\r\n\r\n        stdout = run(f'\"{python}\" \"{path_installer}\"', errdesc=f\"Error running install.py for extension {extension_dir}\", custom_env=env).strip()\r\n        if stdout:\r\n            print(stdout)\r\n    except Exception as e:\r\n        errors.report(str(e))\r\n\r\n\r\ndef list_extensions(settings_file):\r\n    settings = {}\r\n\r\n    try:\r\n        with open(settings_file, \"r\", encoding=\"utf8\") as file:\r\n            settings = json.load(file)\r\n    except FileNotFoundError:\r\n        pass\r\n    except Exception:\r\n        errors.report(f'\\nCould not load settings\\nThe config file \"{settings_file}\" is likely corrupted\\nIt has been moved to the \"tmp/config.json\"\\nReverting config to default\\n\\n''', exc_info=True)\r\n        os.replace(settings_file, os.path.join(script_path, \"tmp\", \"config.json\"))\r\n\r\n    disabled_extensions = set(settings.get('disabled_extensions', []))\r\n    disable_all_extensions = settings.get('disable_all_extensions', 'none')\r\n\r\n    if disable_all_extensions != 'none' or args.disable_extra_extensions or args.disable_all_extensions or not os.path.isdir(extensions_dir):\r\n        return []\r\n\r\n    return [x for x in os.listdir(extensions_dir) if x not in disabled_extensions]\r\n\r\n\r\ndef run_extensions_installers(settings_file):\r\n    if not os.path.isdir(extensions_dir):\r\n        return\r\n\r\n    with startup_timer.subcategory(\"run extensions installers\"):\r\n        for dirname_extension in list_extensions(settings_file):\r\n            logging.debug(f\"Installing {dirname_extension}\")\r\n\r\n            path = os.path.join(extensions_dir, dirname_extension)\r\n\r\n            if os.path.isdir(path):\r\n                run_extension_installer(path)\r\n                startup_timer.record(dirname_extension)\r\n\r\n\r\nre_requirement = re.compile(r\"\\s*([-_a-zA-Z0-9]+)\\s*(?:==\\s*([-+_.a-zA-Z0-9]+))?\\s*\")\r\n\r\n\r\ndef requirements_met(requirements_file):\r\n    \"\"\"\r\n    Does a simple parse of a requirements.txt file to determine if all rerqirements in it\r\n    are already installed. Returns True if so, False if not installed or parsing fails.\r\n    \"\"\"\r\n\r\n    import importlib.metadata\r\n    import packaging.version\r\n\r\n    with open(requirements_file, \"r\", encoding=\"utf8\") as file:\r\n        for line in file:\r\n            if line.strip() == \"\":\r\n                continue\r\n\r\n            m = re.match(re_requirement, line)\r\n            if m is None:\r\n                return False\r\n\r\n            package = m.group(1).strip()\r\n            version_required = (m.group(2) or \"\").strip()\r\n\r\n            if version_required == \"\":\r\n                continue\r\n\r\n            try:\r\n                version_installed = importlib.metadata.version(package)\r\n            except Exception:\r\n                return False\r\n\r\n            if packaging.version.parse(version_required) != packaging.version.parse(version_installed):\r\n                return False\r\n\r\n    return True\r\n\r\n\r\ndef prepare_environment():\r\n    torch_index_url = os.environ.get('TORCH_INDEX_URL', \"https://download.pytorch.org/whl/cu121\")\r\n    torch_command = os.environ.get('TORCH_COMMAND', f\"pip install torch==2.1.2 torchvision==0.16.2 --extra-index-url {torch_index_url}\")\r\n    if args.use_ipex:\r\n        if platform.system() == \"Windows\":\r\n            # The \"Nuullll/intel-extension-for-pytorch\" wheels were built from IPEX source for Intel Arc GPU: https://github.com/intel/intel-extension-for-pytorch/tree/xpu-main\r\n            # This is NOT an Intel official release so please use it at your own risk!!\r\n            # See https://github.com/Nuullll/intel-extension-for-pytorch/releases/tag/v2.0.110%2Bxpu-master%2Bdll-bundle for details.\r\n            #\r\n            # Strengths (over official IPEX 2.0.110 windows release):\r\n            #   - AOT build (for Arc GPU only) to eliminate JIT compilation overhead: https://github.com/intel/intel-extension-for-pytorch/issues/399\r\n            #   - Bundles minimal oneAPI 2023.2 dependencies into the python wheels, so users don't need to install oneAPI for the whole system.\r\n            #   - Provides a compatible torchvision wheel: https://github.com/intel/intel-extension-for-pytorch/issues/465\r\n            # Limitation:\r\n            #   - Only works for python 3.10\r\n            url_prefix = \"https://github.com/Nuullll/intel-extension-for-pytorch/releases/download/v2.0.110%2Bxpu-master%2Bdll-bundle\"\r\n            torch_command = os.environ.get('TORCH_COMMAND', f\"pip install {url_prefix}/torch-2.0.0a0+gite9ebda2-cp310-cp310-win_amd64.whl {url_prefix}/torchvision-0.15.2a0+fa99a53-cp310-cp310-win_amd64.whl {url_prefix}/intel_extension_for_pytorch-2.0.110+gitc6ea20b-cp310-cp310-win_amd64.whl\")\r\n        else:\r\n            # Using official IPEX release for linux since it's already an AOT build.\r\n            # However, users still have to install oneAPI toolkit and activate oneAPI environment manually.\r\n            # See https://intel.github.io/intel-extension-for-pytorch/index.html#installation for details.\r\n            torch_index_url = os.environ.get('TORCH_INDEX_URL', \"https://pytorch-extension.intel.com/release-whl/stable/xpu/us/\")\r\n            torch_command = os.environ.get('TORCH_COMMAND', f\"pip install torch==2.0.0a0 intel-extension-for-pytorch==2.0.110+gitba7f6c1 --extra-index-url {torch_index_url}\")\r\n    requirements_file = os.environ.get('REQS_FILE', \"requirements_versions.txt\")\r\n    requirements_file_for_npu = os.environ.get('REQS_FILE_FOR_NPU', \"requirements_npu.txt\")\r\n\r\n    xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.23.post1')\r\n    clip_package = os.environ.get('CLIP_PACKAGE', \"https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip\")\r\n    openclip_package = os.environ.get('OPENCLIP_PACKAGE', \"https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip\")\r\n\r\n    assets_repo = os.environ.get('ASSETS_REPO', \"https://github.com/AUTOMATIC1111/stable-diffusion-webui-assets.git\")\r\n    stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', \"https://github.com/Stability-AI/stablediffusion.git\")\r\n    stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', \"https://github.com/Stability-AI/generative-models.git\")\r\n    k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')\r\n    blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')\r\n\r\n    assets_commit_hash = os.environ.get('ASSETS_COMMIT_HASH', \"6f7db241d2f8ba7457bac5ca9753331f0c266917\")\r\n    stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', \"cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf\")\r\n    stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', \"45c443b316737a4ab6e40413d7794a7f5657c19f\")\r\n    k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', \"ab527a9a6d347f364e3d185ba6d714e22d80cb3c\")\r\n    blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', \"48211a1594f1321b00f14c9f7a5b4813144b2fb9\")\r\n\r\n    try:\r\n        # the existence of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution\r\n        os.remove(os.path.join(script_path, \"tmp\", \"restart\"))\r\n        os.environ.setdefault('SD_WEBUI_RESTARTING', '1')\r\n    except OSError:\r\n        pass\r\n\r\n    if not args.skip_python_version_check:\r\n        check_python_version()\r\n\r\n    startup_timer.record(\"checks\")\r\n\r\n    commit = commit_hash()\r\n    tag = git_tag()\r\n    startup_timer.record(\"git version info\")\r\n\r\n    print(f\"Python {sys.version}\")\r\n    print(f\"Version: {tag}\")\r\n    print(f\"Commit hash: {commit}\")\r\n\r\n    if args.reinstall_torch or not is_installed(\"torch\") or not is_installed(\"torchvision\"):\r\n        run(f'\"{python}\" -m {torch_command}', \"Installing torch and torchvision\", \"Couldn't install torch\", live=True)\r\n        startup_timer.record(\"install torch\")\r\n\r\n    if args.use_ipex:\r\n        args.skip_torch_cuda_test = True\r\n    if not args.skip_torch_cuda_test and not check_run_python(\"import torch; assert torch.cuda.is_available()\"):\r\n        raise RuntimeError(\r\n            'Torch is not able to use GPU; '\r\n            'add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'\r\n        )\r\n    startup_timer.record(\"torch GPU test\")\r\n\r\n    if not is_installed(\"clip\"):\r\n        run_pip(f\"install {clip_package}\", \"clip\")\r\n        startup_timer.record(\"install clip\")\r\n\r\n    if not is_installed(\"open_clip\"):\r\n        run_pip(f\"install {openclip_package}\", \"open_clip\")\r\n        startup_timer.record(\"install open_clip\")\r\n\r\n    if (not is_installed(\"xformers\") or args.reinstall_xformers) and args.xformers:\r\n        run_pip(f\"install -U -I --no-deps {xformers_package}\", \"xformers\")\r\n        startup_timer.record(\"install xformers\")\r\n\r\n    if not is_installed(\"ngrok\") and args.ngrok:\r\n        run_pip(\"install ngrok\", \"ngrok\")\r\n        startup_timer.record(\"install ngrok\")\r\n\r\n    os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)\r\n\r\n    git_clone(assets_repo, repo_dir('stable-diffusion-webui-assets'), \"assets\", assets_commit_hash)\r\n    git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), \"Stable Diffusion\", stable_diffusion_commit_hash)\r\n    git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), \"Stable Diffusion XL\", stable_diffusion_xl_commit_hash)\r\n    git_clone(k_diffusion_repo, repo_dir('k-diffusion'), \"K-diffusion\", k_diffusion_commit_hash)\r\n    git_clone(blip_repo, repo_dir('BLIP'), \"BLIP\", blip_commit_hash)\r\n\r\n    startup_timer.record(\"clone repositores\")\r\n\r\n    if not os.path.isfile(requirements_file):\r\n        requirements_file = os.path.join(script_path, requirements_file)\r\n\r\n    if not requirements_met(requirements_file):\r\n        run_pip(f\"install -r \\\"{requirements_file}\\\"\", \"requirements\")\r\n        startup_timer.record(\"install requirements\")\r\n\r\n    if not os.path.isfile(requirements_file_for_npu):\r\n        requirements_file_for_npu = os.path.join(script_path, requirements_file_for_npu)\r\n\r\n    if \"torch_npu\" in torch_command and not requirements_met(requirements_file_for_npu):\r\n        run_pip(f\"install -r \\\"{requirements_file_for_npu}\\\"\", \"requirements_for_npu\")\r\n        startup_timer.record(\"install requirements_for_npu\")\r\n\r\n    if not args.skip_install:\r\n        run_extensions_installers(settings_file=args.ui_settings_file)\r\n\r\n    if args.update_check:\r\n        version_check(commit)\r\n        startup_timer.record(\"check version\")\r\n\r\n    if args.update_all_extensions:\r\n        git_pull_recursive(extensions_dir)\r\n        startup_timer.record(\"update extensions\")\r\n\r\n    if \"--exit\" in sys.argv:\r\n        print(\"Exiting because of --exit argument\")\r\n        exit(0)\r\n\r\n\r\ndef configure_for_tests():\r\n    if \"--api\" not in sys.argv:\r\n        sys.argv.append(\"--api\")\r\n    if \"--ckpt\" not in sys.argv:\r\n        sys.argv.append(\"--ckpt\")\r\n        sys.argv.append(os.path.join(script_path, \"test/test_files/empty.pt\"))\r\n    if \"--skip-torch-cuda-test\" not in sys.argv:\r\n        sys.argv.append(\"--skip-torch-cuda-test\")\r\n    if \"--disable-nan-check\" not in sys.argv:\r\n        sys.argv.append(\"--disable-nan-check\")\r\n\r\n    os.environ['COMMANDLINE_ARGS'] = \"\"\r\n\r\n\r\ndef start():\r\n    print(f\"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {shlex.join(sys.argv[1:])}\")\r\n    import webui\r\n    if '--nowebui' in sys.argv:\r\n        webui.api_only()\r\n    else:\r\n        webui.webui()\r\n\r\n\r\ndef dump_sysinfo():\r\n    from modules import sysinfo\r\n    import datetime\r\n\r\n    text = sysinfo.get()\r\n    filename = f\"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.json\"\r\n\r\n    with open(filename, \"w\", encoding=\"utf8\") as file:\r\n        file.write(text)\r\n\r\n    return filename\r\n"
  },
  {
    "path": "modules/localization.py",
    "content": "import json\r\nimport os\r\n\r\nfrom modules import errors, scripts\r\n\r\nlocalizations = {}\r\n\r\n\r\ndef list_localizations(dirname):\r\n    localizations.clear()\r\n\r\n    for file in os.listdir(dirname):\r\n        fn, ext = os.path.splitext(file)\r\n        if ext.lower() != \".json\":\r\n            continue\r\n\r\n        localizations[fn] = [os.path.join(dirname, file)]\r\n\r\n    for file in scripts.list_scripts(\"localizations\", \".json\"):\r\n        fn, ext = os.path.splitext(file.filename)\r\n        if fn not in localizations:\r\n            localizations[fn] = []\r\n        localizations[fn].append(file.path)\r\n\r\n\r\ndef localization_js(current_localization_name: str) -> str:\r\n    fns = localizations.get(current_localization_name, None)\r\n    data = {}\r\n    if fns is not None:\r\n        for fn in fns:\r\n            try:\r\n                with open(fn, \"r\", encoding=\"utf8\") as file:\r\n                    data.update(json.load(file))\r\n            except Exception:\r\n                errors.report(f\"Error loading localization from {fn}\", exc_info=True)\r\n\r\n    return f\"window.localization = {json.dumps(data)}\"\r\n"
  },
  {
    "path": "modules/logging_config.py",
    "content": "import logging\r\nimport os\r\n\r\ntry:\r\n    from tqdm import tqdm\r\n\r\n\r\n    class TqdmLoggingHandler(logging.Handler):\r\n        def __init__(self, fallback_handler: logging.Handler):\r\n            super().__init__()\r\n            self.fallback_handler = fallback_handler\r\n\r\n        def emit(self, record):\r\n            try:\r\n                # If there are active tqdm progress bars,\r\n                # attempt to not interfere with them.\r\n                if tqdm._instances:\r\n                    tqdm.write(self.format(record))\r\n                else:\r\n                    self.fallback_handler.emit(record)\r\n            except Exception:\r\n                self.fallback_handler.emit(record)\r\n\r\nexcept ImportError:\r\n    TqdmLoggingHandler = None\r\n\r\n\r\ndef setup_logging(loglevel):\r\n    if loglevel is None:\r\n        loglevel = os.environ.get(\"SD_WEBUI_LOG_LEVEL\")\r\n\r\n    if not loglevel:\r\n        return\r\n\r\n    if logging.root.handlers:\r\n        # Already configured, do not interfere\r\n        return\r\n\r\n    formatter = logging.Formatter(\r\n        '%(asctime)s %(levelname)s [%(name)s] %(message)s',\r\n        '%Y-%m-%d %H:%M:%S',\r\n    )\r\n\r\n    if os.environ.get(\"SD_WEBUI_RICH_LOG\"):\r\n        from rich.logging import RichHandler\r\n        handler = RichHandler()\r\n    else:\r\n        handler = logging.StreamHandler()\r\n        handler.setFormatter(formatter)\r\n\r\n    if TqdmLoggingHandler:\r\n        handler = TqdmLoggingHandler(handler)\r\n\r\n    handler.setFormatter(formatter)\r\n\r\n    log_level = getattr(logging, loglevel.upper(), None) or logging.INFO\r\n    logging.root.setLevel(log_level)\r\n    logging.root.addHandler(handler)\r\n"
  },
  {
    "path": "modules/lowvram.py",
    "content": "from collections import namedtuple\r\n\r\nimport torch\r\nfrom modules import devices, shared\r\n\r\nmodule_in_gpu = None\r\ncpu = torch.device(\"cpu\")\r\n\r\nModuleWithParent = namedtuple('ModuleWithParent', ['module', 'parent'], defaults=['None'])\r\n\r\ndef send_everything_to_cpu():\r\n    global module_in_gpu\r\n\r\n    if module_in_gpu is not None:\r\n        module_in_gpu.to(cpu)\r\n\r\n    module_in_gpu = None\r\n\r\n\r\ndef is_needed(sd_model):\r\n    return shared.cmd_opts.lowvram or shared.cmd_opts.medvram or shared.cmd_opts.medvram_sdxl and hasattr(sd_model, 'conditioner')\r\n\r\n\r\ndef apply(sd_model):\r\n    enable = is_needed(sd_model)\r\n    shared.parallel_processing_allowed = not enable\r\n\r\n    if enable:\r\n        setup_for_low_vram(sd_model, not shared.cmd_opts.lowvram)\r\n    else:\r\n        sd_model.lowvram = False\r\n\r\n\r\ndef setup_for_low_vram(sd_model, use_medvram):\r\n    if getattr(sd_model, 'lowvram', False):\r\n        return\r\n\r\n    sd_model.lowvram = True\r\n\r\n    parents = {}\r\n\r\n    def send_me_to_gpu(module, _):\r\n        \"\"\"send this module to GPU; send whatever tracked module was previous in GPU to CPU;\r\n        we add this as forward_pre_hook to a lot of modules and this way all but one of them will\r\n        be in CPU\r\n        \"\"\"\r\n        global module_in_gpu\r\n\r\n        module = parents.get(module, module)\r\n\r\n        if module_in_gpu == module:\r\n            return\r\n\r\n        if module_in_gpu is not None:\r\n            module_in_gpu.to(cpu)\r\n\r\n        module.to(devices.device)\r\n        module_in_gpu = module\r\n\r\n    # see below for register_forward_pre_hook;\r\n    # first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is\r\n    # useless here, and we just replace those methods\r\n\r\n    first_stage_model = sd_model.first_stage_model\r\n    first_stage_model_encode = sd_model.first_stage_model.encode\r\n    first_stage_model_decode = sd_model.first_stage_model.decode\r\n\r\n    def first_stage_model_encode_wrap(x):\r\n        send_me_to_gpu(first_stage_model, None)\r\n        return first_stage_model_encode(x)\r\n\r\n    def first_stage_model_decode_wrap(z):\r\n        send_me_to_gpu(first_stage_model, None)\r\n        return first_stage_model_decode(z)\r\n\r\n    to_remain_in_cpu = [\r\n        (sd_model, 'first_stage_model'),\r\n        (sd_model, 'depth_model'),\r\n        (sd_model, 'embedder'),\r\n        (sd_model, 'model'),\r\n    ]\r\n\r\n    is_sdxl = hasattr(sd_model, 'conditioner')\r\n    is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model')\r\n\r\n    if hasattr(sd_model, 'medvram_fields'):\r\n        to_remain_in_cpu = sd_model.medvram_fields()\r\n    elif is_sdxl:\r\n        to_remain_in_cpu.append((sd_model, 'conditioner'))\r\n    elif is_sd2:\r\n        to_remain_in_cpu.append((sd_model.cond_stage_model, 'model'))\r\n    else:\r\n        to_remain_in_cpu.append((sd_model.cond_stage_model, 'transformer'))\r\n\r\n    # remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model\r\n    stored = []\r\n    for obj, field in to_remain_in_cpu:\r\n        module = getattr(obj, field, None)\r\n        stored.append(module)\r\n        setattr(obj, field, None)\r\n\r\n    # send the model to GPU.\r\n    sd_model.to(devices.device)\r\n\r\n    # put modules back. the modules will be in CPU.\r\n    for (obj, field), module in zip(to_remain_in_cpu, stored):\r\n        setattr(obj, field, module)\r\n\r\n    # register hooks for those the first three models\r\n    if hasattr(sd_model, \"cond_stage_model\") and hasattr(sd_model.cond_stage_model, \"medvram_modules\"):\r\n        for module in sd_model.cond_stage_model.medvram_modules():\r\n            if isinstance(module, ModuleWithParent):\r\n                parent = module.parent\r\n                module = module.module\r\n            else:\r\n                parent = None\r\n\r\n            if module:\r\n                module.register_forward_pre_hook(send_me_to_gpu)\r\n\r\n                if parent:\r\n                    parents[module] = parent\r\n\r\n    elif is_sdxl:\r\n        sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu)\r\n    elif is_sd2:\r\n        sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu)\r\n        sd_model.cond_stage_model.model.token_embedding.register_forward_pre_hook(send_me_to_gpu)\r\n        parents[sd_model.cond_stage_model.model] = sd_model.cond_stage_model\r\n        parents[sd_model.cond_stage_model.model.token_embedding] = sd_model.cond_stage_model\r\n    else:\r\n        sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)\r\n        parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model\r\n\r\n    sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)\r\n    sd_model.first_stage_model.encode = first_stage_model_encode_wrap\r\n    sd_model.first_stage_model.decode = first_stage_model_decode_wrap\r\n    if getattr(sd_model, 'depth_model', None) is not None:\r\n        sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)\r\n    if getattr(sd_model, 'embedder', None) is not None:\r\n        sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)\r\n\r\n    if use_medvram:\r\n        sd_model.model.register_forward_pre_hook(send_me_to_gpu)\r\n    else:\r\n        diff_model = sd_model.model.diffusion_model\r\n\r\n        # the third remaining model is still too big for 4 GB, so we also do the same for its submodules\r\n        # so that only one of them is in GPU at a time\r\n        stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed\r\n        diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None\r\n        sd_model.model.to(devices.device)\r\n        diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored\r\n\r\n        # install hooks for bits of third model\r\n        diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)\r\n        for block in diff_model.input_blocks:\r\n            block.register_forward_pre_hook(send_me_to_gpu)\r\n        diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)\r\n        for block in diff_model.output_blocks:\r\n            block.register_forward_pre_hook(send_me_to_gpu)\r\n\r\n\r\ndef is_enabled(sd_model):\r\n    return sd_model.lowvram\r\n"
  },
  {
    "path": "modules/mac_specific.py",
    "content": "import logging\n\nimport torch\nfrom torch import Tensor\nimport platform\nfrom modules.sd_hijack_utils import CondFunc\nfrom packaging import version\nfrom modules import shared\n\nlog = logging.getLogger(__name__)\n\n\n# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,\n# use check `getattr` and try it for compatibility.\n# in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availability,\n# since torch 2.0.1+ nightly build, getattr(torch, 'has_mps', False) was deprecated, see https://github.com/pytorch/pytorch/pull/103279\ndef check_for_mps() -> bool:\n    if version.parse(torch.__version__) <= version.parse(\"2.0.1\"):\n        if not getattr(torch, 'has_mps', False):\n            return False\n        try:\n            torch.zeros(1).to(torch.device(\"mps\"))\n            return True\n        except Exception:\n            return False\n    else:\n        return torch.backends.mps.is_available() and torch.backends.mps.is_built()\n\n\nhas_mps = check_for_mps()\n\n\ndef torch_mps_gc() -> None:\n    try:\n        if shared.state.current_latent is not None:\n            log.debug(\"`current_latent` is set, skipping MPS garbage collection\")\n            return\n        from torch.mps import empty_cache\n        empty_cache()\n    except Exception:\n        log.warning(\"MPS garbage collection failed\", exc_info=True)\n\n\n# MPS workaround for https://github.com/pytorch/pytorch/issues/89784\ndef cumsum_fix(input, cumsum_func, *args, **kwargs):\n    if input.device.type == 'mps':\n        output_dtype = kwargs.get('dtype', input.dtype)\n        if output_dtype == torch.int64:\n            return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)\n        elif output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):\n            return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)\n    return cumsum_func(input, *args, **kwargs)\n\n\n# MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046\ndef interpolate_with_fp32_fallback(orig_func, *args, **kwargs) -> Tensor:\n    try:\n        return orig_func(*args, **kwargs)\n    except RuntimeError as e:\n        if \"not implemented for\" in str(e) and \"Half\" in str(e):\n            input_tensor = args[0]\n            return orig_func(input_tensor.to(torch.float32), *args[1:], **kwargs).to(input_tensor.dtype)\n        else:\n            print(f\"An unexpected RuntimeError occurred: {str(e)}\")\n\nif has_mps:\n    if platform.mac_ver()[0].startswith(\"13.2.\"):\n        # MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)\n        CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760)\n\n    if version.parse(torch.__version__) < version.parse(\"1.13\"):\n        # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working\n\n        # MPS workaround for https://github.com/pytorch/pytorch/issues/79383\n        CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),\n                                                          lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))\n        # MPS workaround for https://github.com/pytorch/pytorch/issues/80800\n        CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),\n                                                                                        lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')\n        # MPS workaround for https://github.com/pytorch/pytorch/issues/90532\n        CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)\n    elif version.parse(torch.__version__) > version.parse(\"1.13.1\"):\n        cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device(\"mps\")).equal(torch.ShortTensor([1,1]).to(torch.device(\"mps\")).cumsum(0))\n        cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)\n        CondFunc('torch.cumsum', cumsum_fix_func, None)\n        CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)\n        CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)\n\n        # MPS workaround for https://github.com/pytorch/pytorch/issues/96113\n        CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps')\n\n        # MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046\n        CondFunc('torch.nn.functional.interpolate', interpolate_with_fp32_fallback, None)\n\n        # MPS workaround for https://github.com/pytorch/pytorch/issues/92311\n        if platform.processor() == 'i386':\n            for funcName in ['torch.argmax', 'torch.Tensor.argmax']:\n                CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps')\n"
  },
  {
    "path": "modules/masking.py",
    "content": "from PIL import Image, ImageFilter, ImageOps\r\n\r\n\r\ndef get_crop_region_v2(mask, pad=0):\r\n    \"\"\"\r\n    Finds a rectangular region that contains all masked ares in a mask.\r\n    Returns None if mask is completely black mask (all 0)\r\n\r\n    Parameters:\r\n    mask: PIL.Image.Image L mode or numpy 1d array\r\n    pad: int number of pixels that the region will be extended on all sides\r\n    Returns: (x1, y1, x2, y2) | None\r\n\r\n    Introduced post 1.9.0\r\n    \"\"\"\r\n    mask = mask if isinstance(mask, Image.Image) else Image.fromarray(mask)\r\n    if box := mask.getbbox():\r\n        x1, y1, x2, y2 = box\r\n        return (max(x1 - pad, 0), max(y1 - pad, 0), min(x2 + pad, mask.size[0]), min(y2 + pad, mask.size[1])) if pad else box\r\n\r\n\r\ndef get_crop_region(mask, pad=0):\r\n    \"\"\"\r\n    Same function as get_crop_region_v2 but handles completely black mask (all 0) differently\r\n    when mask all black still return coordinates but the coordinates may be invalid ie x2>x1 or y2>y1\r\n    Notes: it is possible for the coordinates to be \"valid\" again if pad size is sufficiently large\r\n    (mask_size.x-pad, mask_size.y-pad, pad, pad)\r\n\r\n    Extension developer should use get_crop_region_v2 instead unless for compatibility considerations.\r\n    \"\"\"\r\n    mask = mask if isinstance(mask, Image.Image) else Image.fromarray(mask)\r\n    if box := get_crop_region_v2(mask, pad):\r\n        return box\r\n    x1, y1 = mask.size\r\n    x2 = y2 = 0\r\n    return max(x1 - pad, 0), max(y1 - pad, 0), min(x2 + pad, mask.size[0]), min(y2 + pad, mask.size[1])\r\n\r\n\r\ndef expand_crop_region(crop_region, processing_width, processing_height, image_width, image_height):\r\n    \"\"\"expands crop region get_crop_region() to match the ratio of the image the region will processed in; returns expanded region\r\n    for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128.\"\"\"\r\n\r\n    x1, y1, x2, y2 = crop_region\r\n\r\n    ratio_crop_region = (x2 - x1) / (y2 - y1)\r\n    ratio_processing = processing_width / processing_height\r\n\r\n    if ratio_crop_region > ratio_processing:\r\n        desired_height = (x2 - x1) / ratio_processing\r\n        desired_height_diff = int(desired_height - (y2-y1))\r\n        y1 -= desired_height_diff//2\r\n        y2 += desired_height_diff - desired_height_diff//2\r\n        if y2 >= image_height:\r\n            diff = y2 - image_height\r\n            y2 -= diff\r\n            y1 -= diff\r\n        if y1 < 0:\r\n            y2 -= y1\r\n            y1 -= y1\r\n        if y2 >= image_height:\r\n            y2 = image_height\r\n    else:\r\n        desired_width = (y2 - y1) * ratio_processing\r\n        desired_width_diff = int(desired_width - (x2-x1))\r\n        x1 -= desired_width_diff//2\r\n        x2 += desired_width_diff - desired_width_diff//2\r\n        if x2 >= image_width:\r\n            diff = x2 - image_width\r\n            x2 -= diff\r\n            x1 -= diff\r\n        if x1 < 0:\r\n            x2 -= x1\r\n            x1 -= x1\r\n        if x2 >= image_width:\r\n            x2 = image_width\r\n\r\n    return x1, y1, x2, y2\r\n\r\n\r\ndef fill(image, mask):\r\n    \"\"\"fills masked regions with colors from image using blur. Not extremely effective.\"\"\"\r\n\r\n    image_mod = Image.new('RGBA', (image.width, image.height))\r\n\r\n    image_masked = Image.new('RGBa', (image.width, image.height))\r\n    image_masked.paste(image.convert(\"RGBA\").convert(\"RGBa\"), mask=ImageOps.invert(mask.convert('L')))\r\n\r\n    image_masked = image_masked.convert('RGBa')\r\n\r\n    for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]:\r\n        blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA')\r\n        for _ in range(repeats):\r\n            image_mod.alpha_composite(blurred)\r\n\r\n    return image_mod.convert(\"RGB\")\r\n\r\n"
  },
  {
    "path": "modules/memmon.py",
    "content": "import threading\nimport time\nfrom collections import defaultdict\n\nimport torch\n\n\nclass MemUsageMonitor(threading.Thread):\n    run_flag = None\n    device = None\n    disabled = False\n    opts = None\n    data = None\n\n    def __init__(self, name, device, opts):\n        threading.Thread.__init__(self)\n        self.name = name\n        self.device = device\n        self.opts = opts\n\n        self.daemon = True\n        self.run_flag = threading.Event()\n        self.data = defaultdict(int)\n\n        try:\n            self.cuda_mem_get_info()\n            torch.cuda.memory_stats(self.device)\n        except Exception as e:  # AMD or whatever\n            print(f\"Warning: caught exception '{e}', memory monitor disabled\")\n            self.disabled = True\n\n    def cuda_mem_get_info(self):\n        index = self.device.index if self.device.index is not None else torch.cuda.current_device()\n        return torch.cuda.mem_get_info(index)\n\n    def run(self):\n        if self.disabled:\n            return\n\n        while True:\n            self.run_flag.wait()\n\n            torch.cuda.reset_peak_memory_stats()\n            self.data.clear()\n\n            if self.opts.memmon_poll_rate <= 0:\n                self.run_flag.clear()\n                continue\n\n            self.data[\"min_free\"] = self.cuda_mem_get_info()[0]\n\n            while self.run_flag.is_set():\n                free, total = self.cuda_mem_get_info()\n                self.data[\"min_free\"] = min(self.data[\"min_free\"], free)\n\n                time.sleep(1 / self.opts.memmon_poll_rate)\n\n    def dump_debug(self):\n        print(self, 'recorded data:')\n        for k, v in self.read().items():\n            print(k, -(v // -(1024 ** 2)))\n\n        print(self, 'raw torch memory stats:')\n        tm = torch.cuda.memory_stats(self.device)\n        for k, v in tm.items():\n            if 'bytes' not in k:\n                continue\n            print('\\t' if 'peak' in k else '', k, -(v // -(1024 ** 2)))\n\n        print(torch.cuda.memory_summary())\n\n    def monitor(self):\n        self.run_flag.set()\n\n    def read(self):\n        if not self.disabled:\n            free, total = self.cuda_mem_get_info()\n            self.data[\"free\"] = free\n            self.data[\"total\"] = total\n\n            torch_stats = torch.cuda.memory_stats(self.device)\n            self.data[\"active\"] = torch_stats[\"active.all.current\"]\n            self.data[\"active_peak\"] = torch_stats[\"active_bytes.all.peak\"]\n            self.data[\"reserved\"] = torch_stats[\"reserved_bytes.all.current\"]\n            self.data[\"reserved_peak\"] = torch_stats[\"reserved_bytes.all.peak\"]\n            self.data[\"system_peak\"] = total - self.data[\"min_free\"]\n\n        return self.data\n\n    def stop(self):\n        self.run_flag.clear()\n        return self.read()\n"
  },
  {
    "path": "modules/modelloader.py",
    "content": "from __future__ import annotations\n\nimport importlib\nimport logging\nimport os\nfrom typing import TYPE_CHECKING\nfrom urllib.parse import urlparse\n\nimport torch\n\nfrom modules import shared\nfrom modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone\n\nif TYPE_CHECKING:\n    import spandrel\n\nlogger = logging.getLogger(__name__)\n\n\ndef load_file_from_url(\n    url: str,\n    *,\n    model_dir: str,\n    progress: bool = True,\n    file_name: str | None = None,\n    hash_prefix: str | None = None,\n) -> str:\n    \"\"\"Download a file from `url` into `model_dir`, using the file present if possible.\n\n    Returns the path to the downloaded file.\n    \"\"\"\n    os.makedirs(model_dir, exist_ok=True)\n    if not file_name:\n        parts = urlparse(url)\n        file_name = os.path.basename(parts.path)\n    cached_file = os.path.abspath(os.path.join(model_dir, file_name))\n    if not os.path.exists(cached_file):\n        print(f'Downloading: \"{url}\" to {cached_file}\\n')\n        from torch.hub import download_url_to_file\n        download_url_to_file(url, cached_file, progress=progress, hash_prefix=hash_prefix)\n    return cached_file\n\n\ndef load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None, hash_prefix=None) -> list:\n    \"\"\"\n    A one-and done loader to try finding the desired models in specified directories.\n\n    @param download_name: Specify to download from model_url immediately.\n    @param model_url: If no other models are found, this will be downloaded on upscale.\n    @param model_path: The location to store/find models in.\n    @param command_path: A command-line argument to search for models in first.\n    @param ext_filter: An optional list of filename extensions to filter by\n    @param hash_prefix: the expected sha256 of the model_url\n    @return: A list of paths containing the desired model(s)\n    \"\"\"\n    output = []\n\n    try:\n        places = []\n\n        if command_path is not None and command_path != model_path:\n            pretrained_path = os.path.join(command_path, 'experiments/pretrained_models')\n            if os.path.exists(pretrained_path):\n                print(f\"Appending path: {pretrained_path}\")\n                places.append(pretrained_path)\n            elif os.path.exists(command_path):\n                places.append(command_path)\n\n        places.append(model_path)\n\n        for place in places:\n            for full_path in shared.walk_files(place, allowed_extensions=ext_filter):\n                if os.path.islink(full_path) and not os.path.exists(full_path):\n                    print(f\"Skipping broken symlink: {full_path}\")\n                    continue\n                if ext_blacklist is not None and any(full_path.endswith(x) for x in ext_blacklist):\n                    continue\n                if full_path not in output:\n                    output.append(full_path)\n\n        if model_url is not None and len(output) == 0:\n            if download_name is not None:\n                output.append(load_file_from_url(model_url, model_dir=places[0], file_name=download_name, hash_prefix=hash_prefix))\n            else:\n                output.append(model_url)\n\n    except Exception:\n        pass\n\n    return output\n\n\ndef friendly_name(file: str):\n    if file.startswith(\"http\"):\n        file = urlparse(file).path\n\n    file = os.path.basename(file)\n    model_name, extension = os.path.splitext(file)\n    return model_name\n\n\ndef load_upscalers():\n    # We can only do this 'magic' method to dynamically load upscalers if they are referenced,\n    # so we'll try to import any _model.py files before looking in __subclasses__\n    modules_dir = os.path.join(shared.script_path, \"modules\")\n    for file in os.listdir(modules_dir):\n        if \"_model.py\" in file:\n            model_name = file.replace(\"_model.py\", \"\")\n            full_model = f\"modules.{model_name}_model\"\n            try:\n                importlib.import_module(full_model)\n            except Exception:\n                pass\n\n    data = []\n    commandline_options = vars(shared.cmd_opts)\n\n    # some of upscaler classes will not go away after reloading their modules, and we'll end\n    # up with two copies of those classes. The newest copy will always be the last in the list,\n    # so we go from end to beginning and ignore duplicates\n    used_classes = {}\n    for cls in reversed(Upscaler.__subclasses__()):\n        classname = str(cls)\n        if classname not in used_classes:\n            used_classes[classname] = cls\n\n    for cls in reversed(used_classes.values()):\n        name = cls.__name__\n        cmd_name = f\"{name.lower().replace('upscaler', '')}_models_path\"\n        commandline_model_path = commandline_options.get(cmd_name, None)\n        scaler = cls(commandline_model_path)\n        scaler.user_path = commandline_model_path\n        scaler.model_download_path = commandline_model_path or scaler.model_path\n        data += scaler.scalers\n\n    shared.sd_upscalers = sorted(\n        data,\n        # Special case for UpscalerNone keeps it at the beginning of the list.\n        key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else \"\"\n    )\n\n# None: not loaded, False: failed to load, True: loaded\n_spandrel_extra_init_state = None\n\n\ndef _init_spandrel_extra_archs() -> None:\n    \"\"\"\n    Try to initialize `spandrel_extra_archs` (exactly once).\n    \"\"\"\n    global _spandrel_extra_init_state\n    if _spandrel_extra_init_state is not None:\n        return\n\n    try:\n        import spandrel\n        import spandrel_extra_arches\n        spandrel.MAIN_REGISTRY.add(*spandrel_extra_arches.EXTRA_REGISTRY)\n        _spandrel_extra_init_state = True\n    except Exception:\n        logger.warning(\"Failed to load spandrel_extra_arches\", exc_info=True)\n        _spandrel_extra_init_state = False\n\n\ndef load_spandrel_model(\n    path: str | os.PathLike,\n    *,\n    device: str | torch.device | None,\n    prefer_half: bool = False,\n    dtype: str | torch.dtype | None = None,\n    expected_architecture: str | None = None,\n) -> spandrel.ModelDescriptor:\n    global _spandrel_extra_init_state\n\n    import spandrel\n    _init_spandrel_extra_archs()\n\n    model_descriptor = spandrel.ModelLoader(device=device).load_from_file(str(path))\n    arch = model_descriptor.architecture\n    if expected_architecture and arch.name != expected_architecture:\n        logger.warning(\n            f\"Model {path!r} is not a {expected_architecture!r} model (got {arch.name!r})\",\n        )\n    half = False\n    if prefer_half:\n        if model_descriptor.supports_half:\n            model_descriptor.model.half()\n            half = True\n        else:\n            logger.info(\"Model %s does not support half precision, ignoring --half\", path)\n    if dtype:\n        model_descriptor.model.to(dtype=dtype)\n    model_descriptor.model.eval()\n    logger.debug(\n        \"Loaded %s from %s (device=%s, half=%s, dtype=%s)\",\n        arch, path, device, half, dtype,\n    )\n    return model_descriptor\n"
  },
  {
    "path": "modules/models/diffusion/ddpm_edit.py",
    "content": "\"\"\"\nwild mixture of\nhttps://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py\nhttps://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py\nhttps://github.com/CompVis/taming-transformers\n-- merci\n\"\"\"\n\n# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).\n# See more details in LICENSE.\n\nimport torch\nimport torch.nn as nn\nimport numpy as np\nimport pytorch_lightning as pl\nfrom torch.optim.lr_scheduler import LambdaLR\nfrom einops import rearrange, repeat\nfrom contextlib import contextmanager\nfrom functools import partial\nfrom tqdm import tqdm\nfrom torchvision.utils import make_grid\nfrom pytorch_lightning.utilities.distributed import rank_zero_only\n\nfrom ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config\nfrom ldm.modules.ema import LitEma\nfrom ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution\nfrom ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL\nfrom ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like\nfrom ldm.models.diffusion.ddim import DDIMSampler\n\ntry:\n    from ldm.models.autoencoder import VQModelInterface\nexcept Exception:\n    class VQModelInterface:\n        pass\n\n__conditioning_keys__ = {'concat': 'c_concat',\n                         'crossattn': 'c_crossattn',\n                         'adm': 'y'}\n\n\ndef disabled_train(self, mode=True):\n    \"\"\"Overwrite model.train with this function to make sure train/eval mode\n    does not change anymore.\"\"\"\n    return self\n\n\ndef uniform_on_device(r1, r2, shape, device):\n    return (r1 - r2) * torch.rand(*shape, device=device) + r2\n\n\nclass DDPM(pl.LightningModule):\n    # classic DDPM with Gaussian diffusion, in image space\n    def __init__(self,\n                 unet_config,\n                 timesteps=1000,\n                 beta_schedule=\"linear\",\n                 loss_type=\"l2\",\n                 ckpt_path=None,\n                 ignore_keys=None,\n                 load_only_unet=False,\n                 monitor=\"val/loss\",\n                 use_ema=True,\n                 first_stage_key=\"image\",\n                 image_size=256,\n                 channels=3,\n                 log_every_t=100,\n                 clip_denoised=True,\n                 linear_start=1e-4,\n                 linear_end=2e-2,\n                 cosine_s=8e-3,\n                 given_betas=None,\n                 original_elbo_weight=0.,\n                 v_posterior=0.,  # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta\n                 l_simple_weight=1.,\n                 conditioning_key=None,\n                 parameterization=\"eps\",  # all assuming fixed variance schedules\n                 scheduler_config=None,\n                 use_positional_encodings=False,\n                 learn_logvar=False,\n                 logvar_init=0.,\n                 load_ema=True,\n                 ):\n        super().__init__()\n        assert parameterization in [\"eps\", \"x0\"], 'currently only supporting \"eps\" and \"x0\"'\n        self.parameterization = parameterization\n        print(f\"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode\")\n        self.cond_stage_model = None\n        self.clip_denoised = clip_denoised\n        self.log_every_t = log_every_t\n        self.first_stage_key = first_stage_key\n        self.image_size = image_size  # try conv?\n        self.channels = channels\n        self.use_positional_encodings = use_positional_encodings\n        self.model = DiffusionWrapper(unet_config, conditioning_key)\n        count_params(self.model, verbose=True)\n        self.use_ema = use_ema\n\n        self.use_scheduler = scheduler_config is not None\n        if self.use_scheduler:\n            self.scheduler_config = scheduler_config\n\n        self.v_posterior = v_posterior\n        self.original_elbo_weight = original_elbo_weight\n        self.l_simple_weight = l_simple_weight\n\n        if monitor is not None:\n            self.monitor = monitor\n\n        if self.use_ema and load_ema:\n            self.model_ema = LitEma(self.model)\n            print(f\"Keeping EMAs of {len(list(self.model_ema.buffers()))}.\")\n\n        if ckpt_path is not None:\n            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet)\n\n            # If initialing from EMA-only checkpoint, create EMA model after loading.\n            if self.use_ema and not load_ema:\n                self.model_ema = LitEma(self.model)\n                print(f\"Keeping EMAs of {len(list(self.model_ema.buffers()))}.\")\n\n        self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,\n                               linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)\n\n        self.loss_type = loss_type\n\n        self.learn_logvar = learn_logvar\n        self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))\n        if self.learn_logvar:\n            self.logvar = nn.Parameter(self.logvar, requires_grad=True)\n\n\n    def register_schedule(self, given_betas=None, beta_schedule=\"linear\", timesteps=1000,\n                          linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):\n        if exists(given_betas):\n            betas = given_betas\n        else:\n            betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,\n                                       cosine_s=cosine_s)\n        alphas = 1. - betas\n        alphas_cumprod = np.cumprod(alphas, axis=0)\n        alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])\n\n        timesteps, = betas.shape\n        self.num_timesteps = int(timesteps)\n        self.linear_start = linear_start\n        self.linear_end = linear_end\n        assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'\n\n        to_torch = partial(torch.tensor, dtype=torch.float32)\n\n        self.register_buffer('betas', to_torch(betas))\n        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))\n        self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))\n\n        # calculations for diffusion q(x_t | x_{t-1}) and others\n        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))\n        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))\n        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))\n        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))\n        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))\n\n        # calculations for posterior q(x_{t-1} | x_t, x_0)\n        posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (\n                    1. - alphas_cumprod) + self.v_posterior * betas\n        # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)\n        self.register_buffer('posterior_variance', to_torch(posterior_variance))\n        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain\n        self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))\n        self.register_buffer('posterior_mean_coef1', to_torch(\n            betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))\n        self.register_buffer('posterior_mean_coef2', to_torch(\n            (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))\n\n        if self.parameterization == \"eps\":\n            lvlb_weights = self.betas ** 2 / (\n                        2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))\n        elif self.parameterization == \"x0\":\n            lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))\n        else:\n            raise NotImplementedError(\"mu not supported\")\n        # TODO how to choose this term\n        lvlb_weights[0] = lvlb_weights[1]\n        self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)\n        assert not torch.isnan(self.lvlb_weights).all()\n\n    @contextmanager\n    def ema_scope(self, context=None):\n        if self.use_ema:\n            self.model_ema.store(self.model.parameters())\n            self.model_ema.copy_to(self.model)\n            if context is not None:\n                print(f\"{context}: Switched to EMA weights\")\n        try:\n            yield None\n        finally:\n            if self.use_ema:\n                self.model_ema.restore(self.model.parameters())\n                if context is not None:\n                    print(f\"{context}: Restored training weights\")\n\n    def init_from_ckpt(self, path, ignore_keys=None, only_model=False):\n        ignore_keys = ignore_keys or []\n\n        sd = torch.load(path, map_location=\"cpu\")\n        if \"state_dict\" in list(sd.keys()):\n            sd = sd[\"state_dict\"]\n        keys = list(sd.keys())\n\n        # Our model adds additional channels to the first layer to condition on an input image.\n        # For the first layer, copy existing channel weights and initialize new channel weights to zero.\n        input_keys = [\n            \"model.diffusion_model.input_blocks.0.0.weight\",\n            \"model_ema.diffusion_modelinput_blocks00weight\",\n        ]\n\n        self_sd = self.state_dict()\n        for input_key in input_keys:\n            if input_key not in sd or input_key not in self_sd:\n                continue\n\n            input_weight = self_sd[input_key]\n\n            if input_weight.size() != sd[input_key].size():\n                print(f\"Manual init: {input_key}\")\n                input_weight.zero_()\n                input_weight[:, :4, :, :].copy_(sd[input_key])\n                ignore_keys.append(input_key)\n\n        for k in keys:\n            for ik in ignore_keys:\n                if k.startswith(ik):\n                    print(f\"Deleting key {k} from state_dict.\")\n                    del sd[k]\n        missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(\n            sd, strict=False)\n        print(f\"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys\")\n        if missing:\n            print(f\"Missing Keys: {missing}\")\n        if unexpected:\n            print(f\"Unexpected Keys: {unexpected}\")\n\n    def q_mean_variance(self, x_start, t):\n        \"\"\"\n        Get the distribution q(x_t | x_0).\n        :param x_start: the [N x C x ...] tensor of noiseless inputs.\n        :param t: the number of diffusion steps (minus 1). Here, 0 means one step.\n        :return: A tuple (mean, variance, log_variance), all of x_start's shape.\n        \"\"\"\n        mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)\n        variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)\n        log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)\n        return mean, variance, log_variance\n\n    def predict_start_from_noise(self, x_t, t, noise):\n        return (\n                extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -\n                extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise\n        )\n\n    def q_posterior(self, x_start, x_t, t):\n        posterior_mean = (\n                extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +\n                extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t\n        )\n        posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)\n        posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)\n        return posterior_mean, posterior_variance, posterior_log_variance_clipped\n\n    def p_mean_variance(self, x, t, clip_denoised: bool):\n        model_out = self.model(x, t)\n        if self.parameterization == \"eps\":\n            x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)\n        elif self.parameterization == \"x0\":\n            x_recon = model_out\n        if clip_denoised:\n            x_recon.clamp_(-1., 1.)\n\n        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)\n        return model_mean, posterior_variance, posterior_log_variance\n\n    @torch.no_grad()\n    def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):\n        b, *_, device = *x.shape, x.device\n        model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)\n        noise = noise_like(x.shape, device, repeat_noise)\n        # no noise when t == 0\n        nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))\n        return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise\n\n    @torch.no_grad()\n    def p_sample_loop(self, shape, return_intermediates=False):\n        device = self.betas.device\n        b = shape[0]\n        img = torch.randn(shape, device=device)\n        intermediates = [img]\n        for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):\n            img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),\n                                clip_denoised=self.clip_denoised)\n            if i % self.log_every_t == 0 or i == self.num_timesteps - 1:\n                intermediates.append(img)\n        if return_intermediates:\n            return img, intermediates\n        return img\n\n    @torch.no_grad()\n    def sample(self, batch_size=16, return_intermediates=False):\n        image_size = self.image_size\n        channels = self.channels\n        return self.p_sample_loop((batch_size, channels, image_size, image_size),\n                                  return_intermediates=return_intermediates)\n\n    def q_sample(self, x_start, t, noise=None):\n        noise = default(noise, lambda: torch.randn_like(x_start))\n        return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +\n                extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)\n\n    def get_loss(self, pred, target, mean=True):\n        if self.loss_type == 'l1':\n            loss = (target - pred).abs()\n            if mean:\n                loss = loss.mean()\n        elif self.loss_type == 'l2':\n            if mean:\n                loss = torch.nn.functional.mse_loss(target, pred)\n            else:\n                loss = torch.nn.functional.mse_loss(target, pred, reduction='none')\n        else:\n            raise NotImplementedError(\"unknown loss type '{loss_type}'\")\n\n        return loss\n\n    def p_losses(self, x_start, t, noise=None):\n        noise = default(noise, lambda: torch.randn_like(x_start))\n        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)\n        model_out = self.model(x_noisy, t)\n\n        loss_dict = {}\n        if self.parameterization == \"eps\":\n            target = noise\n        elif self.parameterization == \"x0\":\n            target = x_start\n        else:\n            raise NotImplementedError(f\"Parameterization {self.parameterization} not yet supported\")\n\n        loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])\n\n        log_prefix = 'train' if self.training else 'val'\n\n        loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})\n        loss_simple = loss.mean() * self.l_simple_weight\n\n        loss_vlb = (self.lvlb_weights[t] * loss).mean()\n        loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})\n\n        loss = loss_simple + self.original_elbo_weight * loss_vlb\n\n        loss_dict.update({f'{log_prefix}/loss': loss})\n\n        return loss, loss_dict\n\n    def forward(self, x, *args, **kwargs):\n        # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size\n        # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'\n        t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()\n        return self.p_losses(x, t, *args, **kwargs)\n\n    def get_input(self, batch, k):\n        return batch[k]\n\n    def shared_step(self, batch):\n        x = self.get_input(batch, self.first_stage_key)\n        loss, loss_dict = self(x)\n        return loss, loss_dict\n\n    def training_step(self, batch, batch_idx):\n        loss, loss_dict = self.shared_step(batch)\n\n        self.log_dict(loss_dict, prog_bar=True,\n                      logger=True, on_step=True, on_epoch=True)\n\n        self.log(\"global_step\", self.global_step,\n                 prog_bar=True, logger=True, on_step=True, on_epoch=False)\n\n        if self.use_scheduler:\n            lr = self.optimizers().param_groups[0]['lr']\n            self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)\n\n        return loss\n\n    @torch.no_grad()\n    def validation_step(self, batch, batch_idx):\n        _, loss_dict_no_ema = self.shared_step(batch)\n        with self.ema_scope():\n            _, loss_dict_ema = self.shared_step(batch)\n            loss_dict_ema = {f\"{key}_ema\": loss_dict_ema[key] for key in loss_dict_ema}\n        self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)\n        self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)\n\n    def on_train_batch_end(self, *args, **kwargs):\n        if self.use_ema:\n            self.model_ema(self.model)\n\n    def _get_rows_from_list(self, samples):\n        n_imgs_per_row = len(samples)\n        denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')\n        denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')\n        denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)\n        return denoise_grid\n\n    @torch.no_grad()\n    def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):\n        log = {}\n        x = self.get_input(batch, self.first_stage_key)\n        N = min(x.shape[0], N)\n        n_row = min(x.shape[0], n_row)\n        x = x.to(self.device)[:N]\n        log[\"inputs\"] = x\n\n        # get diffusion row\n        diffusion_row = []\n        x_start = x[:n_row]\n\n        for t in range(self.num_timesteps):\n            if t % self.log_every_t == 0 or t == self.num_timesteps - 1:\n                t = repeat(torch.tensor([t]), '1 -> b', b=n_row)\n                t = t.to(self.device).long()\n                noise = torch.randn_like(x_start)\n                x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)\n                diffusion_row.append(x_noisy)\n\n        log[\"diffusion_row\"] = self._get_rows_from_list(diffusion_row)\n\n        if sample:\n            # get denoise row\n            with self.ema_scope(\"Plotting\"):\n                samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)\n\n            log[\"samples\"] = samples\n            log[\"denoise_row\"] = self._get_rows_from_list(denoise_row)\n\n        if return_keys:\n            if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:\n                return log\n            else:\n                return {key: log[key] for key in return_keys}\n        return log\n\n    def configure_optimizers(self):\n        lr = self.learning_rate\n        params = list(self.model.parameters())\n        if self.learn_logvar:\n            params = params + [self.logvar]\n        opt = torch.optim.AdamW(params, lr=lr)\n        return opt\n\n\nclass LatentDiffusion(DDPM):\n    \"\"\"main class\"\"\"\n    def __init__(self,\n                 first_stage_config,\n                 cond_stage_config,\n                 num_timesteps_cond=None,\n                 cond_stage_key=\"image\",\n                 cond_stage_trainable=False,\n                 concat_mode=True,\n                 cond_stage_forward=None,\n                 conditioning_key=None,\n                 scale_factor=1.0,\n                 scale_by_std=False,\n                 load_ema=True,\n                 *args, **kwargs):\n        self.num_timesteps_cond = default(num_timesteps_cond, 1)\n        self.scale_by_std = scale_by_std\n        assert self.num_timesteps_cond <= kwargs['timesteps']\n        # for backwards compatibility after implementation of DiffusionWrapper\n        if conditioning_key is None:\n            conditioning_key = 'concat' if concat_mode else 'crossattn'\n        if cond_stage_config == '__is_unconditional__':\n            conditioning_key = None\n        ckpt_path = kwargs.pop(\"ckpt_path\", None)\n        ignore_keys = kwargs.pop(\"ignore_keys\", [])\n        super().__init__(*args, conditioning_key=conditioning_key, load_ema=load_ema, **kwargs)\n        self.concat_mode = concat_mode\n        self.cond_stage_trainable = cond_stage_trainable\n        self.cond_stage_key = cond_stage_key\n        try:\n            self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1\n        except Exception:\n            self.num_downs = 0\n        if not scale_by_std:\n            self.scale_factor = scale_factor\n        else:\n            self.register_buffer('scale_factor', torch.tensor(scale_factor))\n        self.instantiate_first_stage(first_stage_config)\n        self.instantiate_cond_stage(cond_stage_config)\n        self.cond_stage_forward = cond_stage_forward\n        self.clip_denoised = False\n        self.bbox_tokenizer = None\n\n        self.restarted_from_ckpt = False\n        if ckpt_path is not None:\n            self.init_from_ckpt(ckpt_path, ignore_keys)\n            self.restarted_from_ckpt = True\n\n            if self.use_ema and not load_ema:\n                self.model_ema = LitEma(self.model)\n                print(f\"Keeping EMAs of {len(list(self.model_ema.buffers()))}.\")\n\n    def make_cond_schedule(self, ):\n        self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)\n        ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()\n        self.cond_ids[:self.num_timesteps_cond] = ids\n\n    @rank_zero_only\n    @torch.no_grad()\n    def on_train_batch_start(self, batch, batch_idx, dataloader_idx):\n        # only for very first batch\n        if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:\n            assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'\n            # set rescale weight to 1./std of encodings\n            print(\"### USING STD-RESCALING ###\")\n            x = super().get_input(batch, self.first_stage_key)\n            x = x.to(self.device)\n            encoder_posterior = self.encode_first_stage(x)\n            z = self.get_first_stage_encoding(encoder_posterior).detach()\n            del self.scale_factor\n            self.register_buffer('scale_factor', 1. / z.flatten().std())\n            print(f\"setting self.scale_factor to {self.scale_factor}\")\n            print(\"### USING STD-RESCALING ###\")\n\n    def register_schedule(self,\n                          given_betas=None, beta_schedule=\"linear\", timesteps=1000,\n                          linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):\n        super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)\n\n        self.shorten_cond_schedule = self.num_timesteps_cond > 1\n        if self.shorten_cond_schedule:\n            self.make_cond_schedule()\n\n    def instantiate_first_stage(self, config):\n        model = instantiate_from_config(config)\n        self.first_stage_model = model.eval()\n        self.first_stage_model.train = disabled_train\n        for param in self.first_stage_model.parameters():\n            param.requires_grad = False\n\n    def instantiate_cond_stage(self, config):\n        if not self.cond_stage_trainable:\n            if config == \"__is_first_stage__\":\n                print(\"Using first stage also as cond stage.\")\n                self.cond_stage_model = self.first_stage_model\n            elif config == \"__is_unconditional__\":\n                print(f\"Training {self.__class__.__name__} as an unconditional model.\")\n                self.cond_stage_model = None\n                # self.be_unconditional = True\n            else:\n                model = instantiate_from_config(config)\n                self.cond_stage_model = model.eval()\n                self.cond_stage_model.train = disabled_train\n                for param in self.cond_stage_model.parameters():\n                    param.requires_grad = False\n        else:\n            assert config != '__is_first_stage__'\n            assert config != '__is_unconditional__'\n            model = instantiate_from_config(config)\n            self.cond_stage_model = model\n\n    def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):\n        denoise_row = []\n        for zd in tqdm(samples, desc=desc):\n            denoise_row.append(self.decode_first_stage(zd.to(self.device),\n                                                            force_not_quantize=force_no_decoder_quantization))\n        n_imgs_per_row = len(denoise_row)\n        denoise_row = torch.stack(denoise_row)  # n_log_step, n_row, C, H, W\n        denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')\n        denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')\n        denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)\n        return denoise_grid\n\n    def get_first_stage_encoding(self, encoder_posterior):\n        if isinstance(encoder_posterior, DiagonalGaussianDistribution):\n            z = encoder_posterior.sample()\n        elif isinstance(encoder_posterior, torch.Tensor):\n            z = encoder_posterior\n        else:\n            raise NotImplementedError(f\"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented\")\n        return self.scale_factor * z\n\n    def get_learned_conditioning(self, c):\n        if self.cond_stage_forward is None:\n            if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):\n                c = self.cond_stage_model.encode(c)\n                if isinstance(c, DiagonalGaussianDistribution):\n                    c = c.mode()\n            else:\n                c = self.cond_stage_model(c)\n        else:\n            assert hasattr(self.cond_stage_model, self.cond_stage_forward)\n            c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)\n        return c\n\n    def meshgrid(self, h, w):\n        y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)\n        x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)\n\n        arr = torch.cat([y, x], dim=-1)\n        return arr\n\n    def delta_border(self, h, w):\n        \"\"\"\n        :param h: height\n        :param w: width\n        :return: normalized distance to image border,\n         wtith min distance = 0 at border and max dist = 0.5 at image center\n        \"\"\"\n        lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)\n        arr = self.meshgrid(h, w) / lower_right_corner\n        dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]\n        dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]\n        edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]\n        return edge_dist\n\n    def get_weighting(self, h, w, Ly, Lx, device):\n        weighting = self.delta_border(h, w)\n        weighting = torch.clip(weighting, self.split_input_params[\"clip_min_weight\"],\n                               self.split_input_params[\"clip_max_weight\"], )\n        weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)\n\n        if self.split_input_params[\"tie_braker\"]:\n            L_weighting = self.delta_border(Ly, Lx)\n            L_weighting = torch.clip(L_weighting,\n                                     self.split_input_params[\"clip_min_tie_weight\"],\n                                     self.split_input_params[\"clip_max_tie_weight\"])\n\n            L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)\n            weighting = weighting * L_weighting\n        return weighting\n\n    def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1):  # todo load once not every time, shorten code\n        \"\"\"\n        :param x: img of size (bs, c, h, w)\n        :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])\n        \"\"\"\n        bs, nc, h, w = x.shape\n\n        # number of crops in image\n        Ly = (h - kernel_size[0]) // stride[0] + 1\n        Lx = (w - kernel_size[1]) // stride[1] + 1\n\n        if uf == 1 and df == 1:\n            fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)\n            unfold = torch.nn.Unfold(**fold_params)\n\n            fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)\n\n            weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)\n            normalization = fold(weighting).view(1, 1, h, w)  # normalizes the overlap\n            weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))\n\n        elif uf > 1 and df == 1:\n            fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)\n            unfold = torch.nn.Unfold(**fold_params)\n\n            fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),\n                                dilation=1, padding=0,\n                                stride=(stride[0] * uf, stride[1] * uf))\n            fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)\n\n            weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)\n            normalization = fold(weighting).view(1, 1, h * uf, w * uf)  # normalizes the overlap\n            weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))\n\n        elif df > 1 and uf == 1:\n            fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)\n            unfold = torch.nn.Unfold(**fold_params)\n\n            fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),\n                                dilation=1, padding=0,\n                                stride=(stride[0] // df, stride[1] // df))\n            fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)\n\n            weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)\n            normalization = fold(weighting).view(1, 1, h // df, w // df)  # normalizes the overlap\n            weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))\n\n        else:\n            raise NotImplementedError\n\n        return fold, unfold, normalization, weighting\n\n    @torch.no_grad()\n    def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,\n                  cond_key=None, return_original_cond=False, bs=None, uncond=0.05):\n        x = super().get_input(batch, k)\n        if bs is not None:\n            x = x[:bs]\n        x = x.to(self.device)\n        encoder_posterior = self.encode_first_stage(x)\n        z = self.get_first_stage_encoding(encoder_posterior).detach()\n        cond_key = cond_key or self.cond_stage_key\n        xc = super().get_input(batch, cond_key)\n        if bs is not None:\n            xc[\"c_crossattn\"] = xc[\"c_crossattn\"][:bs]\n            xc[\"c_concat\"] = xc[\"c_concat\"][:bs]\n        cond = {}\n\n        # To support classifier-free guidance, randomly drop out only text conditioning 5%, only image conditioning 5%, and both 5%.\n        random = torch.rand(x.size(0), device=x.device)\n        prompt_mask = rearrange(random < 2 * uncond, \"n -> n 1 1\")\n        input_mask = 1 - rearrange((random >= uncond).float() * (random < 3 * uncond).float(), \"n -> n 1 1 1\")\n\n        null_prompt = self.get_learned_conditioning([\"\"])\n        cond[\"c_crossattn\"] = [torch.where(prompt_mask, null_prompt, self.get_learned_conditioning(xc[\"c_crossattn\"]).detach())]\n        cond[\"c_concat\"] = [input_mask * self.encode_first_stage((xc[\"c_concat\"].to(self.device))).mode().detach()]\n\n        out = [z, cond]\n        if return_first_stage_outputs:\n            xrec = self.decode_first_stage(z)\n            out.extend([x, xrec])\n        if return_original_cond:\n            out.append(xc)\n        return out\n\n    @torch.no_grad()\n    def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):\n        if predict_cids:\n            if z.dim() == 4:\n                z = torch.argmax(z.exp(), dim=1).long()\n            z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)\n            z = rearrange(z, 'b h w c -> b c h w').contiguous()\n\n        z = 1. / self.scale_factor * z\n\n        if hasattr(self, \"split_input_params\"):\n            if self.split_input_params[\"patch_distributed_vq\"]:\n                ks = self.split_input_params[\"ks\"]  # eg. (128, 128)\n                stride = self.split_input_params[\"stride\"]  # eg. (64, 64)\n                uf = self.split_input_params[\"vqf\"]\n                bs, nc, h, w = z.shape\n                if ks[0] > h or ks[1] > w:\n                    ks = (min(ks[0], h), min(ks[1], w))\n                    print(\"reducing Kernel\")\n\n                if stride[0] > h or stride[1] > w:\n                    stride = (min(stride[0], h), min(stride[1], w))\n                    print(\"reducing stride\")\n\n                fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)\n\n                z = unfold(z)  # (bn, nc * prod(**ks), L)\n                # 1. Reshape to img shape\n                z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1]))  # (bn, nc, ks[0], ks[1], L )\n\n                # 2. apply model loop over last dim\n                if isinstance(self.first_stage_model, VQModelInterface):\n                    output_list = [self.first_stage_model.decode(z[:, :, :, :, i],\n                                                                 force_not_quantize=predict_cids or force_not_quantize)\n                                   for i in range(z.shape[-1])]\n                else:\n\n                    output_list = [self.first_stage_model.decode(z[:, :, :, :, i])\n                                   for i in range(z.shape[-1])]\n\n                o = torch.stack(output_list, axis=-1)  # # (bn, nc, ks[0], ks[1], L)\n                o = o * weighting\n                # Reverse 1. reshape to img shape\n                o = o.view((o.shape[0], -1, o.shape[-1]))  # (bn, nc * ks[0] * ks[1], L)\n                # stitch crops together\n                decoded = fold(o)\n                decoded = decoded / normalization  # norm is shape (1, 1, h, w)\n                return decoded\n            else:\n                if isinstance(self.first_stage_model, VQModelInterface):\n                    return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)\n                else:\n                    return self.first_stage_model.decode(z)\n\n        else:\n            if isinstance(self.first_stage_model, VQModelInterface):\n                return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)\n            else:\n                return self.first_stage_model.decode(z)\n\n    # same as above but without decorator\n    def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):\n        if predict_cids:\n            if z.dim() == 4:\n                z = torch.argmax(z.exp(), dim=1).long()\n            z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)\n            z = rearrange(z, 'b h w c -> b c h w').contiguous()\n\n        z = 1. / self.scale_factor * z\n\n        if hasattr(self, \"split_input_params\"):\n            if self.split_input_params[\"patch_distributed_vq\"]:\n                ks = self.split_input_params[\"ks\"]  # eg. (128, 128)\n                stride = self.split_input_params[\"stride\"]  # eg. (64, 64)\n                uf = self.split_input_params[\"vqf\"]\n                bs, nc, h, w = z.shape\n                if ks[0] > h or ks[1] > w:\n                    ks = (min(ks[0], h), min(ks[1], w))\n                    print(\"reducing Kernel\")\n\n                if stride[0] > h or stride[1] > w:\n                    stride = (min(stride[0], h), min(stride[1], w))\n                    print(\"reducing stride\")\n\n                fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)\n\n                z = unfold(z)  # (bn, nc * prod(**ks), L)\n                # 1. Reshape to img shape\n                z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1]))  # (bn, nc, ks[0], ks[1], L )\n\n                # 2. apply model loop over last dim\n                if isinstance(self.first_stage_model, VQModelInterface):\n                    output_list = [self.first_stage_model.decode(z[:, :, :, :, i],\n                                                                 force_not_quantize=predict_cids or force_not_quantize)\n                                   for i in range(z.shape[-1])]\n                else:\n\n                    output_list = [self.first_stage_model.decode(z[:, :, :, :, i])\n                                   for i in range(z.shape[-1])]\n\n                o = torch.stack(output_list, axis=-1)  # # (bn, nc, ks[0], ks[1], L)\n                o = o * weighting\n                # Reverse 1. reshape to img shape\n                o = o.view((o.shape[0], -1, o.shape[-1]))  # (bn, nc * ks[0] * ks[1], L)\n                # stitch crops together\n                decoded = fold(o)\n                decoded = decoded / normalization  # norm is shape (1, 1, h, w)\n                return decoded\n            else:\n                if isinstance(self.first_stage_model, VQModelInterface):\n                    return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)\n                else:\n                    return self.first_stage_model.decode(z)\n\n        else:\n            if isinstance(self.first_stage_model, VQModelInterface):\n                return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)\n            else:\n                return self.first_stage_model.decode(z)\n\n    @torch.no_grad()\n    def encode_first_stage(self, x):\n        if hasattr(self, \"split_input_params\"):\n            if self.split_input_params[\"patch_distributed_vq\"]:\n                ks = self.split_input_params[\"ks\"]  # eg. (128, 128)\n                stride = self.split_input_params[\"stride\"]  # eg. (64, 64)\n                df = self.split_input_params[\"vqf\"]\n                self.split_input_params['original_image_size'] = x.shape[-2:]\n                bs, nc, h, w = x.shape\n                if ks[0] > h or ks[1] > w:\n                    ks = (min(ks[0], h), min(ks[1], w))\n                    print(\"reducing Kernel\")\n\n                if stride[0] > h or stride[1] > w:\n                    stride = (min(stride[0], h), min(stride[1], w))\n                    print(\"reducing stride\")\n\n                fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)\n                z = unfold(x)  # (bn, nc * prod(**ks), L)\n                # Reshape to img shape\n                z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1]))  # (bn, nc, ks[0], ks[1], L )\n\n                output_list = [self.first_stage_model.encode(z[:, :, :, :, i])\n                               for i in range(z.shape[-1])]\n\n                o = torch.stack(output_list, axis=-1)\n                o = o * weighting\n\n                # Reverse reshape to img shape\n                o = o.view((o.shape[0], -1, o.shape[-1]))  # (bn, nc * ks[0] * ks[1], L)\n                # stitch crops together\n                decoded = fold(o)\n                decoded = decoded / normalization\n                return decoded\n\n            else:\n                return self.first_stage_model.encode(x)\n        else:\n            return self.first_stage_model.encode(x)\n\n    def shared_step(self, batch, **kwargs):\n        x, c = self.get_input(batch, self.first_stage_key)\n        loss = self(x, c)\n        return loss\n\n    def forward(self, x, c, *args, **kwargs):\n        t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()\n        if self.model.conditioning_key is not None:\n            assert c is not None\n            if self.cond_stage_trainable:\n                c = self.get_learned_conditioning(c)\n            if self.shorten_cond_schedule:  # TODO: drop this option\n                tc = self.cond_ids[t].to(self.device)\n                c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))\n        return self.p_losses(x, c, t, *args, **kwargs)\n\n    def apply_model(self, x_noisy, t, cond, return_ids=False):\n\n        if isinstance(cond, dict):\n            # hybrid case, cond is expected to be a dict\n            pass\n        else:\n            if not isinstance(cond, list):\n                cond = [cond]\n            key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'\n            cond = {key: cond}\n\n        if hasattr(self, \"split_input_params\"):\n            assert len(cond) == 1  # todo can only deal with one conditioning atm\n            assert not return_ids\n            ks = self.split_input_params[\"ks\"]  # eg. (128, 128)\n            stride = self.split_input_params[\"stride\"]  # eg. (64, 64)\n\n            h, w = x_noisy.shape[-2:]\n\n            fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)\n\n            z = unfold(x_noisy)  # (bn, nc * prod(**ks), L)\n            # Reshape to img shape\n            z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1]))  # (bn, nc, ks[0], ks[1], L )\n            z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]\n\n            if self.cond_stage_key in [\"image\", \"LR_image\", \"segmentation\",\n                                       'bbox_img'] and self.model.conditioning_key:  # todo check for completeness\n                c_key = next(iter(cond.keys()))  # get key\n                c = next(iter(cond.values()))  # get value\n                assert (len(c) == 1)  # todo extend to list with more than one elem\n                c = c[0]  # get element\n\n                c = unfold(c)\n                c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1]))  # (bn, nc, ks[0], ks[1], L )\n\n                cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]\n\n            elif self.cond_stage_key == 'coordinates_bbox':\n                assert 'original_image_size' in self.split_input_params, 'BoundingBoxRescaling is missing original_image_size'\n\n                # assuming padding of unfold is always 0 and its dilation is always 1\n                n_patches_per_row = int((w - ks[0]) / stride[0] + 1)\n                full_img_h, full_img_w = self.split_input_params['original_image_size']\n                # as we are operating on latents, we need the factor from the original image size to the\n                # spatial latent size to properly rescale the crops for regenerating the bbox annotations\n                num_downs = self.first_stage_model.encoder.num_resolutions - 1\n                rescale_latent = 2 ** (num_downs)\n\n                # get top left positions of patches as conforming for the bbbox tokenizer, therefore we\n                # need to rescale the tl patch coordinates to be in between (0,1)\n                tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,\n                                         rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)\n                                        for patch_nr in range(z.shape[-1])]\n\n                # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)\n                patch_limits = [(x_tl, y_tl,\n                                 rescale_latent * ks[0] / full_img_w,\n                                 rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]\n                # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]\n\n                # tokenize crop coordinates for the bounding boxes of the respective patches\n                patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)\n                                      for bbox in patch_limits]  # list of length l with tensors of shape (1, 2)\n                print(patch_limits_tknzd[0].shape)\n                # cut tknzd crop position from conditioning\n                assert isinstance(cond, dict), 'cond must be dict to be fed into model'\n                cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)\n                print(cut_cond.shape)\n\n                adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])\n                adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')\n                print(adapted_cond.shape)\n                adapted_cond = self.get_learned_conditioning(adapted_cond)\n                print(adapted_cond.shape)\n                adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])\n                print(adapted_cond.shape)\n\n                cond_list = [{'c_crossattn': [e]} for e in adapted_cond]\n\n            else:\n                cond_list = [cond for i in range(z.shape[-1])]  # Todo make this more efficient\n\n            # apply model by loop over crops\n            output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]\n            assert not isinstance(output_list[0],\n                                  tuple)  # todo cant deal with multiple model outputs check this never happens\n\n            o = torch.stack(output_list, axis=-1)\n            o = o * weighting\n            # Reverse reshape to img shape\n            o = o.view((o.shape[0], -1, o.shape[-1]))  # (bn, nc * ks[0] * ks[1], L)\n            # stitch crops together\n            x_recon = fold(o) / normalization\n\n        else:\n            x_recon = self.model(x_noisy, t, **cond)\n\n        if isinstance(x_recon, tuple) and not return_ids:\n            return x_recon[0]\n        else:\n            return x_recon\n\n    def _predict_eps_from_xstart(self, x_t, t, pred_xstart):\n        return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \\\n               extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)\n\n    def _prior_bpd(self, x_start):\n        \"\"\"\n        Get the prior KL term for the variational lower-bound, measured in\n        bits-per-dim.\n        This term can't be optimized, as it only depends on the encoder.\n        :param x_start: the [N x C x ...] tensor of inputs.\n        :return: a batch of [N] KL values (in bits), one per batch element.\n        \"\"\"\n        batch_size = x_start.shape[0]\n        t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)\n        qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)\n        kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)\n        return mean_flat(kl_prior) / np.log(2.0)\n\n    def p_losses(self, x_start, cond, t, noise=None):\n        noise = default(noise, lambda: torch.randn_like(x_start))\n        x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)\n        model_output = self.apply_model(x_noisy, t, cond)\n\n        loss_dict = {}\n        prefix = 'train' if self.training else 'val'\n\n        if self.parameterization == \"x0\":\n            target = x_start\n        elif self.parameterization == \"eps\":\n            target = noise\n        else:\n            raise NotImplementedError()\n\n        loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])\n        loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})\n\n        logvar_t = self.logvar[t].to(self.device)\n        loss = loss_simple / torch.exp(logvar_t) + logvar_t\n        # loss = loss_simple / torch.exp(self.logvar) + self.logvar\n        if self.learn_logvar:\n            loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})\n            loss_dict.update({'logvar': self.logvar.data.mean()})\n\n        loss = self.l_simple_weight * loss.mean()\n\n        loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))\n        loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()\n        loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})\n        loss += (self.original_elbo_weight * loss_vlb)\n        loss_dict.update({f'{prefix}/loss': loss})\n\n        return loss, loss_dict\n\n    def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,\n                        return_x0=False, score_corrector=None, corrector_kwargs=None):\n        t_in = t\n        model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)\n\n        if score_corrector is not None:\n            assert self.parameterization == \"eps\"\n            model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)\n\n        if return_codebook_ids:\n            model_out, logits = model_out\n\n        if self.parameterization == \"eps\":\n            x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)\n        elif self.parameterization == \"x0\":\n            x_recon = model_out\n        else:\n            raise NotImplementedError()\n\n        if clip_denoised:\n            x_recon.clamp_(-1., 1.)\n        if quantize_denoised:\n            x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)\n        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)\n        if return_codebook_ids:\n            return model_mean, posterior_variance, posterior_log_variance, logits\n        elif return_x0:\n            return model_mean, posterior_variance, posterior_log_variance, x_recon\n        else:\n            return model_mean, posterior_variance, posterior_log_variance\n\n    @torch.no_grad()\n    def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,\n                 return_codebook_ids=False, quantize_denoised=False, return_x0=False,\n                 temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):\n        b, *_, device = *x.shape, x.device\n        outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,\n                                       return_codebook_ids=return_codebook_ids,\n                                       quantize_denoised=quantize_denoised,\n                                       return_x0=return_x0,\n                                       score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)\n        if return_codebook_ids:\n            raise DeprecationWarning(\"Support dropped.\")\n            model_mean, _, model_log_variance, logits = outputs\n        elif return_x0:\n            model_mean, _, model_log_variance, x0 = outputs\n        else:\n            model_mean, _, model_log_variance = outputs\n\n        noise = noise_like(x.shape, device, repeat_noise) * temperature\n        if noise_dropout > 0.:\n            noise = torch.nn.functional.dropout(noise, p=noise_dropout)\n        # no noise when t == 0\n        nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))\n\n        if return_codebook_ids:\n            return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)\n        if return_x0:\n            return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0\n        else:\n            return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise\n\n    @torch.no_grad()\n    def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,\n                              img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,\n                              score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,\n                              log_every_t=None):\n        if not log_every_t:\n            log_every_t = self.log_every_t\n        timesteps = self.num_timesteps\n        if batch_size is not None:\n            b = batch_size if batch_size is not None else shape[0]\n            shape = [batch_size] + list(shape)\n        else:\n            b = batch_size = shape[0]\n        if x_T is None:\n            img = torch.randn(shape, device=self.device)\n        else:\n            img = x_T\n        intermediates = []\n        if cond is not None:\n            if isinstance(cond, dict):\n                cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else\n                [x[:batch_size] for x in cond[key]] for key in cond}\n            else:\n                cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]\n\n        if start_T is not None:\n            timesteps = min(timesteps, start_T)\n        iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',\n                        total=timesteps) if verbose else reversed(\n            range(0, timesteps))\n        if type(temperature) == float:\n            temperature = [temperature] * timesteps\n\n        for i in iterator:\n            ts = torch.full((b,), i, device=self.device, dtype=torch.long)\n            if self.shorten_cond_schedule:\n                assert self.model.conditioning_key != 'hybrid'\n                tc = self.cond_ids[ts].to(cond.device)\n                cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))\n\n            img, x0_partial = self.p_sample(img, cond, ts,\n                                            clip_denoised=self.clip_denoised,\n                                            quantize_denoised=quantize_denoised, return_x0=True,\n                                            temperature=temperature[i], noise_dropout=noise_dropout,\n                                            score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)\n            if mask is not None:\n                assert x0 is not None\n                img_orig = self.q_sample(x0, ts)\n                img = img_orig * mask + (1. - mask) * img\n\n            if i % log_every_t == 0 or i == timesteps - 1:\n                intermediates.append(x0_partial)\n            if callback:\n                callback(i)\n            if img_callback:\n                img_callback(img, i)\n        return img, intermediates\n\n    @torch.no_grad()\n    def p_sample_loop(self, cond, shape, return_intermediates=False,\n                      x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,\n                      mask=None, x0=None, img_callback=None, start_T=None,\n                      log_every_t=None):\n\n        if not log_every_t:\n            log_every_t = self.log_every_t\n        device = self.betas.device\n        b = shape[0]\n        if x_T is None:\n            img = torch.randn(shape, device=device)\n        else:\n            img = x_T\n\n        intermediates = [img]\n        if timesteps is None:\n            timesteps = self.num_timesteps\n\n        if start_T is not None:\n            timesteps = min(timesteps, start_T)\n        iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(\n            range(0, timesteps))\n\n        if mask is not None:\n            assert x0 is not None\n            assert x0.shape[2:3] == mask.shape[2:3]  # spatial size has to match\n\n        for i in iterator:\n            ts = torch.full((b,), i, device=device, dtype=torch.long)\n            if self.shorten_cond_schedule:\n                assert self.model.conditioning_key != 'hybrid'\n                tc = self.cond_ids[ts].to(cond.device)\n                cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))\n\n            img = self.p_sample(img, cond, ts,\n                                clip_denoised=self.clip_denoised,\n                                quantize_denoised=quantize_denoised)\n            if mask is not None:\n                img_orig = self.q_sample(x0, ts)\n                img = img_orig * mask + (1. - mask) * img\n\n            if i % log_every_t == 0 or i == timesteps - 1:\n                intermediates.append(img)\n            if callback:\n                callback(i)\n            if img_callback:\n                img_callback(img, i)\n\n        if return_intermediates:\n            return img, intermediates\n        return img\n\n    @torch.no_grad()\n    def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,\n               verbose=True, timesteps=None, quantize_denoised=False,\n               mask=None, x0=None, shape=None,**kwargs):\n        if shape is None:\n            shape = (batch_size, self.channels, self.image_size, self.image_size)\n        if cond is not None:\n            if isinstance(cond, dict):\n                cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else\n                [x[:batch_size] for x in cond[key]] for key in cond}\n            else:\n                cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]\n        return self.p_sample_loop(cond,\n                                  shape,\n                                  return_intermediates=return_intermediates, x_T=x_T,\n                                  verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,\n                                  mask=mask, x0=x0)\n\n    @torch.no_grad()\n    def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):\n\n        if ddim:\n            ddim_sampler = DDIMSampler(self)\n            shape = (self.channels, self.image_size, self.image_size)\n            samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,\n                                                        shape,cond,verbose=False,**kwargs)\n\n        else:\n            samples, intermediates = self.sample(cond=cond, batch_size=batch_size,\n                                                 return_intermediates=True,**kwargs)\n\n        return samples, intermediates\n\n\n    @torch.no_grad()\n    def log_images(self, batch, N=4, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,\n                   quantize_denoised=True, inpaint=False, plot_denoise_rows=False, plot_progressive_rows=False,\n                   plot_diffusion_rows=False, **kwargs):\n\n        use_ddim = False\n\n        log = {}\n        z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,\n                                           return_first_stage_outputs=True,\n                                           force_c_encode=True,\n                                           return_original_cond=True,\n                                           bs=N, uncond=0)\n        N = min(x.shape[0], N)\n        n_row = min(x.shape[0], n_row)\n        log[\"inputs\"] = x\n        log[\"reals\"] = xc[\"c_concat\"]\n        log[\"reconstruction\"] = xrec\n        if self.model.conditioning_key is not None:\n            if hasattr(self.cond_stage_model, \"decode\"):\n                xc = self.cond_stage_model.decode(c)\n                log[\"conditioning\"] = xc\n            elif self.cond_stage_key in [\"caption\"]:\n                xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[\"caption\"])\n                log[\"conditioning\"] = xc\n            elif self.cond_stage_key == 'class_label':\n                xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[\"human_label\"])\n                log['conditioning'] = xc\n            elif isimage(xc):\n                log[\"conditioning\"] = xc\n            if ismap(xc):\n                log[\"original_conditioning\"] = self.to_rgb(xc)\n\n        if plot_diffusion_rows:\n            # get diffusion row\n            diffusion_row = []\n            z_start = z[:n_row]\n            for t in range(self.num_timesteps):\n                if t % self.log_every_t == 0 or t == self.num_timesteps - 1:\n                    t = repeat(torch.tensor([t]), '1 -> b', b=n_row)\n                    t = t.to(self.device).long()\n                    noise = torch.randn_like(z_start)\n                    z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)\n                    diffusion_row.append(self.decode_first_stage(z_noisy))\n\n            diffusion_row = torch.stack(diffusion_row)  # n_log_step, n_row, C, H, W\n            diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')\n            diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')\n            diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])\n            log[\"diffusion_row\"] = diffusion_grid\n\n        if sample:\n            # get denoise row\n            with self.ema_scope(\"Plotting\"):\n                samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,\n                                                         ddim_steps=ddim_steps,eta=ddim_eta)\n                # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)\n            x_samples = self.decode_first_stage(samples)\n            log[\"samples\"] = x_samples\n            if plot_denoise_rows:\n                denoise_grid = self._get_denoise_row_from_list(z_denoise_row)\n                log[\"denoise_row\"] = denoise_grid\n\n            if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(\n                    self.first_stage_model, IdentityFirstStage):\n                # also display when quantizing x0 while sampling\n                with self.ema_scope(\"Plotting Quantized Denoised\"):\n                    samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,\n                                                             ddim_steps=ddim_steps,eta=ddim_eta,\n                                                             quantize_denoised=True)\n                    # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,\n                    #                                      quantize_denoised=True)\n                x_samples = self.decode_first_stage(samples.to(self.device))\n                log[\"samples_x0_quantized\"] = x_samples\n\n            if inpaint:\n                # make a simple center square\n                h, w = z.shape[2], z.shape[3]\n                mask = torch.ones(N, h, w).to(self.device)\n                # zeros will be filled in\n                mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.\n                mask = mask[:, None, ...]\n                with self.ema_scope(\"Plotting Inpaint\"):\n\n                    samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,\n                                                ddim_steps=ddim_steps, x0=z[:N], mask=mask)\n                x_samples = self.decode_first_stage(samples.to(self.device))\n                log[\"samples_inpainting\"] = x_samples\n                log[\"mask\"] = mask\n\n                # outpaint\n                with self.ema_scope(\"Plotting Outpaint\"):\n                    samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,\n                                                ddim_steps=ddim_steps, x0=z[:N], mask=mask)\n                x_samples = self.decode_first_stage(samples.to(self.device))\n                log[\"samples_outpainting\"] = x_samples\n\n        if plot_progressive_rows:\n            with self.ema_scope(\"Plotting Progressives\"):\n                img, progressives = self.progressive_denoising(c,\n                                                               shape=(self.channels, self.image_size, self.image_size),\n                                                               batch_size=N)\n            prog_row = self._get_denoise_row_from_list(progressives, desc=\"Progressive Generation\")\n            log[\"progressive_row\"] = prog_row\n\n        if return_keys:\n            if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:\n                return log\n            else:\n                return {key: log[key] for key in return_keys}\n        return log\n\n    def configure_optimizers(self):\n        lr = self.learning_rate\n        params = list(self.model.parameters())\n        if self.cond_stage_trainable:\n            print(f\"{self.__class__.__name__}: Also optimizing conditioner params!\")\n            params = params + list(self.cond_stage_model.parameters())\n        if self.learn_logvar:\n            print('Diffusion model optimizing logvar')\n            params.append(self.logvar)\n        opt = torch.optim.AdamW(params, lr=lr)\n        if self.use_scheduler:\n            assert 'target' in self.scheduler_config\n            scheduler = instantiate_from_config(self.scheduler_config)\n\n            print(\"Setting up LambdaLR scheduler...\")\n            scheduler = [\n                {\n                    'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),\n                    'interval': 'step',\n                    'frequency': 1\n                }]\n            return [opt], scheduler\n        return opt\n\n    @torch.no_grad()\n    def to_rgb(self, x):\n        x = x.float()\n        if not hasattr(self, \"colorize\"):\n            self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)\n        x = nn.functional.conv2d(x, weight=self.colorize)\n        x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.\n        return x\n\n\nclass DiffusionWrapper(pl.LightningModule):\n    def __init__(self, diff_model_config, conditioning_key):\n        super().__init__()\n        self.diffusion_model = instantiate_from_config(diff_model_config)\n        self.conditioning_key = conditioning_key\n        assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']\n\n    def forward(self, x, t, c_concat: list = None, c_crossattn: list = None):\n        if self.conditioning_key is None:\n            out = self.diffusion_model(x, t)\n        elif self.conditioning_key == 'concat':\n            xc = torch.cat([x] + c_concat, dim=1)\n            out = self.diffusion_model(xc, t)\n        elif self.conditioning_key == 'crossattn':\n            cc = torch.cat(c_crossattn, 1)\n            out = self.diffusion_model(x, t, context=cc)\n        elif self.conditioning_key == 'hybrid':\n            xc = torch.cat([x] + c_concat, dim=1)\n            cc = torch.cat(c_crossattn, 1)\n            out = self.diffusion_model(xc, t, context=cc)\n        elif self.conditioning_key == 'adm':\n            cc = c_crossattn[0]\n            out = self.diffusion_model(x, t, y=cc)\n        else:\n            raise NotImplementedError()\n\n        return out\n\n\nclass Layout2ImgDiffusion(LatentDiffusion):\n    # TODO: move all layout-specific hacks to this class\n    def __init__(self, cond_stage_key, *args, **kwargs):\n        assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key=\"coordinates_bbox\"'\n        super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs)\n\n    def log_images(self, batch, N=8, *args, **kwargs):\n        logs = super().log_images(*args, batch=batch, N=N, **kwargs)\n\n        key = 'train' if self.training else 'validation'\n        dset = self.trainer.datamodule.datasets[key]\n        mapper = dset.conditional_builders[self.cond_stage_key]\n\n        bbox_imgs = []\n        map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno))\n        for tknzd_bbox in batch[self.cond_stage_key][:N]:\n            bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256))\n            bbox_imgs.append(bboximg)\n\n        cond_img = torch.stack(bbox_imgs, dim=0)\n        logs['bbox_image'] = cond_img\n        return logs\n"
  },
  {
    "path": "modules/models/diffusion/uni_pc/__init__.py",
    "content": "from .sampler import UniPCSampler  # noqa: F401\n"
  },
  {
    "path": "modules/models/diffusion/uni_pc/sampler.py",
    "content": "\"\"\"SAMPLING ONLY.\"\"\"\n\nimport torch\n\nfrom .uni_pc import NoiseScheduleVP, model_wrapper, UniPC\nfrom modules import shared, devices\n\n\nclass UniPCSampler(object):\n    def __init__(self, model, **kwargs):\n        super().__init__()\n        self.model = model\n        to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)\n        self.before_sample = None\n        self.after_sample = None\n        self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))\n\n    def register_buffer(self, name, attr):\n        if type(attr) == torch.Tensor:\n            if attr.device != devices.device:\n                attr = attr.to(devices.device)\n        setattr(self, name, attr)\n\n    def set_hooks(self, before_sample, after_sample, after_update):\n        self.before_sample = before_sample\n        self.after_sample = after_sample\n        self.after_update = after_update\n\n    @torch.no_grad()\n    def sample(self,\n               S,\n               batch_size,\n               shape,\n               conditioning=None,\n               callback=None,\n               normals_sequence=None,\n               img_callback=None,\n               quantize_x0=False,\n               eta=0.,\n               mask=None,\n               x0=None,\n               temperature=1.,\n               noise_dropout=0.,\n               score_corrector=None,\n               corrector_kwargs=None,\n               verbose=True,\n               x_T=None,\n               log_every_t=100,\n               unconditional_guidance_scale=1.,\n               unconditional_conditioning=None,\n               # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...\n               **kwargs\n               ):\n        if conditioning is not None:\n            if isinstance(conditioning, dict):\n                ctmp = conditioning[list(conditioning.keys())[0]]\n                while isinstance(ctmp, list):\n                    ctmp = ctmp[0]\n                cbs = ctmp.shape[0]\n                if cbs != batch_size:\n                    print(f\"Warning: Got {cbs} conditionings but batch-size is {batch_size}\")\n\n            elif isinstance(conditioning, list):\n                for ctmp in conditioning:\n                    if ctmp.shape[0] != batch_size:\n                        print(f\"Warning: Got {cbs} conditionings but batch-size is {batch_size}\")\n\n            else:\n                if conditioning.shape[0] != batch_size:\n                    print(f\"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}\")\n\n        # sampling\n        C, H, W = shape\n        size = (batch_size, C, H, W)\n        # print(f'Data shape for UniPC sampling is {size}')\n\n        device = self.model.betas.device\n        if x_T is None:\n            img = torch.randn(size, device=device)\n        else:\n            img = x_T\n\n        ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)\n\n        # SD 1.X is \"noise\", SD 2.X is \"v\"\n        model_type = \"v\" if self.model.parameterization == \"v\" else \"noise\"\n\n        model_fn = model_wrapper(\n            lambda x, t, c: self.model.apply_model(x, t, c),\n            ns,\n            model_type=model_type,\n            guidance_type=\"classifier-free\",\n            #condition=conditioning,\n            #unconditional_condition=unconditional_conditioning,\n            guidance_scale=unconditional_guidance_scale,\n        )\n\n        uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample, after_update=self.after_update)\n        x = uni_pc.sample(img, steps=S, skip_type=shared.opts.uni_pc_skip_type, method=\"multistep\", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final)\n\n        return x.to(device), None\n"
  },
  {
    "path": "modules/models/diffusion/uni_pc/uni_pc.py",
    "content": "import torch\nimport math\nimport tqdm\n\n\nclass NoiseScheduleVP:\n    def __init__(\n            self,\n            schedule='discrete',\n            betas=None,\n            alphas_cumprod=None,\n            continuous_beta_0=0.1,\n            continuous_beta_1=20.,\n        ):\n        \"\"\"Create a wrapper class for the forward SDE (VP type).\n\n        ***\n        Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.\n                We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.\n        ***\n\n        The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).\n        We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).\n        Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:\n\n            log_alpha_t = self.marginal_log_mean_coeff(t)\n            sigma_t = self.marginal_std(t)\n            lambda_t = self.marginal_lambda(t)\n\n        Moreover, as lambda(t) is an invertible function, we also support its inverse function:\n\n            t = self.inverse_lambda(lambda_t)\n\n        ===============================================================\n\n        We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).\n\n        1. For discrete-time DPMs:\n\n            For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:\n                t_i = (i + 1) / N\n            e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.\n            We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.\n\n            Args:\n                betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)\n                alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)\n\n            Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.\n\n            **Important**:  Please pay special attention for the args for `alphas_cumprod`:\n                The `alphas_cumprod` is the \\hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that\n                    q_{t_n | 0}(x_{t_n} | x_0) = N ( \\sqrt{\\hat{alpha_n}} * x_0, (1 - \\hat{alpha_n}) * I ).\n                Therefore, the notation \\hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have\n                    alpha_{t_n} = \\sqrt{\\hat{alpha_n}},\n                and\n                    log(alpha_{t_n}) = 0.5 * log(\\hat{alpha_n}).\n\n\n        2. For continuous-time DPMs:\n\n            We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise\n            schedule are the default settings in DDPM and improved-DDPM:\n\n            Args:\n                beta_min: A `float` number. The smallest beta for the linear schedule.\n                beta_max: A `float` number. The largest beta for the linear schedule.\n                cosine_s: A `float` number. The hyperparameter in the cosine schedule.\n                cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.\n                T: A `float` number. The ending time of the forward process.\n\n        ===============================================================\n\n        Args:\n            schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,\n                    'linear' or 'cosine' for continuous-time DPMs.\n        Returns:\n            A wrapper object of the forward SDE (VP type).\n\n        ===============================================================\n\n        Example:\n\n        # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):\n        >>> ns = NoiseScheduleVP('discrete', betas=betas)\n\n        # For discrete-time DPMs, given alphas_cumprod (the \\hat{alpha_n} array for n = 0, 1, ..., N - 1):\n        >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)\n\n        # For continuous-time DPMs (VPSDE), linear schedule:\n        >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)\n\n        \"\"\"\n\n        if schedule not in ['discrete', 'linear', 'cosine']:\n            raise ValueError(f\"Unsupported noise schedule {schedule}. The schedule needs to be 'discrete' or 'linear' or 'cosine'\")\n\n        self.schedule = schedule\n        if schedule == 'discrete':\n            if betas is not None:\n                log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)\n            else:\n                assert alphas_cumprod is not None\n                log_alphas = 0.5 * torch.log(alphas_cumprod)\n            self.total_N = len(log_alphas)\n            self.T = 1.\n            self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))\n            self.log_alpha_array = log_alphas.reshape((1, -1,))\n        else:\n            self.total_N = 1000\n            self.beta_0 = continuous_beta_0\n            self.beta_1 = continuous_beta_1\n            self.cosine_s = 0.008\n            self.cosine_beta_max = 999.\n            self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s\n            self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))\n            self.schedule = schedule\n            if schedule == 'cosine':\n                # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.\n                # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.\n                self.T = 0.9946\n            else:\n                self.T = 1.\n\n    def marginal_log_mean_coeff(self, t):\n        \"\"\"\n        Compute log(alpha_t) of a given continuous-time label t in [0, T].\n        \"\"\"\n        if self.schedule == 'discrete':\n            return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1))\n        elif self.schedule == 'linear':\n            return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0\n        elif self.schedule == 'cosine':\n            log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))\n            log_alpha_t =  log_alpha_fn(t) - self.cosine_log_alpha_0\n            return log_alpha_t\n\n    def marginal_alpha(self, t):\n        \"\"\"\n        Compute alpha_t of a given continuous-time label t in [0, T].\n        \"\"\"\n        return torch.exp(self.marginal_log_mean_coeff(t))\n\n    def marginal_std(self, t):\n        \"\"\"\n        Compute sigma_t of a given continuous-time label t in [0, T].\n        \"\"\"\n        return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))\n\n    def marginal_lambda(self, t):\n        \"\"\"\n        Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].\n        \"\"\"\n        log_mean_coeff = self.marginal_log_mean_coeff(t)\n        log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))\n        return log_mean_coeff - log_std\n\n    def inverse_lambda(self, lamb):\n        \"\"\"\n        Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.\n        \"\"\"\n        if self.schedule == 'linear':\n            tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))\n            Delta = self.beta_0**2 + tmp\n            return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)\n        elif self.schedule == 'discrete':\n            log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)\n            t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1]))\n            return t.reshape((-1,))\n        else:\n            log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))\n            t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s\n            t = t_fn(log_alpha)\n            return t\n\n\ndef model_wrapper(\n    model,\n    noise_schedule,\n    model_type=\"noise\",\n    model_kwargs=None,\n    guidance_type=\"uncond\",\n    #condition=None,\n    #unconditional_condition=None,\n    guidance_scale=1.,\n    classifier_fn=None,\n    classifier_kwargs=None,\n):\n    \"\"\"Create a wrapper function for the noise prediction model.\n\n    DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to\n    firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.\n\n    We support four types of the diffusion model by setting `model_type`:\n\n        1. \"noise\": noise prediction model. (Trained by predicting noise).\n\n        2. \"x_start\": data prediction model. (Trained by predicting the data x_0 at time 0).\n\n        3. \"v\": velocity prediction model. (Trained by predicting the velocity).\n            The \"v\" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].\n\n            [1] Salimans, Tim, and Jonathan Ho. \"Progressive distillation for fast sampling of diffusion models.\"\n                arXiv preprint arXiv:2202.00512 (2022).\n            [2] Ho, Jonathan, et al. \"Imagen Video: High Definition Video Generation with Diffusion Models.\"\n                arXiv preprint arXiv:2210.02303 (2022).\n\n        4. \"score\": marginal score function. (Trained by denoising score matching).\n            Note that the score function and the noise prediction model follows a simple relationship:\n            ```\n                noise(x_t, t) = -sigma_t * score(x_t, t)\n            ```\n\n    We support three types of guided sampling by DPMs by setting `guidance_type`:\n        1. \"uncond\": unconditional sampling by DPMs.\n            The input `model` has the following format:\n            ``\n                model(x, t_input, **model_kwargs) -> noise | x_start | v | score\n            ``\n\n        2. \"classifier\": classifier guidance sampling [3] by DPMs and another classifier.\n            The input `model` has the following format:\n            ``\n                model(x, t_input, **model_kwargs) -> noise | x_start | v | score\n            ``\n\n            The input `classifier_fn` has the following format:\n            ``\n                classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)\n            ``\n\n            [3] P. Dhariwal and A. Q. Nichol, \"Diffusion models beat GANs on image synthesis,\"\n                in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.\n\n        3. \"classifier-free\": classifier-free guidance sampling by conditional DPMs.\n            The input `model` has the following format:\n            ``\n                model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score\n            ``\n            And if cond == `unconditional_condition`, the model output is the unconditional DPM output.\n\n            [4] Ho, Jonathan, and Tim Salimans. \"Classifier-free diffusion guidance.\"\n                arXiv preprint arXiv:2207.12598 (2022).\n\n\n    The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)\n    or continuous-time labels (i.e. epsilon to T).\n\n    We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:\n    ``\n        def model_fn(x, t_continuous) -> noise:\n            t_input = get_model_input_time(t_continuous)\n            return noise_pred(model, x, t_input, **model_kwargs)\n    ``\n    where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.\n\n    ===============================================================\n\n    Args:\n        model: A diffusion model with the corresponding format described above.\n        noise_schedule: A noise schedule object, such as NoiseScheduleVP.\n        model_type: A `str`. The parameterization type of the diffusion model.\n                    \"noise\" or \"x_start\" or \"v\" or \"score\".\n        model_kwargs: A `dict`. A dict for the other inputs of the model function.\n        guidance_type: A `str`. The type of the guidance for sampling.\n                    \"uncond\" or \"classifier\" or \"classifier-free\".\n        condition: A pytorch tensor. The condition for the guided sampling.\n                    Only used for \"classifier\" or \"classifier-free\" guidance type.\n        unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.\n                    Only used for \"classifier-free\" guidance type.\n        guidance_scale: A `float`. The scale for the guided sampling.\n        classifier_fn: A classifier function. Only used for the classifier guidance.\n        classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.\n    Returns:\n        A noise prediction model that accepts the noised data and the continuous time as the inputs.\n    \"\"\"\n\n    model_kwargs = model_kwargs or {}\n    classifier_kwargs = classifier_kwargs or {}\n\n    def get_model_input_time(t_continuous):\n        \"\"\"\n        Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.\n        For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].\n        For continuous-time DPMs, we just use `t_continuous`.\n        \"\"\"\n        if noise_schedule.schedule == 'discrete':\n            return (t_continuous - 1. / noise_schedule.total_N) * 1000.\n        else:\n            return t_continuous\n\n    def noise_pred_fn(x, t_continuous, cond=None):\n        if t_continuous.reshape((-1,)).shape[0] == 1:\n            t_continuous = t_continuous.expand((x.shape[0]))\n        t_input = get_model_input_time(t_continuous)\n        if cond is None:\n            output = model(x, t_input, None, **model_kwargs)\n        else:\n            output = model(x, t_input, cond, **model_kwargs)\n        if model_type == \"noise\":\n            return output\n        elif model_type == \"x_start\":\n            alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)\n            dims = x.dim()\n            return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)\n        elif model_type == \"v\":\n            alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)\n            dims = x.dim()\n            return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x\n        elif model_type == \"score\":\n            sigma_t = noise_schedule.marginal_std(t_continuous)\n            dims = x.dim()\n            return -expand_dims(sigma_t, dims) * output\n\n    def cond_grad_fn(x, t_input, condition):\n        \"\"\"\n        Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).\n        \"\"\"\n        with torch.enable_grad():\n            x_in = x.detach().requires_grad_(True)\n            log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)\n            return torch.autograd.grad(log_prob.sum(), x_in)[0]\n\n    def model_fn(x, t_continuous, condition, unconditional_condition):\n        \"\"\"\n        The noise prediction model function that is used for DPM-Solver.\n        \"\"\"\n        if t_continuous.reshape((-1,)).shape[0] == 1:\n            t_continuous = t_continuous.expand((x.shape[0]))\n        if guidance_type == \"uncond\":\n            return noise_pred_fn(x, t_continuous)\n        elif guidance_type == \"classifier\":\n            assert classifier_fn is not None\n            t_input = get_model_input_time(t_continuous)\n            cond_grad = cond_grad_fn(x, t_input, condition)\n            sigma_t = noise_schedule.marginal_std(t_continuous)\n            noise = noise_pred_fn(x, t_continuous)\n            return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad\n        elif guidance_type == \"classifier-free\":\n            if guidance_scale == 1. or unconditional_condition is None:\n                return noise_pred_fn(x, t_continuous, cond=condition)\n            else:\n                x_in = torch.cat([x] * 2)\n                t_in = torch.cat([t_continuous] * 2)\n                if isinstance(condition, dict):\n                    assert isinstance(unconditional_condition, dict)\n                    c_in = {}\n                    for k in condition:\n                        if isinstance(condition[k], list):\n                            c_in[k] = [torch.cat([\n                                unconditional_condition[k][i],\n                                condition[k][i]]) for i in range(len(condition[k]))]\n                        else:\n                            c_in[k] = torch.cat([\n                                unconditional_condition[k],\n                                condition[k]])\n                elif isinstance(condition, list):\n                    c_in = []\n                    assert isinstance(unconditional_condition, list)\n                    for i in range(len(condition)):\n                        c_in.append(torch.cat([unconditional_condition[i], condition[i]]))\n                else:\n                    c_in = torch.cat([unconditional_condition, condition])\n                noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)\n                return noise_uncond + guidance_scale * (noise - noise_uncond)\n\n    assert model_type in [\"noise\", \"x_start\", \"v\"]\n    assert guidance_type in [\"uncond\", \"classifier\", \"classifier-free\"]\n    return model_fn\n\n\nclass UniPC:\n    def __init__(\n        self,\n        model_fn,\n        noise_schedule,\n        predict_x0=True,\n        thresholding=False,\n        max_val=1.,\n        variant='bh1',\n        condition=None,\n        unconditional_condition=None,\n        before_sample=None,\n        after_sample=None,\n        after_update=None\n    ):\n        \"\"\"Construct a UniPC.\n\n        We support both data_prediction and noise_prediction.\n        \"\"\"\n        self.model_fn_ = model_fn\n        self.noise_schedule = noise_schedule\n        self.variant = variant\n        self.predict_x0 = predict_x0\n        self.thresholding = thresholding\n        self.max_val = max_val\n        self.condition = condition\n        self.unconditional_condition = unconditional_condition\n        self.before_sample = before_sample\n        self.after_sample = after_sample\n        self.after_update = after_update\n\n    def dynamic_thresholding_fn(self, x0, t=None):\n        \"\"\"\n        The dynamic thresholding method.\n        \"\"\"\n        dims = x0.dim()\n        p = self.dynamic_thresholding_ratio\n        s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)\n        s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)\n        x0 = torch.clamp(x0, -s, s) / s\n        return x0\n\n    def model(self, x, t):\n        cond = self.condition\n        uncond = self.unconditional_condition\n        if self.before_sample is not None:\n            x, t, cond, uncond = self.before_sample(x, t, cond, uncond)\n        res = self.model_fn_(x, t, cond, uncond)\n        if self.after_sample is not None:\n            x, t, cond, uncond, res = self.after_sample(x, t, cond, uncond, res)\n\n        if isinstance(res, tuple):\n            # (None, pred_x0)\n            res = res[1]\n\n        return res\n\n    def noise_prediction_fn(self, x, t):\n        \"\"\"\n        Return the noise prediction model.\n        \"\"\"\n        return self.model(x, t)\n\n    def data_prediction_fn(self, x, t):\n        \"\"\"\n        Return the data prediction model (with thresholding).\n        \"\"\"\n        noise = self.noise_prediction_fn(x, t)\n        dims = x.dim()\n        alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)\n        x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)\n        if self.thresholding:\n            p = 0.995   # A hyperparameter in the paper of \"Imagen\" [1].\n            s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)\n            s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)\n            x0 = torch.clamp(x0, -s, s) / s\n        return x0\n\n    def model_fn(self, x, t):\n        \"\"\"\n        Convert the model to the noise prediction model or the data prediction model.\n        \"\"\"\n        if self.predict_x0:\n            return self.data_prediction_fn(x, t)\n        else:\n            return self.noise_prediction_fn(x, t)\n\n    def get_time_steps(self, skip_type, t_T, t_0, N, device):\n        \"\"\"Compute the intermediate time steps for sampling.\n        \"\"\"\n        if skip_type == 'logSNR':\n            lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))\n            lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))\n            logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)\n            return self.noise_schedule.inverse_lambda(logSNR_steps)\n        elif skip_type == 'time_uniform':\n            return torch.linspace(t_T, t_0, N + 1).to(device)\n        elif skip_type == 'time_quadratic':\n            t_order = 2\n            t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)\n            return t\n        else:\n            raise ValueError(f\"Unsupported skip_type {skip_type}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'\")\n\n    def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):\n        \"\"\"\n        Get the order of each step for sampling by the singlestep DPM-Solver.\n        \"\"\"\n        if order == 3:\n            K = steps // 3 + 1\n            if steps % 3 == 0:\n                orders = [3,] * (K - 2) + [2, 1]\n            elif steps % 3 == 1:\n                orders = [3,] * (K - 1) + [1]\n            else:\n                orders = [3,] * (K - 1) + [2]\n        elif order == 2:\n            if steps % 2 == 0:\n                K = steps // 2\n                orders = [2,] * K\n            else:\n                K = steps // 2 + 1\n                orders = [2,] * (K - 1) + [1]\n        elif order == 1:\n            K = steps\n            orders = [1,] * steps\n        else:\n            raise ValueError(\"'order' must be '1' or '2' or '3'.\")\n        if skip_type == 'logSNR':\n            # To reproduce the results in DPM-Solver paper\n            timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)\n        else:\n            timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)]\n        return timesteps_outer, orders\n\n    def denoise_to_zero_fn(self, x, s):\n        \"\"\"\n        Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.\n        \"\"\"\n        return self.data_prediction_fn(x, s)\n\n    def multistep_uni_pc_update(self, x, model_prev_list, t_prev_list, t, order, **kwargs):\n        if len(t.shape) == 0:\n            t = t.view(-1)\n        if 'bh' in self.variant:\n            return self.multistep_uni_pc_bh_update(x, model_prev_list, t_prev_list, t, order, **kwargs)\n        else:\n            assert self.variant == 'vary_coeff'\n            return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)\n\n    def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):\n        #print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')\n        ns = self.noise_schedule\n        assert order <= len(model_prev_list)\n\n        # first compute rks\n        t_prev_0 = t_prev_list[-1]\n        lambda_prev_0 = ns.marginal_lambda(t_prev_0)\n        lambda_t = ns.marginal_lambda(t)\n        model_prev_0 = model_prev_list[-1]\n        sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)\n        log_alpha_t = ns.marginal_log_mean_coeff(t)\n        alpha_t = torch.exp(log_alpha_t)\n\n        h = lambda_t - lambda_prev_0\n\n        rks = []\n        D1s = []\n        for i in range(1, order):\n            t_prev_i = t_prev_list[-(i + 1)]\n            model_prev_i = model_prev_list[-(i + 1)]\n            lambda_prev_i = ns.marginal_lambda(t_prev_i)\n            rk = (lambda_prev_i - lambda_prev_0) / h\n            rks.append(rk)\n            D1s.append((model_prev_i - model_prev_0) / rk)\n\n        rks.append(1.)\n        rks = torch.tensor(rks, device=x.device)\n\n        K = len(rks)\n        # build C matrix\n        C = []\n\n        col = torch.ones_like(rks)\n        for k in range(1, K + 1):\n            C.append(col)\n            col = col * rks / (k + 1)\n        C = torch.stack(C, dim=1)\n\n        if len(D1s) > 0:\n            D1s = torch.stack(D1s, dim=1) # (B, K)\n            C_inv_p = torch.linalg.inv(C[:-1, :-1])\n            A_p = C_inv_p\n\n        if use_corrector:\n            #print('using corrector')\n            C_inv = torch.linalg.inv(C)\n            A_c = C_inv\n\n        hh = -h if self.predict_x0 else h\n        h_phi_1 = torch.expm1(hh)\n        h_phi_ks = []\n        factorial_k = 1\n        h_phi_k = h_phi_1\n        for k in range(1, K + 2):\n            h_phi_ks.append(h_phi_k)\n            h_phi_k = h_phi_k / hh - 1 / factorial_k\n            factorial_k *= (k + 1)\n\n        model_t = None\n        if self.predict_x0:\n            x_t_ = (\n                sigma_t / sigma_prev_0 * x\n                - alpha_t * h_phi_1 * model_prev_0\n            )\n            # now predictor\n            x_t = x_t_\n            if len(D1s) > 0:\n                # compute the residuals for predictor\n                for k in range(K - 1):\n                    x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])\n            # now corrector\n            if use_corrector:\n                model_t = self.model_fn(x_t, t)\n                D1_t = (model_t - model_prev_0)\n                x_t = x_t_\n                k = 0\n                for k in range(K - 1):\n                    x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])\n                x_t = x_t - alpha_t * h_phi_ks[K] * (D1_t * A_c[k][-1])\n        else:\n            log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)\n            x_t_ = (\n                (torch.exp(log_alpha_t - log_alpha_prev_0)) * x\n                - (sigma_t * h_phi_1) * model_prev_0\n            )\n            # now predictor\n            x_t = x_t_\n            if len(D1s) > 0:\n                # compute the residuals for predictor\n                for k in range(K - 1):\n                    x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])\n            # now corrector\n            if use_corrector:\n                model_t = self.model_fn(x_t, t)\n                D1_t = (model_t - model_prev_0)\n                x_t = x_t_\n                k = 0\n                for k in range(K - 1):\n                    x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])\n                x_t = x_t - sigma_t * h_phi_ks[K] * (D1_t * A_c[k][-1])\n        return x_t, model_t\n\n    def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True):\n        #print(f'using unified predictor-corrector with order {order} (solver type: B(h))')\n        ns = self.noise_schedule\n        assert order <= len(model_prev_list)\n        dims = x.dim()\n\n        # first compute rks\n        t_prev_0 = t_prev_list[-1]\n        lambda_prev_0 = ns.marginal_lambda(t_prev_0)\n        lambda_t = ns.marginal_lambda(t)\n        model_prev_0 = model_prev_list[-1]\n        sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)\n        log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)\n        alpha_t = torch.exp(log_alpha_t)\n\n        h = lambda_t - lambda_prev_0\n\n        rks = []\n        D1s = []\n        for i in range(1, order):\n            t_prev_i = t_prev_list[-(i + 1)]\n            model_prev_i = model_prev_list[-(i + 1)]\n            lambda_prev_i = ns.marginal_lambda(t_prev_i)\n            rk = ((lambda_prev_i - lambda_prev_0) / h)[0]\n            rks.append(rk)\n            D1s.append((model_prev_i - model_prev_0) / rk)\n\n        rks.append(1.)\n        rks = torch.tensor(rks, device=x.device)\n\n        R = []\n        b = []\n\n        hh = -h[0] if self.predict_x0 else h[0]\n        h_phi_1 = torch.expm1(hh) # h\\phi_1(h) = e^h - 1\n        h_phi_k = h_phi_1 / hh - 1\n\n        factorial_i = 1\n\n        if self.variant == 'bh1':\n            B_h = hh\n        elif self.variant == 'bh2':\n            B_h = torch.expm1(hh)\n        else:\n            raise NotImplementedError()\n\n        for i in range(1, order + 1):\n            R.append(torch.pow(rks, i - 1))\n            b.append(h_phi_k * factorial_i / B_h)\n            factorial_i *= (i + 1)\n            h_phi_k = h_phi_k / hh - 1 / factorial_i\n\n        R = torch.stack(R)\n        b = torch.tensor(b, device=x.device)\n\n        # now predictor\n        use_predictor = len(D1s) > 0 and x_t is None\n        if len(D1s) > 0:\n            D1s = torch.stack(D1s, dim=1) # (B, K)\n            if x_t is None:\n                # for order 2, we use a simplified version\n                if order == 2:\n                    rhos_p = torch.tensor([0.5], device=b.device)\n                else:\n                    rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])\n        else:\n            D1s = None\n\n        if use_corrector:\n            #print('using corrector')\n            # for order 1, we use a simplified version\n            if order == 1:\n                rhos_c = torch.tensor([0.5], device=b.device)\n            else:\n                rhos_c = torch.linalg.solve(R, b)\n\n        model_t = None\n        if self.predict_x0:\n            x_t_ = (\n                expand_dims(sigma_t / sigma_prev_0, dims) * x\n                - expand_dims(alpha_t * h_phi_1, dims)* model_prev_0\n            )\n\n            if x_t is None:\n                if use_predictor:\n                    pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)\n                else:\n                    pred_res = 0\n                x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * pred_res\n\n            if use_corrector:\n                model_t = self.model_fn(x_t, t)\n                if D1s is not None:\n                    corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)\n                else:\n                    corr_res = 0\n                D1_t = (model_t - model_prev_0)\n                x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)\n        else:\n            x_t_ = (\n                expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x\n                - expand_dims(sigma_t * h_phi_1, dims) * model_prev_0\n            )\n            if x_t is None:\n                if use_predictor:\n                    pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)\n                else:\n                    pred_res = 0\n                x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * pred_res\n\n            if use_corrector:\n                model_t = self.model_fn(x_t, t)\n                if D1s is not None:\n                    corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)\n                else:\n                    corr_res = 0\n                D1_t = (model_t - model_prev_0)\n                x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)\n        return x_t, model_t\n\n\n    def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',\n        method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',\n        atol=0.0078, rtol=0.05, corrector=False,\n    ):\n        t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end\n        t_T = self.noise_schedule.T if t_start is None else t_start\n        device = x.device\n        if method == 'multistep':\n            assert steps >= order, \"UniPC order must be < sampling steps\"\n            timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)\n            #print(f\"Running UniPC Sampling with {timesteps.shape[0]} timesteps, order {order}\")\n            assert timesteps.shape[0] - 1 == steps\n            with torch.no_grad():\n                vec_t = timesteps[0].expand((x.shape[0]))\n                model_prev_list = [self.model_fn(x, vec_t)]\n                t_prev_list = [vec_t]\n                with tqdm.tqdm(total=steps) as pbar:\n                    # Init the first `order` values by lower order multistep DPM-Solver.\n                    for init_order in range(1, order):\n                        vec_t = timesteps[init_order].expand(x.shape[0])\n                        x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)\n                        if model_x is None:\n                            model_x = self.model_fn(x, vec_t)\n                        if self.after_update is not None:\n                            self.after_update(x, model_x)\n                        model_prev_list.append(model_x)\n                        t_prev_list.append(vec_t)\n                        pbar.update()\n\n                    for step in range(order, steps + 1):\n                        vec_t = timesteps[step].expand(x.shape[0])\n                        if lower_order_final:\n                            step_order = min(order, steps + 1 - step)\n                        else:\n                            step_order = order\n                        #print('this step order:', step_order)\n                        if step == steps:\n                            #print('do not run corrector at the last step')\n                            use_corrector = False\n                        else:\n                            use_corrector = True\n                        x, model_x =  self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)\n                        if self.after_update is not None:\n                            self.after_update(x, model_x)\n                        for i in range(order - 1):\n                            t_prev_list[i] = t_prev_list[i + 1]\n                            model_prev_list[i] = model_prev_list[i + 1]\n                        t_prev_list[-1] = vec_t\n                        # We do not need to evaluate the final model value.\n                        if step < steps:\n                            if model_x is None:\n                                model_x = self.model_fn(x, vec_t)\n                            model_prev_list[-1] = model_x\n                        pbar.update()\n        else:\n            raise NotImplementedError()\n        if denoise_to_zero:\n            x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)\n        return x\n\n\n#############################################################\n# other utility functions\n#############################################################\n\ndef interpolate_fn(x, xp, yp):\n    \"\"\"\n    A piecewise linear function y = f(x), using xp and yp as keypoints.\n    We implement f(x) in a differentiable way (i.e. applicable for autograd).\n    The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)\n\n    Args:\n        x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).\n        xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.\n        yp: PyTorch tensor with shape [C, K].\n    Returns:\n        The function values f(x), with shape [N, C].\n    \"\"\"\n    N, K = x.shape[0], xp.shape[1]\n    all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)\n    sorted_all_x, x_indices = torch.sort(all_x, dim=2)\n    x_idx = torch.argmin(x_indices, dim=2)\n    cand_start_idx = x_idx - 1\n    start_idx = torch.where(\n        torch.eq(x_idx, 0),\n        torch.tensor(1, device=x.device),\n        torch.where(\n            torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,\n        ),\n    )\n    end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)\n    start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)\n    end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)\n    start_idx2 = torch.where(\n        torch.eq(x_idx, 0),\n        torch.tensor(0, device=x.device),\n        torch.where(\n            torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,\n        ),\n    )\n    y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)\n    start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)\n    end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)\n    cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)\n    return cand\n\n\ndef expand_dims(v, dims):\n    \"\"\"\n    Expand the tensor `v` to the dim `dims`.\n\n    Args:\n        `v`: a PyTorch tensor with shape [N].\n        `dim`: a `int`.\n    Returns:\n        a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.\n    \"\"\"\n    return v[(...,) + (None,)*(dims - 1)]\n"
  },
  {
    "path": "modules/models/sd3/mmdit.py",
    "content": "### This file contains impls for MM-DiT, the core model component of SD3\n\nimport math\nfrom typing import Dict, Optional\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom einops import rearrange, repeat\nfrom modules.models.sd3.other_impls import attention, Mlp\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\" 2D Image to Patch Embedding\"\"\"\n    def __init__(\n            self,\n            img_size: Optional[int] = 224,\n            patch_size: int = 16,\n            in_chans: int = 3,\n            embed_dim: int = 768,\n            flatten: bool = True,\n            bias: bool = True,\n            strict_img_size: bool = True,\n            dynamic_img_pad: bool = False,\n            dtype=None,\n            device=None,\n    ):\n        super().__init__()\n        self.patch_size = (patch_size, patch_size)\n        if img_size is not None:\n            self.img_size = (img_size, img_size)\n            self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])\n            self.num_patches = self.grid_size[0] * self.grid_size[1]\n        else:\n            self.img_size = None\n            self.grid_size = None\n            self.num_patches = None\n\n        # flatten spatial dim and transpose to channels last, kept for bwd compat\n        self.flatten = flatten\n        self.strict_img_size = strict_img_size\n        self.dynamic_img_pad = dynamic_img_pad\n\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device)\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        x = self.proj(x)\n        if self.flatten:\n            x = x.flatten(2).transpose(1, 2)  # NCHW -> NLC\n        return x\n\n\ndef modulate(x, shift, scale):\n    if shift is None:\n        shift = torch.zeros_like(scale)\n    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)\n\n\n#################################################################################\n#                   Sine/Cosine Positional Embedding Functions                  #\n#################################################################################\n\n\ndef get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, scaling_factor=None, offset=None):\n    \"\"\"\n    grid_size: int of the grid height and width\n    return:\n    pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)\n    \"\"\"\n    grid_h = np.arange(grid_size, dtype=np.float32)\n    grid_w = np.arange(grid_size, dtype=np.float32)\n    grid = np.meshgrid(grid_w, grid_h)  # here w goes first\n    grid = np.stack(grid, axis=0)\n    if scaling_factor is not None:\n        grid = grid / scaling_factor\n    if offset is not None:\n        grid = grid - offset\n    grid = grid.reshape([2, 1, grid_size, grid_size])\n    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)\n    if cls_token and extra_tokens > 0:\n        pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)\n    return pos_embed\n\n\ndef get_2d_sincos_pos_embed_from_grid(embed_dim, grid):\n    assert embed_dim % 2 == 0\n    # use half of dimensions to encode grid_h\n    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)\n    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)\n    emb = np.concatenate([emb_h, emb_w], axis=1)  # (H*W, D)\n    return emb\n\n\ndef get_1d_sincos_pos_embed_from_grid(embed_dim, pos):\n    \"\"\"\n    embed_dim: output dimension for each position\n    pos: a list of positions to be encoded: size (M,)\n    out: (M, D)\n    \"\"\"\n    assert embed_dim % 2 == 0\n    omega = np.arange(embed_dim // 2, dtype=np.float64)\n    omega /= embed_dim / 2.0\n    omega = 1.0 / 10000**omega  # (D/2,)\n    pos = pos.reshape(-1)  # (M,)\n    out = np.einsum(\"m,d->md\", pos, omega)  # (M, D/2), outer product\n    emb_sin = np.sin(out)  # (M, D/2)\n    emb_cos = np.cos(out)  # (M, D/2)\n    return np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)\n\n\n#################################################################################\n#               Embedding Layers for Timesteps and Class Labels                 #\n#################################################################################\n\n\nclass TimestepEmbedder(nn.Module):\n    \"\"\"Embeds scalar timesteps into vector representations.\"\"\"\n\n    def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None):\n        super().__init__()\n        self.mlp = nn.Sequential(\n            nn.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),\n            nn.SiLU(),\n            nn.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),\n        )\n        self.frequency_embedding_size = frequency_embedding_size\n\n    @staticmethod\n    def timestep_embedding(t, dim, max_period=10000):\n        \"\"\"\n        Create sinusoidal timestep embeddings.\n        :param t: a 1-D Tensor of N indices, one per batch element.\n                          These may be fractional.\n        :param dim: the dimension of the output.\n        :param max_period: controls the minimum frequency of the embeddings.\n        :return: an (N, D) Tensor of positional embeddings.\n        \"\"\"\n        half = dim // 2\n        freqs = torch.exp(\n            -math.log(max_period)\n            * torch.arange(start=0, end=half, dtype=torch.float32)\n            / half\n        ).to(device=t.device)\n        args = t[:, None].float() * freqs[None]\n        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)\n        if dim % 2:\n            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)\n        if torch.is_floating_point(t):\n            embedding = embedding.to(dtype=t.dtype)\n        return embedding\n\n    def forward(self, t, dtype, **kwargs):\n        t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)\n        t_emb = self.mlp(t_freq)\n        return t_emb\n\n\nclass VectorEmbedder(nn.Module):\n    \"\"\"Embeds a flat vector of dimension input_dim\"\"\"\n\n    def __init__(self, input_dim: int, hidden_size: int, dtype=None, device=None):\n        super().__init__()\n        self.mlp = nn.Sequential(\n            nn.Linear(input_dim, hidden_size, bias=True, dtype=dtype, device=device),\n            nn.SiLU(),\n            nn.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        return self.mlp(x)\n\n\n#################################################################################\n#                                 Core DiT Model                                #\n#################################################################################\n\n\nclass QkvLinear(torch.nn.Linear):\n    pass\n\ndef split_qkv(qkv, head_dim):\n    qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0)\n    return qkv[0], qkv[1], qkv[2]\n\ndef optimized_attention(qkv, num_heads):\n    return attention(qkv[0], qkv[1], qkv[2], num_heads)\n\nclass SelfAttention(nn.Module):\n    ATTENTION_MODES = (\"xformers\", \"torch\", \"torch-hb\", \"math\", \"debug\")\n\n    def __init__(\n        self,\n        dim: int,\n        num_heads: int = 8,\n        qkv_bias: bool = False,\n        qk_scale: Optional[float] = None,\n        attn_mode: str = \"xformers\",\n        pre_only: bool = False,\n        qk_norm: Optional[str] = None,\n        rmsnorm: bool = False,\n        dtype=None,\n        device=None,\n    ):\n        super().__init__()\n        self.num_heads = num_heads\n        self.head_dim = dim // num_heads\n\n        self.qkv = QkvLinear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)\n        if not pre_only:\n            self.proj = nn.Linear(dim, dim, dtype=dtype, device=device)\n        assert attn_mode in self.ATTENTION_MODES\n        self.attn_mode = attn_mode\n        self.pre_only = pre_only\n\n        if qk_norm == \"rms\":\n            self.ln_q = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)\n            self.ln_k = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)\n        elif qk_norm == \"ln\":\n            self.ln_q = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)\n            self.ln_k = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device)\n        elif qk_norm is None:\n            self.ln_q = nn.Identity()\n            self.ln_k = nn.Identity()\n        else:\n            raise ValueError(qk_norm)\n\n    def pre_attention(self, x: torch.Tensor):\n        B, L, C = x.shape\n        qkv = self.qkv(x)\n        q, k, v = split_qkv(qkv, self.head_dim)\n        q = self.ln_q(q).reshape(q.shape[0], q.shape[1], -1)\n        k = self.ln_k(k).reshape(q.shape[0], q.shape[1], -1)\n        return (q, k, v)\n\n    def post_attention(self, x: torch.Tensor) -> torch.Tensor:\n        assert not self.pre_only\n        x = self.proj(x)\n        return x\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        (q, k, v) = self.pre_attention(x)\n        x = attention(q, k, v, self.num_heads)\n        x = self.post_attention(x)\n        return x\n\n\nclass RMSNorm(torch.nn.Module):\n    def __init__(\n        self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6, device=None, dtype=None\n    ):\n        \"\"\"\n        Initialize the RMSNorm normalization layer.\n        Args:\n            dim (int): The dimension of the input tensor.\n            eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.\n        Attributes:\n            eps (float): A small value added to the denominator for numerical stability.\n            weight (nn.Parameter): Learnable scaling parameter.\n        \"\"\"\n        super().__init__()\n        self.eps = eps\n        self.learnable_scale = elementwise_affine\n        if self.learnable_scale:\n            self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))\n        else:\n            self.register_parameter(\"weight\", None)\n\n    def _norm(self, x):\n        \"\"\"\n        Apply the RMSNorm normalization to the input tensor.\n        Args:\n            x (torch.Tensor): The input tensor.\n        Returns:\n            torch.Tensor: The normalized tensor.\n        \"\"\"\n        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)\n\n    def forward(self, x):\n        \"\"\"\n        Forward pass through the RMSNorm layer.\n        Args:\n            x (torch.Tensor): The input tensor.\n        Returns:\n            torch.Tensor: The output tensor after applying RMSNorm.\n        \"\"\"\n        x = self._norm(x)\n        if self.learnable_scale:\n            return x * self.weight.to(device=x.device, dtype=x.dtype)\n        else:\n            return x\n\n\nclass SwiGLUFeedForward(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        hidden_dim: int,\n        multiple_of: int,\n        ffn_dim_multiplier: Optional[float] = None,\n    ):\n        \"\"\"\n        Initialize the FeedForward module.\n\n        Args:\n            dim (int): Input dimension.\n            hidden_dim (int): Hidden dimension of the feedforward layer.\n            multiple_of (int): Value to ensure hidden dimension is a multiple of this value.\n            ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.\n\n        Attributes:\n            w1 (ColumnParallelLinear): Linear transformation for the first layer.\n            w2 (RowParallelLinear): Linear transformation for the second layer.\n            w3 (ColumnParallelLinear): Linear transformation for the third layer.\n\n        \"\"\"\n        super().__init__()\n        hidden_dim = int(2 * hidden_dim / 3)\n        # custom dim factor multiplier\n        if ffn_dim_multiplier is not None:\n            hidden_dim = int(ffn_dim_multiplier * hidden_dim)\n        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)\n\n        self.w1 = nn.Linear(dim, hidden_dim, bias=False)\n        self.w2 = nn.Linear(hidden_dim, dim, bias=False)\n        self.w3 = nn.Linear(dim, hidden_dim, bias=False)\n\n    def forward(self, x):\n        return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))\n\n\nclass DismantledBlock(nn.Module):\n    \"\"\"A DiT block with gated adaptive layer norm (adaLN) conditioning.\"\"\"\n\n    ATTENTION_MODES = (\"xformers\", \"torch\", \"torch-hb\", \"math\", \"debug\")\n\n    def __init__(\n        self,\n        hidden_size: int,\n        num_heads: int,\n        mlp_ratio: float = 4.0,\n        attn_mode: str = \"xformers\",\n        qkv_bias: bool = False,\n        pre_only: bool = False,\n        rmsnorm: bool = False,\n        scale_mod_only: bool = False,\n        swiglu: bool = False,\n        qk_norm: Optional[str] = None,\n        dtype=None,\n        device=None,\n        **block_kwargs,\n    ):\n        super().__init__()\n        assert attn_mode in self.ATTENTION_MODES\n        if not rmsnorm:\n            self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)\n        else:\n            self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)\n        self.attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, attn_mode=attn_mode, pre_only=pre_only, qk_norm=qk_norm, rmsnorm=rmsnorm, dtype=dtype, device=device)\n        if not pre_only:\n            if not rmsnorm:\n                self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)\n            else:\n                self.norm2 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)\n        mlp_hidden_dim = int(hidden_size * mlp_ratio)\n        if not pre_only:\n            if not swiglu:\n                self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=nn.GELU(approximate=\"tanh\"), dtype=dtype, device=device)\n            else:\n                self.mlp = SwiGLUFeedForward(dim=hidden_size, hidden_dim=mlp_hidden_dim, multiple_of=256)\n        self.scale_mod_only = scale_mod_only\n        if not scale_mod_only:\n            n_mods = 6 if not pre_only else 2\n        else:\n            n_mods = 4 if not pre_only else 1\n        self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, n_mods * hidden_size, bias=True, dtype=dtype, device=device))\n        self.pre_only = pre_only\n\n    def pre_attention(self, x: torch.Tensor, c: torch.Tensor):\n        assert x is not None, \"pre_attention called with None input\"\n        if not self.pre_only:\n            if not self.scale_mod_only:\n                shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)\n            else:\n                shift_msa = None\n                shift_mlp = None\n                scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(4, dim=1)\n            qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))\n            return qkv, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp)\n        else:\n            if not self.scale_mod_only:\n                shift_msa, scale_msa = self.adaLN_modulation(c).chunk(2, dim=1)\n            else:\n                shift_msa = None\n                scale_msa = self.adaLN_modulation(c)\n            qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))\n            return qkv, None\n\n    def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp):\n        assert not self.pre_only\n        x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn)\n        x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))\n        return x\n\n    def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:\n        assert not self.pre_only\n        (q, k, v), intermediates = self.pre_attention(x, c)\n        attn = attention(q, k, v, self.attn.num_heads)\n        return self.post_attention(attn, *intermediates)\n\n\ndef block_mixing(context, x, context_block, x_block, c):\n    assert context is not None, \"block_mixing called with None context\"\n    context_qkv, context_intermediates = context_block.pre_attention(context, c)\n\n    x_qkv, x_intermediates = x_block.pre_attention(x, c)\n\n    o = []\n    for t in range(3):\n        o.append(torch.cat((context_qkv[t], x_qkv[t]), dim=1))\n    q, k, v = tuple(o)\n\n    attn = attention(q, k, v, x_block.attn.num_heads)\n    context_attn, x_attn = (attn[:, : context_qkv[0].shape[1]], attn[:, context_qkv[0].shape[1] :])\n\n    if not context_block.pre_only:\n        context = context_block.post_attention(context_attn, *context_intermediates)\n    else:\n        context = None\n    x = x_block.post_attention(x_attn, *x_intermediates)\n    return context, x\n\n\nclass JointBlock(nn.Module):\n    \"\"\"just a small wrapper to serve as a fsdp unit\"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__()\n        pre_only = kwargs.pop(\"pre_only\")\n        qk_norm = kwargs.pop(\"qk_norm\", None)\n        self.context_block = DismantledBlock(*args, pre_only=pre_only, qk_norm=qk_norm, **kwargs)\n        self.x_block = DismantledBlock(*args, pre_only=False, qk_norm=qk_norm, **kwargs)\n\n    def forward(self, *args, **kwargs):\n        return block_mixing(*args, context_block=self.context_block, x_block=self.x_block, **kwargs)\n\n\nclass FinalLayer(nn.Module):\n    \"\"\"\n    The final layer of DiT.\n    \"\"\"\n\n    def __init__(self, hidden_size: int, patch_size: int, out_channels: int, total_out_channels: Optional[int] = None, dtype=None, device=None):\n        super().__init__()\n        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)\n        self.linear = (\n            nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)\n            if (total_out_channels is None)\n            else nn.Linear(hidden_size, total_out_channels, bias=True, dtype=dtype, device=device)\n        )\n        self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))\n\n    def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:\n        shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)\n        x = modulate(self.norm_final(x), shift, scale)\n        x = self.linear(x)\n        return x\n\n\nclass MMDiT(nn.Module):\n    \"\"\"Diffusion model with a Transformer backbone.\"\"\"\n\n    def __init__(\n        self,\n        input_size: int = 32,\n        patch_size: int = 2,\n        in_channels: int = 4,\n        depth: int = 28,\n        mlp_ratio: float = 4.0,\n        learn_sigma: bool = False,\n        adm_in_channels: Optional[int] = None,\n        context_embedder_config: Optional[Dict] = None,\n        register_length: int = 0,\n        attn_mode: str = \"torch\",\n        rmsnorm: bool = False,\n        scale_mod_only: bool = False,\n        swiglu: bool = False,\n        out_channels: Optional[int] = None,\n        pos_embed_scaling_factor: Optional[float] = None,\n        pos_embed_offset: Optional[float] = None,\n        pos_embed_max_size: Optional[int] = None,\n        num_patches = None,\n        qk_norm: Optional[str] = None,\n        qkv_bias: bool = True,\n        dtype = None,\n        device = None,\n    ):\n        super().__init__()\n        self.dtype = dtype\n        self.learn_sigma = learn_sigma\n        self.in_channels = in_channels\n        default_out_channels = in_channels * 2 if learn_sigma else in_channels\n        self.out_channels = out_channels if out_channels is not None else default_out_channels\n        self.patch_size = patch_size\n        self.pos_embed_scaling_factor = pos_embed_scaling_factor\n        self.pos_embed_offset = pos_embed_offset\n        self.pos_embed_max_size = pos_embed_max_size\n\n        # apply magic --> this defines a head_size of 64\n        hidden_size = 64 * depth\n        num_heads = depth\n\n        self.num_heads = num_heads\n\n        self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True, strict_img_size=self.pos_embed_max_size is None, dtype=dtype, device=device)\n        self.t_embedder = TimestepEmbedder(hidden_size, dtype=dtype, device=device)\n\n        if adm_in_channels is not None:\n            assert isinstance(adm_in_channels, int)\n            self.y_embedder = VectorEmbedder(adm_in_channels, hidden_size, dtype=dtype, device=device)\n\n        self.context_embedder = nn.Identity()\n        if context_embedder_config is not None:\n            if context_embedder_config[\"target\"] == \"torch.nn.Linear\":\n                self.context_embedder = nn.Linear(**context_embedder_config[\"params\"], dtype=dtype, device=device)\n\n        self.register_length = register_length\n        if self.register_length > 0:\n            self.register = nn.Parameter(torch.randn(1, register_length, hidden_size, dtype=dtype, device=device))\n\n        # num_patches = self.x_embedder.num_patches\n        # Will use fixed sin-cos embedding:\n        # just use a buffer already\n        if num_patches is not None:\n            self.register_buffer(\n                \"pos_embed\",\n                torch.zeros(1, num_patches, hidden_size, dtype=dtype, device=device),\n            )\n        else:\n            self.pos_embed = None\n\n        self.joint_blocks = nn.ModuleList(\n            [\n                JointBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, attn_mode=attn_mode, pre_only=i == depth - 1, rmsnorm=rmsnorm, scale_mod_only=scale_mod_only, swiglu=swiglu, qk_norm=qk_norm, dtype=dtype, device=device)\n                for i in range(depth)\n            ]\n        )\n\n        self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels, dtype=dtype, device=device)\n\n    def cropped_pos_embed(self, hw):\n        assert self.pos_embed_max_size is not None\n        p = self.x_embedder.patch_size[0]\n        h, w = hw\n        # patched size\n        h = h // p\n        w = w // p\n        assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size)\n        assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size)\n        top = (self.pos_embed_max_size - h) // 2\n        left = (self.pos_embed_max_size - w) // 2\n        spatial_pos_embed = rearrange(\n            self.pos_embed,\n            \"1 (h w) c -> 1 h w c\",\n            h=self.pos_embed_max_size,\n            w=self.pos_embed_max_size,\n        )\n        spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :]\n        spatial_pos_embed = rearrange(spatial_pos_embed, \"1 h w c -> 1 (h w) c\")\n        return spatial_pos_embed\n\n    def unpatchify(self, x, hw=None):\n        \"\"\"\n        x: (N, T, patch_size**2 * C)\n        imgs: (N, H, W, C)\n        \"\"\"\n        c = self.out_channels\n        p = self.x_embedder.patch_size[0]\n        if hw is None:\n            h = w = int(x.shape[1] ** 0.5)\n        else:\n            h, w = hw\n            h = h // p\n            w = w // p\n        assert h * w == x.shape[1]\n\n        x = x.reshape(shape=(x.shape[0], h, w, p, p, c))\n        x = torch.einsum(\"nhwpqc->nchpwq\", x)\n        imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))\n        return imgs\n\n    def forward_core_with_concat(self, x: torch.Tensor, c_mod: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor:\n        if self.register_length > 0:\n            context = torch.cat((repeat(self.register, \"1 ... -> b ...\", b=x.shape[0]), context if context is not None else torch.Tensor([]).type_as(x)), 1)\n\n        # context is B, L', D\n        # x is B, L, D\n        for block in self.joint_blocks:\n            context, x = block(context, x, c=c_mod)\n\n        x = self.final_layer(x, c_mod)  # (N, T, patch_size ** 2 * out_channels)\n        return x\n\n    def forward(self, x: torch.Tensor, t: torch.Tensor, y: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None) -> torch.Tensor:\n        \"\"\"\n        Forward pass of DiT.\n        x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)\n        t: (N,) tensor of diffusion timesteps\n        y: (N,) tensor of class labels\n        \"\"\"\n        hw = x.shape[-2:]\n        x = self.x_embedder(x) + self.cropped_pos_embed(hw)\n        c = self.t_embedder(t, dtype=x.dtype)  # (N, D)\n        if y is not None:\n            y = self.y_embedder(y)  # (N, D)\n            c = c + y  # (N, D)\n\n        context = self.context_embedder(context)\n\n        x = self.forward_core_with_concat(x, c, context)\n\n        x = self.unpatchify(x, hw=hw)  # (N, out_channels, H, W)\n        return x\n"
  },
  {
    "path": "modules/models/sd3/other_impls.py",
    "content": "### This file contains impls for underlying related models (CLIP, T5, etc)\n\nimport torch\nimport math\nfrom torch import nn\nfrom transformers import CLIPTokenizer, T5TokenizerFast\n\nfrom modules import sd_hijack\n\n\n#################################################################################################\n### Core/Utility\n#################################################################################################\n\n\nclass AutocastLinear(nn.Linear):\n    \"\"\"Same as usual linear layer, but casts its weights to whatever the parameter type is.\n\n    This is different from torch.autocast in a way that float16 layer processing float32 input\n    will return float16 with autocast on, and float32 with this. T5 seems to be fucked\n    if you do it in full float16 (returning almost all zeros in the final output).\n    \"\"\"\n\n    def forward(self, x):\n        return torch.nn.functional.linear(x, self.weight.to(x.dtype), self.bias.to(x.dtype) if self.bias is not None else None)\n\n\ndef attention(q, k, v, heads, mask=None):\n    \"\"\"Convenience wrapper around a basic attention operation\"\"\"\n    b, _, dim_head = q.shape\n    dim_head //= heads\n    q, k, v = [t.view(b, -1, heads, dim_head).transpose(1, 2) for t in (q, k, v)]\n    out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)\n    return out.transpose(1, 2).reshape(b, -1, heads * dim_head)\n\n\nclass Mlp(nn.Module):\n    \"\"\" MLP as used in Vision Transformer, MLP-Mixer and related networks\"\"\"\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, dtype=None, device=None):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n\n        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, dtype=dtype, device=device)\n        self.act = act_layer\n        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, dtype=dtype, device=device)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.fc2(x)\n        return x\n\n\n#################################################################################################\n### CLIP\n#################################################################################################\n\n\nclass CLIPAttention(torch.nn.Module):\n    def __init__(self, embed_dim, heads, dtype, device):\n        super().__init__()\n        self.heads = heads\n        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)\n        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)\n        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)\n        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)\n\n    def forward(self, x, mask=None):\n        q = self.q_proj(x)\n        k = self.k_proj(x)\n        v = self.v_proj(x)\n        out = attention(q, k, v, self.heads, mask)\n        return self.out_proj(out)\n\n\nACTIVATIONS = {\n    \"quick_gelu\": lambda a: a * torch.sigmoid(1.702 * a),\n    \"gelu\": torch.nn.functional.gelu,\n}\n\nclass CLIPLayer(torch.nn.Module):\n    def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device):\n        super().__init__()\n        self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)\n        self.self_attn = CLIPAttention(embed_dim, heads, dtype, device)\n        self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device)\n        #self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device)\n        self.mlp = Mlp(embed_dim, intermediate_size, embed_dim, act_layer=ACTIVATIONS[intermediate_activation], dtype=dtype, device=device)\n\n    def forward(self, x, mask=None):\n        x += self.self_attn(self.layer_norm1(x), mask)\n        x += self.mlp(self.layer_norm2(x))\n        return x\n\n\nclass CLIPEncoder(torch.nn.Module):\n    def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device):\n        super().__init__()\n        self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device) for i in range(num_layers)])\n\n    def forward(self, x, mask=None, intermediate_output=None):\n        if intermediate_output is not None:\n            if intermediate_output < 0:\n                intermediate_output = len(self.layers) + intermediate_output\n        intermediate = None\n        for i, layer in enumerate(self.layers):\n            x = layer(x, mask)\n            if i == intermediate_output:\n                intermediate = x.clone()\n        return x, intermediate\n\n\nclass CLIPEmbeddings(torch.nn.Module):\n    def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None, textual_inversion_key=\"clip_l\"):\n        super().__init__()\n        self.token_embedding = sd_hijack.TextualInversionEmbeddings(vocab_size, embed_dim, dtype=dtype, device=device, textual_inversion_key=textual_inversion_key)\n        self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)\n\n    def forward(self, input_tokens):\n        return self.token_embedding(input_tokens) + self.position_embedding.weight\n\n\nclass CLIPTextModel_(torch.nn.Module):\n    def __init__(self, config_dict, dtype, device):\n        num_layers = config_dict[\"num_hidden_layers\"]\n        embed_dim = config_dict[\"hidden_size\"]\n        heads = config_dict[\"num_attention_heads\"]\n        intermediate_size = config_dict[\"intermediate_size\"]\n        intermediate_activation = config_dict[\"hidden_act\"]\n        super().__init__()\n        self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device, textual_inversion_key=config_dict.get('textual_inversion_key', 'clip_l'))\n        self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device)\n        self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device)\n\n    def forward(self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True):\n        x = self.embeddings(input_tokens)\n        causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float(\"-inf\")).triu_(1)\n        x, i = self.encoder(x, mask=causal_mask, intermediate_output=intermediate_output)\n        x = self.final_layer_norm(x)\n        if i is not None and final_layer_norm_intermediate:\n            i = self.final_layer_norm(i)\n        pooled_output = x[torch.arange(x.shape[0], device=x.device), input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),]\n        return x, i, pooled_output\n\n\nclass CLIPTextModel(torch.nn.Module):\n    def __init__(self, config_dict, dtype, device):\n        super().__init__()\n        self.num_layers = config_dict[\"num_hidden_layers\"]\n        self.text_model = CLIPTextModel_(config_dict, dtype, device)\n        embed_dim = config_dict[\"hidden_size\"]\n        self.text_projection = nn.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)\n        self.text_projection.weight.copy_(torch.eye(embed_dim))\n        self.dtype = dtype\n\n    def get_input_embeddings(self):\n        return self.text_model.embeddings.token_embedding\n\n    def set_input_embeddings(self, embeddings):\n        self.text_model.embeddings.token_embedding = embeddings\n\n    def forward(self, *args, **kwargs):\n        x = self.text_model(*args, **kwargs)\n        out = self.text_projection(x[2])\n        return (x[0], x[1], out, x[2])\n\n\nclass SDTokenizer:\n    def __init__(self, max_length=77, pad_with_end=True, tokenizer=None, has_start_token=True, pad_to_max_length=True, min_length=None):\n        self.tokenizer = tokenizer\n        self.max_length = max_length\n        self.min_length = min_length\n        empty = self.tokenizer('')[\"input_ids\"]\n        if has_start_token:\n            self.tokens_start = 1\n            self.start_token = empty[0]\n            self.end_token = empty[1]\n        else:\n            self.tokens_start = 0\n            self.start_token = None\n            self.end_token = empty[0]\n        self.pad_with_end = pad_with_end\n        self.pad_to_max_length = pad_to_max_length\n        vocab = self.tokenizer.get_vocab()\n        self.inv_vocab = {v: k for k, v in vocab.items()}\n        self.max_word_length = 8\n\n\n    def tokenize_with_weights(self, text:str):\n        \"\"\"Tokenize the text, with weight values - presume 1.0 for all and ignore other features here. The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3.\"\"\"\n        if self.pad_with_end:\n            pad_token = self.end_token\n        else:\n            pad_token = 0\n        batch = []\n        if self.start_token is not None:\n            batch.append((self.start_token, 1.0))\n        to_tokenize = text.replace(\"\\n\", \" \").split(' ')\n        to_tokenize = [x for x in to_tokenize if x != \"\"]\n        for word in to_tokenize:\n            batch.extend([(t, 1) for t in self.tokenizer(word)[\"input_ids\"][self.tokens_start:-1]])\n        batch.append((self.end_token, 1.0))\n        if self.pad_to_max_length:\n            batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch)))\n        if self.min_length is not None and len(batch) < self.min_length:\n            batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch)))\n        return [batch]\n\n\nclass SDXLClipGTokenizer(SDTokenizer):\n    def __init__(self, tokenizer):\n        super().__init__(pad_with_end=False, tokenizer=tokenizer)\n\n\nclass SD3Tokenizer:\n    def __init__(self):\n        clip_tokenizer = CLIPTokenizer.from_pretrained(\"openai/clip-vit-large-patch14\")\n        self.clip_l = SDTokenizer(tokenizer=clip_tokenizer)\n        self.clip_g = SDXLClipGTokenizer(clip_tokenizer)\n        self.t5xxl = T5XXLTokenizer()\n\n    def tokenize_with_weights(self, text:str):\n        out = {}\n        out[\"g\"] = self.clip_g.tokenize_with_weights(text)\n        out[\"l\"] = self.clip_l.tokenize_with_weights(text)\n        out[\"t5xxl\"] = self.t5xxl.tokenize_with_weights(text)\n        return out\n\n\nclass ClipTokenWeightEncoder:\n    def encode_token_weights(self, token_weight_pairs):\n        tokens = [a[0] for a in token_weight_pairs[0]]\n        out, pooled = self([tokens])\n        if pooled is not None:\n            first_pooled = pooled[0:1].cpu()\n        else:\n            first_pooled = pooled\n        output = [out[0:1]]\n        return torch.cat(output, dim=-2).cpu(), first_pooled\n\n\nclass SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):\n    \"\"\"Uses the CLIP transformer encoder for text (from huggingface)\"\"\"\n    LAYERS = [\"last\", \"pooled\", \"hidden\"]\n    def __init__(self, device=\"cpu\", max_length=77, layer=\"last\", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=CLIPTextModel,\n                 special_tokens=None, layer_norm_hidden_state=True, return_projected_pooled=True):\n        super().__init__()\n        assert layer in self.LAYERS\n        self.transformer = model_class(textmodel_json_config, dtype, device)\n        self.num_layers = self.transformer.num_layers\n        self.max_length = max_length\n        self.transformer = self.transformer.eval()\n        for param in self.parameters():\n            param.requires_grad = False\n        self.layer = layer\n        self.layer_idx = None\n        self.special_tokens = special_tokens if special_tokens is not None else {\"start\": 49406, \"end\": 49407, \"pad\": 49407}\n        self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))\n        self.layer_norm_hidden_state = layer_norm_hidden_state\n        self.return_projected_pooled = return_projected_pooled\n        if layer == \"hidden\":\n            assert layer_idx is not None\n            assert abs(layer_idx) < self.num_layers\n            self.set_clip_options({\"layer\": layer_idx})\n        self.options_default = (self.layer, self.layer_idx, self.return_projected_pooled)\n\n    def set_clip_options(self, options):\n        layer_idx = options.get(\"layer\", self.layer_idx)\n        self.return_projected_pooled = options.get(\"projected_pooled\", self.return_projected_pooled)\n        if layer_idx is None or abs(layer_idx) > self.num_layers:\n            self.layer = \"last\"\n        else:\n            self.layer = \"hidden\"\n            self.layer_idx = layer_idx\n\n    def forward(self, tokens):\n        backup_embeds = self.transformer.get_input_embeddings()\n        tokens = torch.asarray(tokens, dtype=torch.int64, device=backup_embeds.weight.device)\n        outputs = self.transformer(tokens, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)\n        self.transformer.set_input_embeddings(backup_embeds)\n        if self.layer == \"last\":\n            z = outputs[0]\n        else:\n            z = outputs[1]\n        pooled_output = None\n        if len(outputs) >= 3:\n            if not self.return_projected_pooled and len(outputs) >= 4 and outputs[3] is not None:\n                pooled_output = outputs[3].float()\n            elif outputs[2] is not None:\n                pooled_output = outputs[2].float()\n        return z.float(), pooled_output\n\n\nclass SDXLClipG(SDClipModel):\n    \"\"\"Wraps the CLIP-G model into the SD-CLIP-Model interface\"\"\"\n    def __init__(self, config, device=\"cpu\", layer=\"penultimate\", layer_idx=None, dtype=None):\n        if layer == \"penultimate\":\n            layer=\"hidden\"\n            layer_idx=-2\n        super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=config, dtype=dtype, special_tokens={\"start\": 49406, \"end\": 49407, \"pad\": 0}, layer_norm_hidden_state=False)\n\n\nclass T5XXLModel(SDClipModel):\n    \"\"\"Wraps the T5-XXL model into the SD-CLIP-Model interface for convenience\"\"\"\n    def __init__(self, config, device=\"cpu\", layer=\"last\", layer_idx=None, dtype=None):\n        super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=config, dtype=dtype, special_tokens={\"end\": 1, \"pad\": 0}, model_class=T5)\n\n\n#################################################################################################\n### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl\n#################################################################################################\n\nclass T5XXLTokenizer(SDTokenizer):\n    \"\"\"Wraps the T5 Tokenizer from HF into the SDTokenizer interface\"\"\"\n    def __init__(self):\n        super().__init__(pad_with_end=False, tokenizer=T5TokenizerFast.from_pretrained(\"google/t5-v1_1-xxl\"), has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)\n\n\nclass T5LayerNorm(torch.nn.Module):\n    def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None):\n        super().__init__()\n        self.weight = torch.nn.Parameter(torch.ones(hidden_size, dtype=dtype, device=device))\n        self.variance_epsilon = eps\n\n    def forward(self, x):\n        variance = x.pow(2).mean(-1, keepdim=True)\n        x = x * torch.rsqrt(variance + self.variance_epsilon)\n        return self.weight.to(device=x.device, dtype=x.dtype) * x\n\n\nclass T5DenseGatedActDense(torch.nn.Module):\n    def __init__(self, model_dim, ff_dim, dtype, device):\n        super().__init__()\n        self.wi_0 = AutocastLinear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)\n        self.wi_1 = AutocastLinear(model_dim, ff_dim, bias=False, dtype=dtype, device=device)\n        self.wo = AutocastLinear(ff_dim, model_dim, bias=False, dtype=dtype, device=device)\n\n    def forward(self, x):\n        hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate=\"tanh\")\n        hidden_linear = self.wi_1(x)\n        x = hidden_gelu * hidden_linear\n        x = self.wo(x)\n        return x\n\n\nclass T5LayerFF(torch.nn.Module):\n    def __init__(self, model_dim, ff_dim, dtype, device):\n        super().__init__()\n        self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, dtype, device)\n        self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)\n\n    def forward(self, x):\n        forwarded_states = self.layer_norm(x)\n        forwarded_states = self.DenseReluDense(forwarded_states)\n        x += forwarded_states\n        return x\n\n\nclass T5Attention(torch.nn.Module):\n    def __init__(self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device):\n        super().__init__()\n        # Mesh TensorFlow initialization to avoid scaling before softmax\n        self.q = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)\n        self.k = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)\n        self.v = AutocastLinear(model_dim, inner_dim, bias=False, dtype=dtype, device=device)\n        self.o = AutocastLinear(inner_dim, model_dim, bias=False, dtype=dtype, device=device)\n        self.num_heads = num_heads\n        self.relative_attention_bias = None\n        if relative_attention_bias:\n            self.relative_attention_num_buckets = 32\n            self.relative_attention_max_distance = 128\n            self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device)\n\n    @staticmethod\n    def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):\n        \"\"\"\n        Adapted from Mesh Tensorflow:\n        https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593\n\n        Translate relative position to a bucket number for relative attention. The relative position is defined as\n        memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to\n        position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for\n        small absolute relative_position and larger buckets for larger absolute relative_positions. All relative\n        positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.\n        This should allow for more graceful generalization to longer sequences than the model has been trained on\n\n        Args:\n            relative_position: an int32 Tensor\n            bidirectional: a boolean - whether the attention is bidirectional\n            num_buckets: an integer\n            max_distance: an integer\n\n        Returns:\n            a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)\n        \"\"\"\n        relative_buckets = 0\n        if bidirectional:\n            num_buckets //= 2\n            relative_buckets += (relative_position > 0).to(torch.long) * num_buckets\n            relative_position = torch.abs(relative_position)\n        else:\n            relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))\n        # now relative_position is in the range [0, inf)\n        # half of the buckets are for exact increments in positions\n        max_exact = num_buckets // 2\n        is_small = relative_position < max_exact\n        # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance\n        relative_position_if_large = max_exact + (\n            torch.log(relative_position.float() / max_exact)\n            / math.log(max_distance / max_exact)\n            * (num_buckets - max_exact)\n        ).to(torch.long)\n        relative_position_if_large = torch.min(relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1))\n        relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)\n        return relative_buckets\n\n    def compute_bias(self, query_length, key_length, device):\n        \"\"\"Compute binned relative position bias\"\"\"\n        context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]\n        memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]\n        relative_position = memory_position - context_position  # shape (query_length, key_length)\n        relative_position_bucket = self._relative_position_bucket(\n            relative_position,  # shape (query_length, key_length)\n            bidirectional=True,\n            num_buckets=self.relative_attention_num_buckets,\n            max_distance=self.relative_attention_max_distance,\n        )\n        values = self.relative_attention_bias(relative_position_bucket)  # shape (query_length, key_length, num_heads)\n        values = values.permute([2, 0, 1]).unsqueeze(0)  # shape (1, num_heads, query_length, key_length)\n        return values\n\n    def forward(self, x, past_bias=None):\n        q = self.q(x)\n        k = self.k(x)\n        v = self.v(x)\n\n        if self.relative_attention_bias is not None:\n            past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device)\n        if past_bias is not None:\n            mask = past_bias\n        else:\n            mask = None\n\n        out = attention(q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask.to(x.dtype) if mask is not None else None)\n\n        return self.o(out), past_bias\n\n\nclass T5LayerSelfAttention(torch.nn.Module):\n    def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device):\n        super().__init__()\n        self.SelfAttention = T5Attention(model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device)\n        self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)\n\n    def forward(self, x, past_bias=None):\n        output, past_bias = self.SelfAttention(self.layer_norm(x), past_bias=past_bias)\n        x += output\n        return x, past_bias\n\n\nclass T5Block(torch.nn.Module):\n    def __init__(self, model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device):\n        super().__init__()\n        self.layer = torch.nn.ModuleList()\n        self.layer.append(T5LayerSelfAttention(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias, dtype, device))\n        self.layer.append(T5LayerFF(model_dim, ff_dim, dtype, device))\n\n    def forward(self, x, past_bias=None):\n        x, past_bias = self.layer[0](x, past_bias)\n        x = self.layer[-1](x)\n        return x, past_bias\n\n\nclass T5Stack(torch.nn.Module):\n    def __init__(self, num_layers, model_dim, inner_dim, ff_dim, num_heads, vocab_size, dtype, device):\n        super().__init__()\n        self.embed_tokens = torch.nn.Embedding(vocab_size, model_dim, device=device)\n        self.block = torch.nn.ModuleList([T5Block(model_dim, inner_dim, ff_dim, num_heads, relative_attention_bias=(i == 0), dtype=dtype, device=device) for i in range(num_layers)])\n        self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device)\n\n    def forward(self, input_ids, intermediate_output=None, final_layer_norm_intermediate=True):\n        intermediate = None\n        x = self.embed_tokens(input_ids).to(torch.float32)  # needs float32 or else T5 returns all zeroes\n        past_bias = None\n        for i, layer in enumerate(self.block):\n            x, past_bias = layer(x, past_bias)\n            if i == intermediate_output:\n                intermediate = x.clone()\n        x = self.final_layer_norm(x)\n        if intermediate is not None and final_layer_norm_intermediate:\n            intermediate = self.final_layer_norm(intermediate)\n        return x, intermediate\n\n\nclass T5(torch.nn.Module):\n    def __init__(self, config_dict, dtype, device):\n        super().__init__()\n        self.num_layers = config_dict[\"num_layers\"]\n        self.encoder = T5Stack(self.num_layers, config_dict[\"d_model\"], config_dict[\"d_model\"], config_dict[\"d_ff\"], config_dict[\"num_heads\"], config_dict[\"vocab_size\"], dtype, device)\n        self.dtype = dtype\n\n    def get_input_embeddings(self):\n        return self.encoder.embed_tokens\n\n    def set_input_embeddings(self, embeddings):\n        self.encoder.embed_tokens = embeddings\n\n    def forward(self, *args, **kwargs):\n        return self.encoder(*args, **kwargs)\n"
  },
  {
    "path": "modules/models/sd3/sd3_cond.py",
    "content": "import os\r\nimport safetensors\r\nimport torch\r\nimport typing\r\n\r\nfrom transformers import CLIPTokenizer, T5TokenizerFast\r\n\r\nfrom modules import shared, devices, modelloader, sd_hijack_clip, prompt_parser\r\nfrom modules.models.sd3.other_impls import SDClipModel, SDXLClipG, T5XXLModel, SD3Tokenizer\r\n\r\n\r\nclass SafetensorsMapping(typing.Mapping):\r\n    def __init__(self, file):\r\n        self.file = file\r\n\r\n    def __len__(self):\r\n        return len(self.file.keys())\r\n\r\n    def __iter__(self):\r\n        for key in self.file.keys():\r\n            yield key\r\n\r\n    def __getitem__(self, key):\r\n        return self.file.get_tensor(key)\r\n\r\n\r\nCLIPL_URL = \"https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_l.safetensors\"\r\nCLIPL_CONFIG = {\r\n    \"hidden_act\": \"quick_gelu\",\r\n    \"hidden_size\": 768,\r\n    \"intermediate_size\": 3072,\r\n    \"num_attention_heads\": 12,\r\n    \"num_hidden_layers\": 12,\r\n}\r\n\r\nCLIPG_URL = \"https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/clip_g.safetensors\"\r\nCLIPG_CONFIG = {\r\n    \"hidden_act\": \"gelu\",\r\n    \"hidden_size\": 1280,\r\n    \"intermediate_size\": 5120,\r\n    \"num_attention_heads\": 20,\r\n    \"num_hidden_layers\": 32,\r\n    \"textual_inversion_key\": \"clip_g\",\r\n}\r\n\r\nT5_URL = \"https://huggingface.co/AUTOMATIC/stable-diffusion-3-medium-text-encoders/resolve/main/t5xxl_fp16.safetensors\"\r\nT5_CONFIG = {\r\n    \"d_ff\": 10240,\r\n    \"d_model\": 4096,\r\n    \"num_heads\": 64,\r\n    \"num_layers\": 24,\r\n    \"vocab_size\": 32128,\r\n}\r\n\r\n\r\nclass Sd3ClipLG(sd_hijack_clip.TextConditionalModel):\r\n    def __init__(self, clip_l, clip_g):\r\n        super().__init__()\r\n\r\n        self.clip_l = clip_l\r\n        self.clip_g = clip_g\r\n\r\n        self.tokenizer = CLIPTokenizer.from_pretrained(\"openai/clip-vit-large-patch14\")\r\n\r\n        empty = self.tokenizer('')[\"input_ids\"]\r\n        self.id_start = empty[0]\r\n        self.id_end = empty[1]\r\n        self.id_pad = empty[1]\r\n\r\n        self.return_pooled = True\r\n\r\n    def tokenize(self, texts):\r\n        return self.tokenizer(texts, truncation=False, add_special_tokens=False)[\"input_ids\"]\r\n\r\n    def encode_with_transformers(self, tokens):\r\n        tokens_g = tokens.clone()\r\n\r\n        for batch_pos in range(tokens_g.shape[0]):\r\n            index = tokens_g[batch_pos].cpu().tolist().index(self.id_end)\r\n            tokens_g[batch_pos, index+1:tokens_g.shape[1]] = 0\r\n\r\n        l_out, l_pooled = self.clip_l(tokens)\r\n        g_out, g_pooled = self.clip_g(tokens_g)\r\n\r\n        lg_out = torch.cat([l_out, g_out], dim=-1)\r\n        lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))\r\n\r\n        vector_out = torch.cat((l_pooled, g_pooled), dim=-1)\r\n\r\n        lg_out.pooled = vector_out\r\n        return lg_out\r\n\r\n    def encode_embedding_init_text(self, init_text, nvpt):\r\n        return torch.zeros((nvpt, 768+1280), device=devices.device) # XXX\r\n\r\n\r\nclass Sd3T5(torch.nn.Module):\r\n    def __init__(self, t5xxl):\r\n        super().__init__()\r\n\r\n        self.t5xxl = t5xxl\r\n        self.tokenizer = T5TokenizerFast.from_pretrained(\"google/t5-v1_1-xxl\")\r\n\r\n        empty = self.tokenizer('', padding='max_length', max_length=2)[\"input_ids\"]\r\n        self.id_end = empty[0]\r\n        self.id_pad = empty[1]\r\n\r\n    def tokenize(self, texts):\r\n        return self.tokenizer(texts, truncation=False, add_special_tokens=False)[\"input_ids\"]\r\n\r\n    def tokenize_line(self, line, *, target_token_count=None):\r\n        if shared.opts.emphasis != \"None\":\r\n            parsed = prompt_parser.parse_prompt_attention(line)\r\n        else:\r\n            parsed = [[line, 1.0]]\r\n\r\n        tokenized = self.tokenize([text for text, _ in parsed])\r\n\r\n        tokens = []\r\n        multipliers = []\r\n\r\n        for text_tokens, (text, weight) in zip(tokenized, parsed):\r\n            if text == 'BREAK' and weight == -1:\r\n                continue\r\n\r\n            tokens += text_tokens\r\n            multipliers += [weight] * len(text_tokens)\r\n\r\n        tokens += [self.id_end]\r\n        multipliers += [1.0]\r\n\r\n        if target_token_count is not None:\r\n            if len(tokens) < target_token_count:\r\n                tokens += [self.id_pad] * (target_token_count - len(tokens))\r\n                multipliers += [1.0] * (target_token_count - len(tokens))\r\n            else:\r\n                tokens = tokens[0:target_token_count]\r\n                multipliers = multipliers[0:target_token_count]\r\n\r\n        return tokens, multipliers\r\n\r\n    def forward(self, texts, *, token_count):\r\n        if not self.t5xxl or not shared.opts.sd3_enable_t5:\r\n            return torch.zeros((len(texts), token_count, 4096), device=devices.device, dtype=devices.dtype)\r\n\r\n        tokens_batch = []\r\n\r\n        for text in texts:\r\n            tokens, multipliers = self.tokenize_line(text, target_token_count=token_count)\r\n            tokens_batch.append(tokens)\r\n\r\n        t5_out, t5_pooled = self.t5xxl(tokens_batch)\r\n\r\n        return t5_out\r\n\r\n    def encode_embedding_init_text(self, init_text, nvpt):\r\n        return torch.zeros((nvpt, 4096), device=devices.device) # XXX\r\n\r\n\r\nclass SD3Cond(torch.nn.Module):\r\n    def __init__(self, *args, **kwargs):\r\n        super().__init__(*args, **kwargs)\r\n\r\n        self.tokenizer = SD3Tokenizer()\r\n\r\n        with torch.no_grad():\r\n            self.clip_g = SDXLClipG(CLIPG_CONFIG, device=\"cpu\", dtype=devices.dtype)\r\n            self.clip_l = SDClipModel(layer=\"hidden\", layer_idx=-2, device=\"cpu\", dtype=devices.dtype, layer_norm_hidden_state=False, return_projected_pooled=False, textmodel_json_config=CLIPL_CONFIG)\r\n\r\n            if shared.opts.sd3_enable_t5:\r\n                self.t5xxl = T5XXLModel(T5_CONFIG, device=\"cpu\", dtype=devices.dtype)\r\n            else:\r\n                self.t5xxl = None\r\n\r\n            self.model_lg = Sd3ClipLG(self.clip_l, self.clip_g)\r\n            self.model_t5 = Sd3T5(self.t5xxl)\r\n\r\n    def forward(self, prompts: list[str]):\r\n        with devices.without_autocast():\r\n            lg_out, vector_out = self.model_lg(prompts)\r\n            t5_out = self.model_t5(prompts, token_count=lg_out.shape[1])\r\n            lgt_out = torch.cat([lg_out, t5_out], dim=-2)\r\n\r\n        return {\r\n            'crossattn': lgt_out,\r\n            'vector': vector_out,\r\n        }\r\n\r\n    def before_load_weights(self, state_dict):\r\n        clip_path = os.path.join(shared.models_path, \"CLIP\")\r\n\r\n        if 'text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight' not in state_dict:\r\n            clip_g_file = modelloader.load_file_from_url(CLIPG_URL, model_dir=clip_path, file_name=\"clip_g.safetensors\")\r\n            with safetensors.safe_open(clip_g_file, framework=\"pt\") as file:\r\n                self.clip_g.transformer.load_state_dict(SafetensorsMapping(file))\r\n\r\n        if 'text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight' not in state_dict:\r\n            clip_l_file = modelloader.load_file_from_url(CLIPL_URL, model_dir=clip_path, file_name=\"clip_l.safetensors\")\r\n            with safetensors.safe_open(clip_l_file, framework=\"pt\") as file:\r\n                self.clip_l.transformer.load_state_dict(SafetensorsMapping(file), strict=False)\r\n\r\n        if self.t5xxl and 'text_encoders.t5xxl.transformer.encoder.embed_tokens.weight' not in state_dict:\r\n            t5_file = modelloader.load_file_from_url(T5_URL, model_dir=clip_path, file_name=\"t5xxl_fp16.safetensors\")\r\n            with safetensors.safe_open(t5_file, framework=\"pt\") as file:\r\n                self.t5xxl.transformer.load_state_dict(SafetensorsMapping(file), strict=False)\r\n\r\n    def encode_embedding_init_text(self, init_text, nvpt):\r\n        return self.model_lg.encode_embedding_init_text(init_text, nvpt)\r\n\r\n    def tokenize(self, texts):\r\n        return self.model_lg.tokenize(texts)\r\n\r\n    def medvram_modules(self):\r\n        return [self.clip_g, self.clip_l, self.t5xxl]\r\n\r\n    def get_token_count(self, text):\r\n        _, token_count = self.model_lg.process_texts([text])\r\n\r\n        return token_count\r\n\r\n    def get_target_prompt_token_count(self, token_count):\r\n        return self.model_lg.get_target_prompt_token_count(token_count)\r\n"
  },
  {
    "path": "modules/models/sd3/sd3_impls.py",
    "content": "### Impls of the SD3 core diffusion model and VAE\n\nimport torch\nimport math\nimport einops\nfrom modules.models.sd3.mmdit import MMDiT\nfrom PIL import Image\n\n\n#################################################################################################\n### MMDiT Model Wrapping\n#################################################################################################\n\n\nclass ModelSamplingDiscreteFlow(torch.nn.Module):\n    \"\"\"Helper for sampler scheduling (ie timestep/sigma calculations) for Discrete Flow models\"\"\"\n    def __init__(self, shift=1.0):\n        super().__init__()\n        self.shift = shift\n        timesteps = 1000\n        ts = self.sigma(torch.arange(1, timesteps + 1, 1))\n        self.register_buffer('sigmas', ts)\n\n    @property\n    def sigma_min(self):\n        return self.sigmas[0]\n\n    @property\n    def sigma_max(self):\n        return self.sigmas[-1]\n\n    def timestep(self, sigma):\n        return sigma * 1000\n\n    def sigma(self, timestep: torch.Tensor):\n        timestep = timestep / 1000.0\n        if self.shift == 1.0:\n            return timestep\n        return self.shift * timestep / (1 + (self.shift - 1) * timestep)\n\n    def calculate_denoised(self, sigma, model_output, model_input):\n        sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))\n        return model_input - model_output * sigma\n\n    def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):\n        return sigma * noise + (1.0 - sigma) * latent_image\n\n\nclass BaseModel(torch.nn.Module):\n    \"\"\"Wrapper around the core MM-DiT model\"\"\"\n    def __init__(self, shift=1.0, device=None, dtype=torch.float32, state_dict=None, prefix=\"\"):\n        super().__init__()\n        # Important configuration values can be quickly determined by checking shapes in the source file\n        # Some of these will vary between models (eg 2B vs 8B primarily differ in their depth, but also other details change)\n        patch_size = state_dict[f\"{prefix}x_embedder.proj.weight\"].shape[2]\n        depth = state_dict[f\"{prefix}x_embedder.proj.weight\"].shape[0] // 64\n        num_patches = state_dict[f\"{prefix}pos_embed\"].shape[1]\n        pos_embed_max_size = round(math.sqrt(num_patches))\n        adm_in_channels = state_dict[f\"{prefix}y_embedder.mlp.0.weight\"].shape[1]\n        context_shape = state_dict[f\"{prefix}context_embedder.weight\"].shape\n        context_embedder_config = {\n            \"target\": \"torch.nn.Linear\",\n            \"params\": {\n                \"in_features\": context_shape[1],\n                \"out_features\": context_shape[0]\n            }\n        }\n        self.diffusion_model = MMDiT(input_size=None, pos_embed_scaling_factor=None, pos_embed_offset=None, pos_embed_max_size=pos_embed_max_size, patch_size=patch_size, in_channels=16, depth=depth, num_patches=num_patches, adm_in_channels=adm_in_channels, context_embedder_config=context_embedder_config, device=device, dtype=dtype)\n        self.model_sampling = ModelSamplingDiscreteFlow(shift=shift)\n        self.depth = depth\n\n    def apply_model(self, x, sigma, c_crossattn=None, y=None):\n        dtype = self.get_dtype()\n        timestep = self.model_sampling.timestep(sigma).float()\n        model_output = self.diffusion_model(x.to(dtype), timestep, context=c_crossattn.to(dtype), y=y.to(dtype)).float()\n        return self.model_sampling.calculate_denoised(sigma, model_output, x)\n\n    def forward(self, *args, **kwargs):\n        return self.apply_model(*args, **kwargs)\n\n    def get_dtype(self):\n        return self.diffusion_model.dtype\n\n\nclass CFGDenoiser(torch.nn.Module):\n    \"\"\"Helper for applying CFG Scaling to diffusion outputs\"\"\"\n    def __init__(self, model):\n        super().__init__()\n        self.model = model\n\n    def forward(self, x, timestep, cond, uncond, cond_scale):\n        # Run cond and uncond in a batch together\n        batched = self.model.apply_model(torch.cat([x, x]), torch.cat([timestep, timestep]), c_crossattn=torch.cat([cond[\"c_crossattn\"], uncond[\"c_crossattn\"]]), y=torch.cat([cond[\"y\"], uncond[\"y\"]]))\n        # Then split and apply CFG Scaling\n        pos_out, neg_out = batched.chunk(2)\n        scaled = neg_out + (pos_out - neg_out) * cond_scale\n        return scaled\n\n\nclass SD3LatentFormat:\n    \"\"\"Latents are slightly shifted from center - this class must be called after VAE Decode to correct for the shift\"\"\"\n    def __init__(self):\n        self.scale_factor = 1.5305\n        self.shift_factor = 0.0609\n\n    def process_in(self, latent):\n        return (latent - self.shift_factor) * self.scale_factor\n\n    def process_out(self, latent):\n        return (latent / self.scale_factor) + self.shift_factor\n\n    def decode_latent_to_preview(self, x0):\n        \"\"\"Quick RGB approximate preview of sd3 latents\"\"\"\n        factors = torch.tensor([\n            [-0.0645,  0.0177,  0.1052], [ 0.0028,  0.0312,  0.0650],\n            [ 0.1848,  0.0762,  0.0360], [ 0.0944,  0.0360,  0.0889],\n            [ 0.0897,  0.0506, -0.0364], [-0.0020,  0.1203,  0.0284],\n            [ 0.0855,  0.0118,  0.0283], [-0.0539,  0.0658,  0.1047],\n            [-0.0057,  0.0116,  0.0700], [-0.0412,  0.0281, -0.0039],\n            [ 0.1106,  0.1171,  0.1220], [-0.0248,  0.0682, -0.0481],\n            [ 0.0815,  0.0846,  0.1207], [-0.0120, -0.0055, -0.0867],\n            [-0.0749, -0.0634, -0.0456], [-0.1418, -0.1457, -0.1259]\n        ], device=\"cpu\")\n        latent_image = x0[0].permute(1, 2, 0).cpu() @ factors\n\n        latents_ubyte = (((latent_image + 1) / 2)\n                            .clamp(0, 1)  # change scale from -1..1 to 0..1\n                            .mul(0xFF)  # to 0..255\n                            .byte()).cpu()\n\n        return Image.fromarray(latents_ubyte.numpy())\n\n\n#################################################################################################\n### K-Diffusion Sampling\n#################################################################################################\n\n\ndef append_dims(x, target_dims):\n    \"\"\"Appends dimensions to the end of a tensor until it has target_dims dimensions.\"\"\"\n    dims_to_append = target_dims - x.ndim\n    return x[(...,) + (None,) * dims_to_append]\n\n\ndef to_d(x, sigma, denoised):\n    \"\"\"Converts a denoiser output to a Karras ODE derivative.\"\"\"\n    return (x - denoised) / append_dims(sigma, x.ndim)\n\n\n@torch.no_grad()\n@torch.autocast(\"cuda\", dtype=torch.float16)\ndef sample_euler(model, x, sigmas, extra_args=None):\n    \"\"\"Implements Algorithm 2 (Euler steps) from Karras et al. (2022).\"\"\"\n    extra_args = {} if extra_args is None else extra_args\n    s_in = x.new_ones([x.shape[0]])\n    for i in range(len(sigmas) - 1):\n        sigma_hat = sigmas[i]\n        denoised = model(x, sigma_hat * s_in, **extra_args)\n        d = to_d(x, sigma_hat, denoised)\n        dt = sigmas[i + 1] - sigma_hat\n        # Euler method\n        x = x + d * dt\n    return x\n\n\n#################################################################################################\n### VAE\n#################################################################################################\n\n\ndef Normalize(in_channels, num_groups=32, dtype=torch.float32, device=None):\n    return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)\n\n\nclass ResnetBlock(torch.nn.Module):\n    def __init__(self, *, in_channels, out_channels=None, dtype=torch.float32, device=None):\n        super().__init__()\n        self.in_channels = in_channels\n        out_channels = in_channels if out_channels is None else out_channels\n        self.out_channels = out_channels\n\n        self.norm1 = Normalize(in_channels, dtype=dtype, device=device)\n        self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)\n        self.norm2 = Normalize(out_channels, dtype=dtype, device=device)\n        self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)\n        if self.in_channels != self.out_channels:\n            self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)\n        else:\n            self.nin_shortcut = None\n        self.swish = torch.nn.SiLU(inplace=True)\n\n    def forward(self, x):\n        hidden = x\n        hidden = self.norm1(hidden)\n        hidden = self.swish(hidden)\n        hidden = self.conv1(hidden)\n        hidden = self.norm2(hidden)\n        hidden = self.swish(hidden)\n        hidden = self.conv2(hidden)\n        if self.in_channels != self.out_channels:\n            x = self.nin_shortcut(x)\n        return x + hidden\n\n\nclass AttnBlock(torch.nn.Module):\n    def __init__(self, in_channels, dtype=torch.float32, device=None):\n        super().__init__()\n        self.norm = Normalize(in_channels, dtype=dtype, device=device)\n        self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)\n        self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)\n        self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)\n        self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)\n\n    def forward(self, x):\n        hidden = self.norm(x)\n        q = self.q(hidden)\n        k = self.k(hidden)\n        v = self.v(hidden)\n        b, c, h, w = q.shape\n        q, k, v = [einops.rearrange(x, \"b c h w -> b 1 (h w) c\").contiguous() for x in (q, k, v)]\n        hidden = torch.nn.functional.scaled_dot_product_attention(q, k, v)  # scale is dim ** -0.5 per default\n        hidden = einops.rearrange(hidden, \"b 1 (h w) c -> b c h w\", h=h, w=w, c=c, b=b)\n        hidden = self.proj_out(hidden)\n        return x + hidden\n\n\nclass Downsample(torch.nn.Module):\n    def __init__(self, in_channels, dtype=torch.float32, device=None):\n        super().__init__()\n        self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0, dtype=dtype, device=device)\n\n    def forward(self, x):\n        pad = (0,1,0,1)\n        x = torch.nn.functional.pad(x, pad, mode=\"constant\", value=0)\n        x = self.conv(x)\n        return x\n\n\nclass Upsample(torch.nn.Module):\n    def __init__(self, in_channels, dtype=torch.float32, device=None):\n        super().__init__()\n        self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)\n\n    def forward(self, x):\n        x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode=\"nearest\")\n        x = self.conv(x)\n        return x\n\n\nclass VAEEncoder(torch.nn.Module):\n    def __init__(self, ch=128, ch_mult=(1,2,4,4), num_res_blocks=2, in_channels=3, z_channels=16, dtype=torch.float32, device=None):\n        super().__init__()\n        self.num_resolutions = len(ch_mult)\n        self.num_res_blocks = num_res_blocks\n        # downsampling\n        self.conv_in = torch.nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)\n        in_ch_mult = (1,) + tuple(ch_mult)\n        self.in_ch_mult = in_ch_mult\n        self.down = torch.nn.ModuleList()\n        for i_level in range(self.num_resolutions):\n            block = torch.nn.ModuleList()\n            attn = torch.nn.ModuleList()\n            block_in = ch*in_ch_mult[i_level]\n            block_out = ch*ch_mult[i_level]\n            for _ in range(num_res_blocks):\n                block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device))\n                block_in = block_out\n            down = torch.nn.Module()\n            down.block = block\n            down.attn = attn\n            if i_level != self.num_resolutions - 1:\n                down.downsample = Downsample(block_in, dtype=dtype, device=device)\n            self.down.append(down)\n        # middle\n        self.mid = torch.nn.Module()\n        self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)\n        self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device)\n        self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)\n        # end\n        self.norm_out = Normalize(block_in, dtype=dtype, device=device)\n        self.conv_out = torch.nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)\n        self.swish = torch.nn.SiLU(inplace=True)\n\n    def forward(self, x):\n        # downsampling\n        hs = [self.conv_in(x)]\n        for i_level in range(self.num_resolutions):\n            for i_block in range(self.num_res_blocks):\n                h = self.down[i_level].block[i_block](hs[-1])\n                hs.append(h)\n            if i_level != self.num_resolutions-1:\n                hs.append(self.down[i_level].downsample(hs[-1]))\n        # middle\n        h = hs[-1]\n        h = self.mid.block_1(h)\n        h = self.mid.attn_1(h)\n        h = self.mid.block_2(h)\n        # end\n        h = self.norm_out(h)\n        h = self.swish(h)\n        h = self.conv_out(h)\n        return h\n\n\nclass VAEDecoder(torch.nn.Module):\n    def __init__(self, ch=128, out_ch=3, ch_mult=(1, 2, 4, 4), num_res_blocks=2, resolution=256, z_channels=16, dtype=torch.float32, device=None):\n        super().__init__()\n        self.num_resolutions = len(ch_mult)\n        self.num_res_blocks = num_res_blocks\n        block_in = ch * ch_mult[self.num_resolutions - 1]\n        curr_res = resolution // 2 ** (self.num_resolutions - 1)\n        # z to block_in\n        self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)\n        # middle\n        self.mid = torch.nn.Module()\n        self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)\n        self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device)\n        self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)\n        # upsampling\n        self.up = torch.nn.ModuleList()\n        for i_level in reversed(range(self.num_resolutions)):\n            block = torch.nn.ModuleList()\n            block_out = ch * ch_mult[i_level]\n            for _ in range(self.num_res_blocks + 1):\n                block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device))\n                block_in = block_out\n            up = torch.nn.Module()\n            up.block = block\n            if i_level != 0:\n                up.upsample = Upsample(block_in, dtype=dtype, device=device)\n                curr_res = curr_res * 2\n            self.up.insert(0, up) # prepend to get consistent order\n        # end\n        self.norm_out = Normalize(block_in, dtype=dtype, device=device)\n        self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)\n        self.swish = torch.nn.SiLU(inplace=True)\n\n    def forward(self, z):\n        # z to block_in\n        hidden = self.conv_in(z)\n        # middle\n        hidden = self.mid.block_1(hidden)\n        hidden = self.mid.attn_1(hidden)\n        hidden = self.mid.block_2(hidden)\n        # upsampling\n        for i_level in reversed(range(self.num_resolutions)):\n            for i_block in range(self.num_res_blocks + 1):\n                hidden = self.up[i_level].block[i_block](hidden)\n            if i_level != 0:\n                hidden = self.up[i_level].upsample(hidden)\n        # end\n        hidden = self.norm_out(hidden)\n        hidden = self.swish(hidden)\n        hidden = self.conv_out(hidden)\n        return hidden\n\n\nclass SDVAE(torch.nn.Module):\n    def __init__(self, dtype=torch.float32, device=None):\n        super().__init__()\n        self.encoder = VAEEncoder(dtype=dtype, device=device)\n        self.decoder = VAEDecoder(dtype=dtype, device=device)\n\n    @torch.autocast(\"cuda\", dtype=torch.float16)\n    def decode(self, latent):\n        return self.decoder(latent)\n\n    @torch.autocast(\"cuda\", dtype=torch.float16)\n    def encode(self, image):\n        hidden = self.encoder(image)\n        mean, logvar = torch.chunk(hidden, 2, dim=1)\n        logvar = torch.clamp(logvar, -30.0, 20.0)\n        std = torch.exp(0.5 * logvar)\n        return mean + std * torch.randn_like(mean)\n"
  },
  {
    "path": "modules/models/sd3/sd3_model.py",
    "content": "import contextlib\r\n\r\nimport torch\r\n\r\nimport k_diffusion\r\nfrom modules.models.sd3.sd3_impls import BaseModel, SDVAE, SD3LatentFormat\r\nfrom modules.models.sd3.sd3_cond import SD3Cond\r\n\r\nfrom modules import shared, devices\r\n\r\n\r\nclass SD3Denoiser(k_diffusion.external.DiscreteSchedule):\r\n    def __init__(self, inner_model, sigmas):\r\n        super().__init__(sigmas, quantize=shared.opts.enable_quantization)\r\n        self.inner_model = inner_model\r\n\r\n    def forward(self, input, sigma, **kwargs):\r\n        return self.inner_model.apply_model(input, sigma, **kwargs)\r\n\r\n\r\nclass SD3Inferencer(torch.nn.Module):\r\n    def __init__(self, state_dict, shift=3, use_ema=False):\r\n        super().__init__()\r\n\r\n        self.shift = shift\r\n\r\n        with torch.no_grad():\r\n            self.model = BaseModel(shift=shift, state_dict=state_dict, prefix=\"model.diffusion_model.\", device=\"cpu\", dtype=devices.dtype)\r\n            self.first_stage_model = SDVAE(device=\"cpu\", dtype=devices.dtype_vae)\r\n            self.first_stage_model.dtype = self.model.diffusion_model.dtype\r\n\r\n        self.alphas_cumprod = 1 / (self.model.model_sampling.sigmas ** 2 + 1)\r\n\r\n        self.text_encoders = SD3Cond()\r\n        self.cond_stage_key = 'txt'\r\n\r\n        self.parameterization = \"eps\"\r\n        self.model.conditioning_key = \"crossattn\"\r\n\r\n        self.latent_format = SD3LatentFormat()\r\n        self.latent_channels = 16\r\n\r\n    @property\r\n    def cond_stage_model(self):\r\n        return self.text_encoders\r\n\r\n    def before_load_weights(self, state_dict):\r\n        self.cond_stage_model.before_load_weights(state_dict)\r\n\r\n    def ema_scope(self):\r\n        return contextlib.nullcontext()\r\n\r\n    def get_learned_conditioning(self, batch: list[str]):\r\n        return self.cond_stage_model(batch)\r\n\r\n    def apply_model(self, x, t, cond):\r\n        return self.model(x, t, c_crossattn=cond['crossattn'], y=cond['vector'])\r\n\r\n    def decode_first_stage(self, latent):\r\n        latent = self.latent_format.process_out(latent)\r\n        return self.first_stage_model.decode(latent)\r\n\r\n    def encode_first_stage(self, image):\r\n        latent = self.first_stage_model.encode(image)\r\n        return self.latent_format.process_in(latent)\r\n\r\n    def get_first_stage_encoding(self, x):\r\n        return x\r\n\r\n    def create_denoiser(self):\r\n        return SD3Denoiser(self, self.model.model_sampling.sigmas)\r\n\r\n    def medvram_fields(self):\r\n        return [\r\n            (self, 'first_stage_model'),\r\n            (self, 'text_encoders'),\r\n            (self, 'model'),\r\n        ]\r\n\r\n    def add_noise_to_latent(self, x, noise, amount):\r\n        return x * (1 - amount) + noise * amount\r\n\r\n    def fix_dimensions(self, width, height):\r\n        return width // 16 * 16, height // 16 * 16\r\n\r\n    def diffusers_weight_mapping(self):\r\n        for i in range(self.model.depth):\r\n            yield f\"transformer.transformer_blocks.{i}.attn.to_q\", f\"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_q_proj\"\r\n            yield f\"transformer.transformer_blocks.{i}.attn.to_k\", f\"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_k_proj\"\r\n            yield f\"transformer.transformer_blocks.{i}.attn.to_v\", f\"diffusion_model_joint_blocks_{i}_x_block_attn_qkv_v_proj\"\r\n            yield f\"transformer.transformer_blocks.{i}.attn.to_out.0\", f\"diffusion_model_joint_blocks_{i}_x_block_attn_proj\"\r\n\r\n            yield f\"transformer.transformer_blocks.{i}.attn.add_q_proj\", f\"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_q_proj\"\r\n            yield f\"transformer.transformer_blocks.{i}.attn.add_k_proj\", f\"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_k_proj\"\r\n            yield f\"transformer.transformer_blocks.{i}.attn.add_v_proj\", f\"diffusion_model_joint_blocks_{i}_context_block.attn_qkv_v_proj\"\r\n            yield f\"transformer.transformer_blocks.{i}.attn.add_out_proj.0\", f\"diffusion_model_joint_blocks_{i}_context_block_attn_proj\"\r\n"
  },
  {
    "path": "modules/ngrok.py",
    "content": "import ngrok\n\n# Connect to ngrok for ingress\ndef connect(token, port, options):\n    account = None\n    if token is None:\n        token = 'None'\n    else:\n        if ':' in token:\n            # token = authtoken:username:password\n            token, username, password = token.split(':', 2)\n            account = f\"{username}:{password}\"\n\n    # For all options see: https://github.com/ngrok/ngrok-py/blob/main/examples/ngrok-connect-full.py\n    if not options.get('authtoken_from_env'):\n        options['authtoken'] = token\n    if account:\n        options['basic_auth'] = account\n    if not options.get('session_metadata'):\n        options['session_metadata'] = 'stable-diffusion-webui'\n\n\n    try:\n        public_url = ngrok.connect(f\"127.0.0.1:{port}\", **options).url()\n    except Exception as e:\n        print(f'Invalid ngrok authtoken? ngrok connection aborted due to: {e}\\n'\n              f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken')\n    else:\n        print(f'ngrok connected to localhost:{port}! URL: {public_url}\\n'\n               'You can use this link after the launch is complete.')\n"
  },
  {
    "path": "modules/npu_specific.py",
    "content": "import importlib\nimport torch\n\nfrom modules import shared\n\n\ndef check_for_npu():\n    if importlib.util.find_spec(\"torch_npu\") is None:\n        return False\n    import torch_npu\n\n    try:\n        # Will raise a RuntimeError if no NPU is found\n        _ = torch_npu.npu.device_count()\n        return torch.npu.is_available()\n    except RuntimeError:\n        return False\n\n\ndef get_npu_device_string():\n    if shared.cmd_opts.device_id is not None:\n        return f\"npu:{shared.cmd_opts.device_id}\"\n    return \"npu:0\"\n\n\ndef torch_npu_gc():\n    with torch.npu.device(get_npu_device_string()):\n        torch.npu.empty_cache()\n\n\nhas_npu = check_for_npu()\n"
  },
  {
    "path": "modules/options.py",
    "content": "import os\r\nimport json\r\nimport sys\r\nfrom dataclasses import dataclass\r\n\r\nimport gradio as gr\r\n\r\nfrom modules import errors\r\nfrom modules.shared_cmd_options import cmd_opts\r\nfrom modules.paths_internal import script_path\r\n\r\n\r\nclass OptionInfo:\r\n    def __init__(self, default=None, label=\"\", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after='', infotext=None, restrict_api=False, category_id=None):\r\n        self.default = default\r\n        self.label = label\r\n        self.component = component\r\n        self.component_args = component_args\r\n        self.onchange = onchange\r\n        self.section = section\r\n        self.category_id = category_id\r\n        self.refresh = refresh\r\n        self.do_not_save = False\r\n\r\n        self.comment_before = comment_before\r\n        \"\"\"HTML text that will be added after label in UI\"\"\"\r\n\r\n        self.comment_after = comment_after\r\n        \"\"\"HTML text that will be added before label in UI\"\"\"\r\n\r\n        self.infotext = infotext\r\n\r\n        self.restrict_api = restrict_api\r\n        \"\"\"If True, the setting will not be accessible via API\"\"\"\r\n\r\n    def link(self, label, url):\r\n        self.comment_before += f\"[<a href='{url}' target='_blank'>{label}</a>]\"\r\n        return self\r\n\r\n    def js(self, label, js_func):\r\n        self.comment_before += f\"[<a onclick='{js_func}(); return false'>{label}</a>]\"\r\n        return self\r\n\r\n    def info(self, info):\r\n        self.comment_after += f\"<span class='info'>({info})</span>\"\r\n        return self\r\n\r\n    def html(self, html):\r\n        self.comment_after += html\r\n        return self\r\n\r\n    def needs_restart(self):\r\n        self.comment_after += \" <span class='info'>(requires restart)</span>\"\r\n        return self\r\n\r\n    def needs_reload_ui(self):\r\n        self.comment_after += \" <span class='info'>(requires Reload UI)</span>\"\r\n        return self\r\n\r\n\r\nclass OptionHTML(OptionInfo):\r\n    def __init__(self, text):\r\n        super().__init__(str(text).strip(), label='', component=lambda **kwargs: gr.HTML(elem_classes=\"settings-info\", **kwargs))\r\n\r\n        self.do_not_save = True\r\n\r\n\r\ndef options_section(section_identifier, options_dict):\r\n    for v in options_dict.values():\r\n        if len(section_identifier) == 2:\r\n            v.section = section_identifier\r\n        elif len(section_identifier) == 3:\r\n            v.section = section_identifier[0:2]\r\n            v.category_id = section_identifier[2]\r\n\r\n    return options_dict\r\n\r\n\r\noptions_builtin_fields = {\"data_labels\", \"data\", \"restricted_opts\", \"typemap\"}\r\n\r\n\r\nclass Options:\r\n    typemap = {int: float}\r\n\r\n    def __init__(self, data_labels: dict[str, OptionInfo], restricted_opts):\r\n        self.data_labels = data_labels\r\n        self.data = {k: v.default for k, v in self.data_labels.items() if not v.do_not_save}\r\n        self.restricted_opts = restricted_opts\r\n\r\n    def __setattr__(self, key, value):\r\n        if key in options_builtin_fields:\r\n            return super(Options, self).__setattr__(key, value)\r\n\r\n        if self.data is not None:\r\n            if key in self.data or key in self.data_labels:\r\n\r\n                # Check that settings aren't globally frozen\r\n                assert not cmd_opts.freeze_settings, \"changing settings is disabled\"\r\n\r\n                # Get the info related to the setting being changed\r\n                info = self.data_labels.get(key, None)\r\n                if info.do_not_save:\r\n                    return\r\n\r\n                # Restrict component arguments\r\n                comp_args = info.component_args if info else None\r\n                if isinstance(comp_args, dict) and comp_args.get('visible', True) is False:\r\n                    raise RuntimeError(f\"not possible to set '{key}' because it is restricted\")\r\n\r\n                # Check that this section isn't frozen\r\n                if cmd_opts.freeze_settings_in_sections is not None:\r\n                    frozen_sections = list(map(str.strip, cmd_opts.freeze_settings_in_sections.split(','))) # Trim whitespace from section names\r\n                    section_key = info.section[0]\r\n                    section_name = info.section[1]\r\n                    assert section_key not in frozen_sections, f\"not possible to set '{key}' because settings in section '{section_name}' ({section_key}) are frozen with --freeze-settings-in-sections\"\r\n\r\n                # Check that this section of the settings isn't frozen\r\n                if cmd_opts.freeze_specific_settings is not None:\r\n                    frozen_keys = list(map(str.strip, cmd_opts.freeze_specific_settings.split(','))) # Trim whitespace from setting keys\r\n                    assert key not in frozen_keys, f\"not possible to set '{key}' because this setting is frozen with --freeze-specific-settings\"\r\n\r\n                # Check shorthand option which disables editing options in \"saving-paths\"\r\n                if cmd_opts.hide_ui_dir_config and key in self.restricted_opts:\r\n                    raise RuntimeError(f\"not possible to set '{key}' because it is restricted with --hide_ui_dir_config\")\r\n\r\n                self.data[key] = value\r\n                return\r\n\r\n        return super(Options, self).__setattr__(key, value)\r\n\r\n    def __getattr__(self, item):\r\n        if item in options_builtin_fields:\r\n            return super(Options, self).__getattribute__(item)\r\n\r\n        if self.data is not None:\r\n            if item in self.data:\r\n                return self.data[item]\r\n\r\n        if item in self.data_labels:\r\n            return self.data_labels[item].default\r\n\r\n        return super(Options, self).__getattribute__(item)\r\n\r\n    def set(self, key, value, is_api=False, run_callbacks=True):\r\n        \"\"\"sets an option and calls its onchange callback, returning True if the option changed and False otherwise\"\"\"\r\n\r\n        oldval = self.data.get(key, None)\r\n        if oldval == value:\r\n            return False\r\n\r\n        option = self.data_labels[key]\r\n        if option.do_not_save:\r\n            return False\r\n\r\n        if is_api and option.restrict_api:\r\n            return False\r\n\r\n        try:\r\n            setattr(self, key, value)\r\n        except RuntimeError:\r\n            return False\r\n\r\n        if run_callbacks and option.onchange is not None:\r\n            try:\r\n                option.onchange()\r\n            except Exception as e:\r\n                errors.display(e, f\"changing setting {key} to {value}\")\r\n                setattr(self, key, oldval)\r\n                return False\r\n\r\n        return True\r\n\r\n    def get_default(self, key):\r\n        \"\"\"returns the default value for the key\"\"\"\r\n\r\n        data_label = self.data_labels.get(key)\r\n        if data_label is None:\r\n            return None\r\n\r\n        return data_label.default\r\n\r\n    def save(self, filename):\r\n        assert not cmd_opts.freeze_settings, \"saving settings is disabled\"\r\n\r\n        with open(filename, \"w\", encoding=\"utf8\") as file:\r\n            json.dump(self.data, file, indent=4, ensure_ascii=False)\r\n\r\n    def same_type(self, x, y):\r\n        if x is None or y is None:\r\n            return True\r\n\r\n        type_x = self.typemap.get(type(x), type(x))\r\n        type_y = self.typemap.get(type(y), type(y))\r\n\r\n        return type_x == type_y\r\n\r\n    def load(self, filename):\r\n        try:\r\n            with open(filename, \"r\", encoding=\"utf8\") as file:\r\n                self.data = json.load(file)\r\n        except FileNotFoundError:\r\n            self.data = {}\r\n        except Exception:\r\n            errors.report(f'\\nCould not load settings\\nThe config file \"{filename}\" is likely corrupted\\nIt has been moved to the \"tmp/config.json\"\\nReverting config to default\\n\\n''', exc_info=True)\r\n            os.replace(filename, os.path.join(script_path, \"tmp\", \"config.json\"))\r\n            self.data = {}\r\n        # 1.6.0 VAE defaults\r\n        if self.data.get('sd_vae_as_default') is not None and self.data.get('sd_vae_overrides_per_model_preferences') is None:\r\n            self.data['sd_vae_overrides_per_model_preferences'] = not self.data.get('sd_vae_as_default')\r\n\r\n        # 1.1.1 quicksettings list migration\r\n        if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None:\r\n            self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')]\r\n\r\n        # 1.4.0 ui_reorder\r\n        if isinstance(self.data.get('ui_reorder'), str) and self.data.get('ui_reorder') and \"ui_reorder_list\" not in self.data:\r\n            self.data['ui_reorder_list'] = [i.strip() for i in self.data.get('ui_reorder').split(',')]\r\n\r\n        bad_settings = 0\r\n        for k, v in self.data.items():\r\n            info = self.data_labels.get(k, None)\r\n            if info is not None and not self.same_type(info.default, v):\r\n                print(f\"Warning: bad setting value: {k}: {v} ({type(v).__name__}; expected {type(info.default).__name__})\", file=sys.stderr)\r\n                bad_settings += 1\r\n\r\n        if bad_settings > 0:\r\n            print(f\"The program is likely to not work with bad settings.\\nSettings file: {filename}\\nEither fix the file, or delete it and restart.\", file=sys.stderr)\r\n\r\n    def onchange(self, key, func, call=True):\r\n        item = self.data_labels.get(key)\r\n        item.onchange = func\r\n\r\n        if call:\r\n            func()\r\n\r\n    def dumpjson(self):\r\n        d = {k: self.data.get(k, v.default) for k, v in self.data_labels.items()}\r\n        d[\"_comments_before\"] = {k: v.comment_before for k, v in self.data_labels.items() if v.comment_before is not None}\r\n        d[\"_comments_after\"] = {k: v.comment_after for k, v in self.data_labels.items() if v.comment_after is not None}\r\n\r\n        item_categories = {}\r\n        for item in self.data_labels.values():\r\n            if item.section[0] is None:\r\n                continue\r\n\r\n            category = categories.mapping.get(item.category_id)\r\n            category = \"Uncategorized\" if category is None else category.label\r\n            if category not in item_categories:\r\n                item_categories[category] = item.section[1]\r\n\r\n        # _categories is a list of pairs: [section, category]. Each section (a setting page) will get a special heading above it with the category as text.\r\n        d[\"_categories\"] = [[v, k] for k, v in item_categories.items()] + [[\"Defaults\", \"Other\"]]\r\n\r\n        return json.dumps(d)\r\n\r\n    def add_option(self, key, info):\r\n        self.data_labels[key] = info\r\n        if key not in self.data and not info.do_not_save:\r\n            self.data[key] = info.default\r\n\r\n    def reorder(self):\r\n        \"\"\"Reorder settings so that:\r\n            - all items related to section always go together\r\n            - all sections belonging to a category go together\r\n            - sections inside a category are ordered alphabetically\r\n            - categories are ordered by creation order\r\n\r\n        Category is a superset of sections: for category \"postprocessing\" there could be multiple sections: \"face restoration\", \"upscaling\".\r\n\r\n        This function also changes items' category_id so that all items belonging to a section have the same category_id.\r\n        \"\"\"\r\n\r\n        category_ids = {}\r\n        section_categories = {}\r\n\r\n        settings_items = self.data_labels.items()\r\n        for _, item in settings_items:\r\n            if item.section not in section_categories:\r\n                section_categories[item.section] = item.category_id\r\n\r\n        for _, item in settings_items:\r\n            item.category_id = section_categories.get(item.section)\r\n\r\n        for category_id in categories.mapping:\r\n            if category_id not in category_ids:\r\n                category_ids[category_id] = len(category_ids)\r\n\r\n        def sort_key(x):\r\n            item: OptionInfo = x[1]\r\n            category_order = category_ids.get(item.category_id, len(category_ids))\r\n            section_order = item.section[1]\r\n\r\n            return category_order, section_order\r\n\r\n        self.data_labels = dict(sorted(settings_items, key=sort_key))\r\n\r\n    def cast_value(self, key, value):\r\n        \"\"\"casts an arbitrary to the same type as this setting's value with key\r\n        Example: cast_value(\"eta_noise_seed_delta\", \"12\") -> returns 12 (an int rather than str)\r\n        \"\"\"\r\n\r\n        if value is None:\r\n            return None\r\n\r\n        default_value = self.data_labels[key].default\r\n        if default_value is None:\r\n            default_value = getattr(self, key, None)\r\n        if default_value is None:\r\n            return None\r\n\r\n        expected_type = type(default_value)\r\n        if expected_type == bool and value == \"False\":\r\n            value = False\r\n        else:\r\n            value = expected_type(value)\r\n\r\n        return value\r\n\r\n\r\n@dataclass\r\nclass OptionsCategory:\r\n    id: str\r\n    label: str\r\n\r\nclass OptionsCategories:\r\n    def __init__(self):\r\n        self.mapping = {}\r\n\r\n    def register_category(self, category_id, label):\r\n        if category_id in self.mapping:\r\n            return category_id\r\n\r\n        self.mapping[category_id] = OptionsCategory(category_id, label)\r\n\r\n\r\ncategories = OptionsCategories()\r\n"
  },
  {
    "path": "modules/patches.py",
    "content": "from collections import defaultdict\r\n\r\n\r\ndef patch(key, obj, field, replacement):\r\n    \"\"\"Replaces a function in a module or a class.\r\n\r\n    Also stores the original function in this module, possible to be retrieved via original(key, obj, field).\r\n    If the function is already replaced by this caller (key), an exception is raised -- use undo() before that.\r\n\r\n    Arguments:\r\n        key: identifying information for who is doing the replacement. You can use __name__.\r\n        obj: the module or the class\r\n        field: name of the function as a string\r\n        replacement: the new function\r\n\r\n    Returns:\r\n        the original function\r\n    \"\"\"\r\n\r\n    patch_key = (obj, field)\r\n    if patch_key in originals[key]:\r\n        raise RuntimeError(f\"patch for {field} is already applied\")\r\n\r\n    original_func = getattr(obj, field)\r\n    originals[key][patch_key] = original_func\r\n\r\n    setattr(obj, field, replacement)\r\n\r\n    return original_func\r\n\r\n\r\ndef undo(key, obj, field):\r\n    \"\"\"Undoes the peplacement by the patch().\r\n\r\n    If the function is not replaced, raises an exception.\r\n\r\n    Arguments:\r\n        key: identifying information for who is doing the replacement. You can use __name__.\r\n        obj: the module or the class\r\n        field: name of the function as a string\r\n\r\n    Returns:\r\n        Always None\r\n    \"\"\"\r\n\r\n    patch_key = (obj, field)\r\n\r\n    if patch_key not in originals[key]:\r\n        raise RuntimeError(f\"there is no patch for {field} to undo\")\r\n\r\n    original_func = originals[key].pop(patch_key)\r\n    setattr(obj, field, original_func)\r\n\r\n    return None\r\n\r\n\r\ndef original(key, obj, field):\r\n    \"\"\"Returns the original function for the patch created by the patch() function\"\"\"\r\n    patch_key = (obj, field)\r\n\r\n    return originals[key].get(patch_key, None)\r\n\r\n\r\noriginals = defaultdict(dict)\r\n"
  },
  {
    "path": "modules/paths.py",
    "content": "import os\r\nimport sys\r\nfrom modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, cwd  # noqa: F401\r\n\r\nimport modules.safe  # noqa: F401\r\n\r\n\r\ndef mute_sdxl_imports():\r\n    \"\"\"create fake modules that SDXL wants to import but doesn't actually use for our purposes\"\"\"\r\n\r\n    class Dummy:\r\n        pass\r\n\r\n    module = Dummy()\r\n    module.LPIPS = None\r\n    sys.modules['taming.modules.losses.lpips'] = module\r\n\r\n    module = Dummy()\r\n    module.StableDataModuleFromConfig = None\r\n    sys.modules['sgm.data'] = module\r\n\r\n\r\n# data_path = cmd_opts_pre.data\r\nsys.path.insert(0, script_path)\r\n\r\n# search for directory of stable diffusion in following places\r\nsd_path = None\r\npossible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion-stability-ai'), '.', os.path.dirname(script_path)]\r\nfor possible_sd_path in possible_sd_paths:\r\n    if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')):\r\n        sd_path = os.path.abspath(possible_sd_path)\r\n        break\r\n\r\nassert sd_path is not None, f\"Couldn't find Stable Diffusion in any of: {possible_sd_paths}\"\r\n\r\nmute_sdxl_imports()\r\n\r\npath_dirs = [\r\n    (sd_path, 'ldm', 'Stable Diffusion', []),\r\n    (os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', [\"sgm\"]),\r\n    (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),\r\n    (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', [\"atstart\"]),\r\n]\r\n\r\npaths = {}\r\n\r\nfor d, must_exist, what, options in path_dirs:\r\n    must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist))\r\n    if not os.path.exists(must_exist_path):\r\n        print(f\"Warning: {what} not found at path {must_exist_path}\", file=sys.stderr)\r\n    else:\r\n        d = os.path.abspath(d)\r\n        if \"atstart\" in options:\r\n            sys.path.insert(0, d)\r\n        elif \"sgm\" in options:\r\n            # Stable Diffusion XL repo has scripts dir with __init__.py in it which ruins every extension's scripts dir, so we\r\n            # import sgm and remove it from sys.path so that when a script imports scripts.something, it doesbn't use sgm's scripts dir.\r\n\r\n            sys.path.insert(0, d)\r\n            import sgm  # noqa: F401\r\n            sys.path.pop(0)\r\n        else:\r\n            sys.path.append(d)\r\n        paths[what] = d\r\n"
  },
  {
    "path": "modules/paths_internal.py",
    "content": "\"\"\"this module defines internal paths used by program and is safe to import before dependencies are installed in launch.py\"\"\"\r\n\r\nimport argparse\r\nimport os\r\nimport sys\r\nimport shlex\r\nfrom pathlib import Path\r\n\r\n\r\nnormalized_filepath = lambda filepath: str(Path(filepath).absolute())\r\n\r\ncommandline_args = os.environ.get('COMMANDLINE_ARGS', \"\")\r\nsys.argv += shlex.split(commandline_args)\r\n\r\ncwd = os.getcwd()\r\nmodules_path = os.path.dirname(os.path.realpath(__file__))\r\nscript_path = os.path.dirname(modules_path)\r\n\r\nsd_configs_path = os.path.join(script_path, \"configs\")\r\nsd_default_config = os.path.join(sd_configs_path, \"v1-inference.yaml\")\r\nsd_model_file = os.path.join(script_path, 'model.ckpt')\r\ndefault_sd_model_file = sd_model_file\r\n\r\n# Parse the --data-dir flag first so we can use it as a base for our other argument default values\r\nparser_pre = argparse.ArgumentParser(add_help=False)\r\nparser_pre.add_argument(\"--data-dir\", type=str, default=os.path.dirname(modules_path), help=\"base path where all user data is stored\", )\r\nparser_pre.add_argument(\"--models-dir\", type=str, default=None, help=\"base path where models are stored; overrides --data-dir\", )\r\ncmd_opts_pre = parser_pre.parse_known_args()[0]\r\n\r\ndata_path = cmd_opts_pre.data_dir\r\n\r\nmodels_path = cmd_opts_pre.models_dir if cmd_opts_pre.models_dir else os.path.join(data_path, \"models\")\r\nextensions_dir = os.path.join(data_path, \"extensions\")\r\nextensions_builtin_dir = os.path.join(script_path, \"extensions-builtin\")\r\nconfig_states_dir = os.path.join(script_path, \"config_states\")\r\ndefault_output_dir = os.path.join(data_path, \"outputs\")\r\n\r\nroboto_ttf_file = os.path.join(modules_path, 'Roboto-Regular.ttf')\r\n"
  },
  {
    "path": "modules/postprocessing.py",
    "content": "import os\r\n\r\nfrom PIL import Image\r\n\r\nfrom modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, infotext_utils\r\nfrom modules.shared import opts\r\n\r\n\r\ndef run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):\r\n    devices.torch_gc()\r\n\r\n    shared.state.begin(job=\"extras\")\r\n\r\n    outputs = []\r\n\r\n    def get_images(extras_mode, image, image_folder, input_dir):\r\n        if extras_mode == 1:\r\n            for img in image_folder:\r\n                if isinstance(img, Image.Image):\r\n                    image = images.fix_image(img)\r\n                    fn = ''\r\n                else:\r\n                    image = images.read(os.path.abspath(img.name))\r\n                    fn = os.path.splitext(img.orig_name)[0]\r\n                yield image, fn\r\n        elif extras_mode == 2:\r\n            assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'\r\n            assert input_dir, 'input directory not selected'\r\n\r\n            image_list = shared.listfiles(input_dir)\r\n            for filename in image_list:\r\n                yield filename, filename\r\n        else:\r\n            assert image, 'image not selected'\r\n            yield image, None\r\n\r\n    if extras_mode == 2 and output_dir != '':\r\n        outpath = output_dir\r\n    else:\r\n        outpath = opts.outdir_samples or opts.outdir_extras_samples\r\n\r\n    infotext = ''\r\n\r\n    data_to_process = list(get_images(extras_mode, image, image_folder, input_dir))\r\n    shared.state.job_count = len(data_to_process)\r\n\r\n    for image_placeholder, name in data_to_process:\r\n        image_data: Image.Image\r\n\r\n        shared.state.nextjob()\r\n        shared.state.textinfo = name\r\n        shared.state.skipped = False\r\n\r\n        if shared.state.interrupted or shared.state.stopping_generation:\r\n            break\r\n\r\n        if isinstance(image_placeholder, str):\r\n            try:\r\n                image_data = images.read(image_placeholder)\r\n            except Exception:\r\n                continue\r\n        else:\r\n            image_data = image_placeholder\r\n\r\n        image_data = image_data if image_data.mode in (\"RGBA\", \"RGB\") else image_data.convert(\"RGB\")\r\n\r\n        parameters, existing_pnginfo = images.read_info_from_image(image_data)\r\n        if parameters:\r\n            existing_pnginfo[\"parameters\"] = parameters\r\n\r\n        initial_pp = scripts_postprocessing.PostprocessedImage(image_data)\r\n\r\n        scripts.scripts_postproc.run(initial_pp, args)\r\n\r\n        if shared.state.skipped:\r\n            continue\r\n\r\n        used_suffixes = {}\r\n        for pp in [initial_pp, *initial_pp.extra_images]:\r\n            suffix = pp.get_suffix(used_suffixes)\r\n\r\n            if opts.use_original_name_batch and name is not None:\r\n                basename = os.path.splitext(os.path.basename(name))[0]\r\n                forced_filename = basename + suffix\r\n            else:\r\n                basename = ''\r\n                forced_filename = None\r\n\r\n            infotext = \", \".join([k if k == v else f'{k}: {infotext_utils.quote(v)}' for k, v in pp.info.items() if v is not None])\r\n\r\n            if opts.enable_pnginfo:\r\n                pp.image.info = existing_pnginfo\r\n                pp.image.info[\"postprocessing\"] = infotext\r\n\r\n            shared.state.assign_current_image(pp.image)\r\n\r\n            if save_output:\r\n                fullfn, _ = images.save_image(pp.image, path=outpath, basename=basename, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name=\"extras\", existing_info=existing_pnginfo, forced_filename=forced_filename, suffix=suffix)\r\n\r\n                if pp.caption:\r\n                    caption_filename = os.path.splitext(fullfn)[0] + \".txt\"\r\n                    existing_caption = \"\"\r\n                    try:\r\n                        with open(caption_filename, encoding=\"utf8\") as file:\r\n                            existing_caption = file.read().strip()\r\n                    except FileNotFoundError:\r\n                        pass\r\n\r\n                    action = shared.opts.postprocessing_existing_caption_action\r\n                    if action == 'Prepend' and existing_caption:\r\n                        caption = f\"{existing_caption} {pp.caption}\"\r\n                    elif action == 'Append' and existing_caption:\r\n                        caption = f\"{pp.caption} {existing_caption}\"\r\n                    elif action == 'Keep' and existing_caption:\r\n                        caption = existing_caption\r\n                    else:\r\n                        caption = pp.caption\r\n\r\n                    caption = caption.strip()\r\n                    if caption:\r\n                        with open(caption_filename, \"w\", encoding=\"utf8\") as file:\r\n                            file.write(caption)\r\n\r\n            if extras_mode != 2 or show_extras_results:\r\n                outputs.append(pp.image)\r\n\r\n    devices.torch_gc()\r\n    shared.state.end()\r\n    return outputs, ui_common.plaintext_to_html(infotext), ''\r\n\r\n\r\ndef run_postprocessing_webui(id_task, *args, **kwargs):\r\n    return run_postprocessing(*args, **kwargs)\r\n\r\n\r\ndef run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True, max_side_length: int = 0):\r\n    \"\"\"old handler for API\"\"\"\r\n\r\n    args = scripts.scripts_postproc.create_args_for_run({\r\n        \"Upscale\": {\r\n            \"upscale_enabled\": True,\r\n            \"upscale_mode\": resize_mode,\r\n            \"upscale_by\": upscaling_resize,\r\n            \"max_side_length\": max_side_length,\r\n            \"upscale_to_width\": upscaling_resize_w,\r\n            \"upscale_to_height\": upscaling_resize_h,\r\n            \"upscale_crop\": upscaling_crop,\r\n            \"upscaler_1_name\": extras_upscaler_1,\r\n            \"upscaler_2_name\": extras_upscaler_2,\r\n            \"upscaler_2_visibility\": extras_upscaler_2_visibility,\r\n        },\r\n        \"GFPGAN\": {\r\n            \"enable\": True,\r\n            \"gfpgan_visibility\": gfpgan_visibility,\r\n        },\r\n        \"CodeFormer\": {\r\n            \"enable\": True,\r\n            \"codeformer_visibility\": codeformer_visibility,\r\n            \"codeformer_weight\": codeformer_weight,\r\n        },\r\n    })\r\n\r\n    return run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output=save_output)\r\n"
  },
  {
    "path": "modules/processing.py",
    "content": "from __future__ import annotations\r\nimport json\r\nimport logging\r\nimport math\r\nimport os\r\nimport sys\r\nimport hashlib\r\nfrom dataclasses import dataclass, field\r\n\r\nimport torch\r\nimport numpy as np\r\nfrom PIL import Image, ImageOps\r\nimport random\r\nimport cv2\r\nfrom skimage import exposure\r\nfrom typing import Any\r\n\r\nimport modules.sd_hijack\r\nfrom modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng, profiling\r\nfrom modules.rng import slerp # noqa: F401\r\nfrom modules.sd_hijack import model_hijack\r\nfrom modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes\r\nfrom modules.shared import opts, cmd_opts, state\r\nimport modules.shared as shared\r\nimport modules.paths as paths\r\nimport modules.face_restoration\r\nimport modules.images as images\r\nimport modules.styles\r\nimport modules.sd_models as sd_models\r\nimport modules.sd_vae as sd_vae\r\nfrom ldm.data.util import AddMiDaS\r\nfrom ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion\r\n\r\nfrom einops import repeat, rearrange\r\nfrom blendmodes.blend import blendLayers, BlendType\r\n\r\n\r\n# some of those options should not be changed at all because they would break the model, so I removed them from options.\r\nopt_C = 4\r\nopt_f = 8\r\n\r\n\r\ndef setup_color_correction(image):\r\n    logging.info(\"Calibrating color correction.\")\r\n    correction_target = cv2.cvtColor(np.asarray(image.copy()), cv2.COLOR_RGB2LAB)\r\n    return correction_target\r\n\r\n\r\ndef apply_color_correction(correction, original_image):\r\n    logging.info(\"Applying color correction.\")\r\n    image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(\r\n        cv2.cvtColor(\r\n            np.asarray(original_image),\r\n            cv2.COLOR_RGB2LAB\r\n        ),\r\n        correction,\r\n        channel_axis=2\r\n    ), cv2.COLOR_LAB2RGB).astype(\"uint8\"))\r\n\r\n    image = blendLayers(image, original_image, BlendType.LUMINOSITY)\r\n\r\n    return image.convert('RGB')\r\n\r\n\r\ndef uncrop(image, dest_size, paste_loc):\r\n    x, y, w, h = paste_loc\r\n    base_image = Image.new('RGBA', dest_size)\r\n    image = images.resize_image(1, image, w, h)\r\n    base_image.paste(image, (x, y))\r\n    image = base_image\r\n\r\n    return image\r\n\r\n\r\ndef apply_overlay(image, paste_loc, overlay):\r\n    if overlay is None:\r\n        return image, image.copy()\r\n\r\n    if paste_loc is not None:\r\n        image = uncrop(image, (overlay.width, overlay.height), paste_loc)\r\n\r\n    original_denoised_image = image.copy()\r\n\r\n    image = image.convert('RGBA')\r\n    image.alpha_composite(overlay)\r\n    image = image.convert('RGB')\r\n\r\n    return image, original_denoised_image\r\n\r\ndef create_binary_mask(image, round=True):\r\n    if image.mode == 'RGBA' and image.getextrema()[-1] != (255, 255):\r\n        if round:\r\n            image = image.split()[-1].convert(\"L\").point(lambda x: 255 if x > 128 else 0)\r\n        else:\r\n            image = image.split()[-1].convert(\"L\")\r\n    else:\r\n        image = image.convert('L')\r\n    return image\r\n\r\ndef txt2img_image_conditioning(sd_model, x, width, height):\r\n    if sd_model.model.conditioning_key in {'hybrid', 'concat'}: # Inpainting models\r\n\r\n        # The \"masked-image\" in this case will just be all 0.5 since the entire image is masked.\r\n        image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5\r\n        image_conditioning = images_tensor_to_samples(image_conditioning, approximation_indexes.get(opts.sd_vae_encode_method))\r\n\r\n        # Add the fake full 1s mask to the first dimension.\r\n        image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)\r\n        image_conditioning = image_conditioning.to(x.dtype)\r\n\r\n        return image_conditioning\r\n\r\n    elif sd_model.model.conditioning_key == \"crossattn-adm\": # UnCLIP models\r\n\r\n        return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device)\r\n\r\n    else:\r\n        if sd_model.is_sdxl_inpaint:\r\n            # The \"masked-image\" in this case will just be all 0.5 since the entire image is masked.\r\n            image_conditioning = torch.ones(x.shape[0], 3, height, width, device=x.device) * 0.5\r\n            image_conditioning = images_tensor_to_samples(image_conditioning,\r\n                                                            approximation_indexes.get(opts.sd_vae_encode_method))\r\n\r\n            # Add the fake full 1s mask to the first dimension.\r\n            image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)\r\n            image_conditioning = image_conditioning.to(x.dtype)\r\n\r\n            return image_conditioning\r\n\r\n        # Dummy zero conditioning if we're not using inpainting or unclip models.\r\n        # Still takes up a bit of memory, but no encoder call.\r\n        # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.\r\n        return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)\r\n\r\n\r\n@dataclass(repr=False)\r\nclass StableDiffusionProcessing:\r\n    sd_model: object = None\r\n    outpath_samples: str = None\r\n    outpath_grids: str = None\r\n    prompt: str = \"\"\r\n    prompt_for_display: str = None\r\n    negative_prompt: str = \"\"\r\n    styles: list[str] = None\r\n    seed: int = -1\r\n    subseed: int = -1\r\n    subseed_strength: float = 0\r\n    seed_resize_from_h: int = -1\r\n    seed_resize_from_w: int = -1\r\n    seed_enable_extras: bool = True\r\n    sampler_name: str = None\r\n    scheduler: str = None\r\n    batch_size: int = 1\r\n    n_iter: int = 1\r\n    steps: int = 50\r\n    cfg_scale: float = 7.0\r\n    width: int = 512\r\n    height: int = 512\r\n    restore_faces: bool = None\r\n    tiling: bool = None\r\n    do_not_save_samples: bool = False\r\n    do_not_save_grid: bool = False\r\n    extra_generation_params: dict[str, Any] = None\r\n    overlay_images: list = None\r\n    eta: float = None\r\n    do_not_reload_embeddings: bool = False\r\n    denoising_strength: float = None\r\n    ddim_discretize: str = None\r\n    s_min_uncond: float = None\r\n    s_churn: float = None\r\n    s_tmax: float = None\r\n    s_tmin: float = None\r\n    s_noise: float = None\r\n    override_settings: dict[str, Any] = None\r\n    override_settings_restore_afterwards: bool = True\r\n    sampler_index: int = None\r\n    refiner_checkpoint: str = None\r\n    refiner_switch_at: float = None\r\n    token_merging_ratio = 0\r\n    token_merging_ratio_hr = 0\r\n    disable_extra_networks: bool = False\r\n    firstpass_image: Image = None\r\n\r\n    scripts_value: scripts.ScriptRunner = field(default=None, init=False)\r\n    script_args_value: list = field(default=None, init=False)\r\n    scripts_setup_complete: bool = field(default=False, init=False)\r\n\r\n    cached_uc = [None, None]\r\n    cached_c = [None, None]\r\n\r\n    comments: dict = None\r\n    sampler: sd_samplers_common.Sampler | None = field(default=None, init=False)\r\n    is_using_inpainting_conditioning: bool = field(default=False, init=False)\r\n    paste_to: tuple | None = field(default=None, init=False)\r\n\r\n    is_hr_pass: bool = field(default=False, init=False)\r\n\r\n    c: tuple = field(default=None, init=False)\r\n    uc: tuple = field(default=None, init=False)\r\n\r\n    rng: rng.ImageRNG | None = field(default=None, init=False)\r\n    step_multiplier: int = field(default=1, init=False)\r\n    color_corrections: list = field(default=None, init=False)\r\n\r\n    all_prompts: list = field(default=None, init=False)\r\n    all_negative_prompts: list = field(default=None, init=False)\r\n    all_seeds: list = field(default=None, init=False)\r\n    all_subseeds: list = field(default=None, init=False)\r\n    iteration: int = field(default=0, init=False)\r\n    main_prompt: str = field(default=None, init=False)\r\n    main_negative_prompt: str = field(default=None, init=False)\r\n\r\n    prompts: list = field(default=None, init=False)\r\n    negative_prompts: list = field(default=None, init=False)\r\n    seeds: list = field(default=None, init=False)\r\n    subseeds: list = field(default=None, init=False)\r\n    extra_network_data: dict = field(default=None, init=False)\r\n\r\n    user: str = field(default=None, init=False)\r\n\r\n    sd_model_name: str = field(default=None, init=False)\r\n    sd_model_hash: str = field(default=None, init=False)\r\n    sd_vae_name: str = field(default=None, init=False)\r\n    sd_vae_hash: str = field(default=None, init=False)\r\n\r\n    is_api: bool = field(default=False, init=False)\r\n\r\n    def __post_init__(self):\r\n        if self.sampler_index is not None:\r\n            print(\"sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name\", file=sys.stderr)\r\n\r\n        self.comments = {}\r\n\r\n        if self.styles is None:\r\n            self.styles = []\r\n\r\n        self.sampler_noise_scheduler_override = None\r\n\r\n        self.extra_generation_params = self.extra_generation_params or {}\r\n        self.override_settings = self.override_settings or {}\r\n        self.script_args = self.script_args or {}\r\n\r\n        self.refiner_checkpoint_info = None\r\n\r\n        if not self.seed_enable_extras:\r\n            self.subseed = -1\r\n            self.subseed_strength = 0\r\n            self.seed_resize_from_h = 0\r\n            self.seed_resize_from_w = 0\r\n\r\n        self.cached_uc = StableDiffusionProcessing.cached_uc\r\n        self.cached_c = StableDiffusionProcessing.cached_c\r\n\r\n    def fill_fields_from_opts(self):\r\n        self.s_min_uncond = self.s_min_uncond if self.s_min_uncond is not None else opts.s_min_uncond\r\n        self.s_churn = self.s_churn if self.s_churn is not None else opts.s_churn\r\n        self.s_tmin = self.s_tmin if self.s_tmin is not None else opts.s_tmin\r\n        self.s_tmax = (self.s_tmax if self.s_tmax is not None else opts.s_tmax) or float('inf')\r\n        self.s_noise = self.s_noise if self.s_noise is not None else opts.s_noise\r\n\r\n    @property\r\n    def sd_model(self):\r\n        return shared.sd_model\r\n\r\n    @sd_model.setter\r\n    def sd_model(self, value):\r\n        pass\r\n\r\n    @property\r\n    def scripts(self):\r\n        return self.scripts_value\r\n\r\n    @scripts.setter\r\n    def scripts(self, value):\r\n        self.scripts_value = value\r\n\r\n        if self.scripts_value and self.script_args_value and not self.scripts_setup_complete:\r\n            self.setup_scripts()\r\n\r\n    @property\r\n    def script_args(self):\r\n        return self.script_args_value\r\n\r\n    @script_args.setter\r\n    def script_args(self, value):\r\n        self.script_args_value = value\r\n\r\n        if self.scripts_value and self.script_args_value and not self.scripts_setup_complete:\r\n            self.setup_scripts()\r\n\r\n    def setup_scripts(self):\r\n        self.scripts_setup_complete = True\r\n\r\n        self.scripts.setup_scrips(self, is_ui=not self.is_api)\r\n\r\n    def comment(self, text):\r\n        self.comments[text] = 1\r\n\r\n    def txt2img_image_conditioning(self, x, width=None, height=None):\r\n        self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}\r\n\r\n        return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)\r\n\r\n    def depth2img_image_conditioning(self, source_image):\r\n        # Use the AddMiDaS helper to Format our source image to suit the MiDaS model\r\n        transformer = AddMiDaS(model_type=\"dpt_hybrid\")\r\n        transformed = transformer({\"jpg\": rearrange(source_image[0], \"c h w -> h w c\")})\r\n        midas_in = torch.from_numpy(transformed[\"midas_in\"][None, ...]).to(device=shared.device)\r\n        midas_in = repeat(midas_in, \"1 ... -> n ...\", n=self.batch_size)\r\n\r\n        conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method))\r\n        conditioning = torch.nn.functional.interpolate(\r\n            self.sd_model.depth_model(midas_in),\r\n            size=conditioning_image.shape[2:],\r\n            mode=\"bicubic\",\r\n            align_corners=False,\r\n        )\r\n\r\n        (depth_min, depth_max) = torch.aminmax(conditioning)\r\n        conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.\r\n        return conditioning\r\n\r\n    def edit_image_conditioning(self, source_image):\r\n        conditioning_image = shared.sd_model.encode_first_stage(source_image).mode()\r\n\r\n        return conditioning_image\r\n\r\n    def unclip_image_conditioning(self, source_image):\r\n        c_adm = self.sd_model.embedder(source_image)\r\n        if self.sd_model.noise_augmentor is not None:\r\n            noise_level = 0 # TODO: Allow other noise levels?\r\n            c_adm, noise_level_emb = self.sd_model.noise_augmentor(c_adm, noise_level=repeat(torch.tensor([noise_level]).to(c_adm.device), '1 -> b', b=c_adm.shape[0]))\r\n            c_adm = torch.cat((c_adm, noise_level_emb), 1)\r\n        return c_adm\r\n\r\n    def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):\r\n        self.is_using_inpainting_conditioning = True\r\n\r\n        # Handle the different mask inputs\r\n        if image_mask is not None:\r\n            if torch.is_tensor(image_mask):\r\n                conditioning_mask = image_mask\r\n            else:\r\n                conditioning_mask = np.array(image_mask.convert(\"L\"))\r\n                conditioning_mask = conditioning_mask.astype(np.float32) / 255.0\r\n                conditioning_mask = torch.from_numpy(conditioning_mask[None, None])\r\n\r\n                if round_image_mask:\r\n                    # Caller is requesting a discretized mask as input, so we round to either 1.0 or 0.0\r\n                    conditioning_mask = torch.round(conditioning_mask)\r\n\r\n        else:\r\n            conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])\r\n\r\n        # Create another latent image, this time with a masked version of the original input.\r\n        # Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.\r\n        conditioning_mask = conditioning_mask.to(device=source_image.device, dtype=source_image.dtype)\r\n        conditioning_image = torch.lerp(\r\n            source_image,\r\n            source_image * (1.0 - conditioning_mask),\r\n            getattr(self, \"inpainting_mask_weight\", shared.opts.inpainting_mask_weight)\r\n        )\r\n\r\n        # Encode the new masked image using first stage of network.\r\n        conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))\r\n\r\n        # Create the concatenated conditioning tensor to be fed to `c_concat`\r\n        conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])\r\n        conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)\r\n        image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)\r\n        image_conditioning = image_conditioning.to(shared.device).type(self.sd_model.dtype)\r\n\r\n        return image_conditioning\r\n\r\n    def img2img_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):\r\n        source_image = devices.cond_cast_float(source_image)\r\n\r\n        # HACK: Using introspection as the Depth2Image model doesn't appear to uniquely\r\n        # identify itself with a field common to all models. The conditioning_key is also hybrid.\r\n        if isinstance(self.sd_model, LatentDepth2ImageDiffusion):\r\n            return self.depth2img_image_conditioning(source_image)\r\n\r\n        if self.sd_model.cond_stage_key == \"edit\":\r\n            return self.edit_image_conditioning(source_image)\r\n\r\n        if self.sampler.conditioning_key in {'hybrid', 'concat'}:\r\n            return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask, round_image_mask=round_image_mask)\r\n\r\n        if self.sampler.conditioning_key == \"crossattn-adm\":\r\n            return self.unclip_image_conditioning(source_image)\r\n\r\n        if self.sampler.model_wrap.inner_model.is_sdxl_inpaint:\r\n            return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)\r\n\r\n        # Dummy zero conditioning if we're not using inpainting or depth model.\r\n        return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)\r\n\r\n    def init(self, all_prompts, all_seeds, all_subseeds):\r\n        pass\r\n\r\n    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):\r\n        raise NotImplementedError()\r\n\r\n    def close(self):\r\n        self.sampler = None\r\n        self.c = None\r\n        self.uc = None\r\n        if not opts.persistent_cond_cache:\r\n            StableDiffusionProcessing.cached_c = [None, None]\r\n            StableDiffusionProcessing.cached_uc = [None, None]\r\n\r\n    def get_token_merging_ratio(self, for_hr=False):\r\n        if for_hr:\r\n            return self.token_merging_ratio_hr or opts.token_merging_ratio_hr or self.token_merging_ratio or opts.token_merging_ratio\r\n\r\n        return self.token_merging_ratio or opts.token_merging_ratio\r\n\r\n    def setup_prompts(self):\r\n        if isinstance(self.prompt,list):\r\n            self.all_prompts = self.prompt\r\n        elif isinstance(self.negative_prompt, list):\r\n            self.all_prompts = [self.prompt] * len(self.negative_prompt)\r\n        else:\r\n            self.all_prompts = self.batch_size * self.n_iter * [self.prompt]\r\n\r\n        if isinstance(self.negative_prompt, list):\r\n            self.all_negative_prompts = self.negative_prompt\r\n        else:\r\n            self.all_negative_prompts = [self.negative_prompt] * len(self.all_prompts)\r\n\r\n        if len(self.all_prompts) != len(self.all_negative_prompts):\r\n            raise RuntimeError(f\"Received a different number of prompts ({len(self.all_prompts)}) and negative prompts ({len(self.all_negative_prompts)})\")\r\n\r\n        self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]\r\n        self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]\r\n\r\n        self.main_prompt = self.all_prompts[0]\r\n        self.main_negative_prompt = self.all_negative_prompts[0]\r\n\r\n    def cached_params(self, required_prompts, steps, extra_network_data, hires_steps=None, use_old_scheduling=False):\r\n        \"\"\"Returns parameters that invalidate the cond cache if changed\"\"\"\r\n\r\n        return (\r\n            required_prompts,\r\n            steps,\r\n            hires_steps,\r\n            use_old_scheduling,\r\n            opts.CLIP_stop_at_last_layers,\r\n            shared.sd_model.sd_checkpoint_info,\r\n            extra_network_data,\r\n            opts.sdxl_crop_left,\r\n            opts.sdxl_crop_top,\r\n            self.width,\r\n            self.height,\r\n            opts.fp8_storage,\r\n            opts.cache_fp16_weight,\r\n            opts.emphasis,\r\n        )\r\n\r\n    def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None):\r\n        \"\"\"\r\n        Returns the result of calling function(shared.sd_model, required_prompts, steps)\r\n        using a cache to store the result if the same arguments have been used before.\r\n\r\n        cache is an array containing two elements. The first element is a tuple\r\n        representing the previously used arguments, or None if no arguments\r\n        have been used before. The second element is where the previously\r\n        computed result is stored.\r\n\r\n        caches is a list with items described above.\r\n        \"\"\"\r\n\r\n        if shared.opts.use_old_scheduling:\r\n            old_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(required_prompts, steps, hires_steps, False)\r\n            new_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(required_prompts, steps, hires_steps, True)\r\n            if old_schedules != new_schedules:\r\n                self.extra_generation_params[\"Old prompt editing timelines\"] = True\r\n\r\n        cached_params = self.cached_params(required_prompts, steps, extra_network_data, hires_steps, shared.opts.use_old_scheduling)\r\n\r\n        for cache in caches:\r\n            if cache[0] is not None and cached_params == cache[0]:\r\n                return cache[1]\r\n\r\n        cache = caches[0]\r\n\r\n        with devices.autocast():\r\n            cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)\r\n\r\n        cache[0] = cached_params\r\n        return cache[1]\r\n\r\n    def setup_conds(self):\r\n        prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height)\r\n        negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True)\r\n\r\n        sampler_config = sd_samplers.find_sampler_config(self.sampler_name)\r\n        total_steps = sampler_config.total_steps(self.steps) if sampler_config else self.steps\r\n        self.step_multiplier = total_steps // self.steps\r\n        self.firstpass_steps = total_steps\r\n\r\n        self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, total_steps, [self.cached_uc], self.extra_network_data)\r\n        self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, total_steps, [self.cached_c], self.extra_network_data)\r\n\r\n    def get_conds(self):\r\n        return self.c, self.uc\r\n\r\n    def parse_extra_network_prompts(self):\r\n        self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)\r\n\r\n    def save_samples(self) -> bool:\r\n        \"\"\"Returns whether generated images need to be written to disk\"\"\"\r\n        return opts.samples_save and not self.do_not_save_samples and (opts.save_incomplete_images or not state.interrupted and not state.skipped)\r\n\r\n\r\nclass Processed:\r\n    def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info=\"\", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=\"\"):\r\n        self.images = images_list\r\n        self.prompt = p.prompt\r\n        self.negative_prompt = p.negative_prompt\r\n        self.seed = seed\r\n        self.subseed = subseed\r\n        self.subseed_strength = p.subseed_strength\r\n        self.info = info\r\n        self.comments = \"\".join(f\"{comment}\\n\" for comment in p.comments)\r\n        self.width = p.width\r\n        self.height = p.height\r\n        self.sampler_name = p.sampler_name\r\n        self.cfg_scale = p.cfg_scale\r\n        self.image_cfg_scale = getattr(p, 'image_cfg_scale', None)\r\n        self.steps = p.steps\r\n        self.batch_size = p.batch_size\r\n        self.restore_faces = p.restore_faces\r\n        self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None\r\n        self.sd_model_name = p.sd_model_name\r\n        self.sd_model_hash = p.sd_model_hash\r\n        self.sd_vae_name = p.sd_vae_name\r\n        self.sd_vae_hash = p.sd_vae_hash\r\n        self.seed_resize_from_w = p.seed_resize_from_w\r\n        self.seed_resize_from_h = p.seed_resize_from_h\r\n        self.denoising_strength = getattr(p, 'denoising_strength', None)\r\n        self.extra_generation_params = p.extra_generation_params\r\n        self.index_of_first_image = index_of_first_image\r\n        self.styles = p.styles\r\n        self.job_timestamp = state.job_timestamp\r\n        self.clip_skip = opts.CLIP_stop_at_last_layers\r\n        self.token_merging_ratio = p.token_merging_ratio\r\n        self.token_merging_ratio_hr = p.token_merging_ratio_hr\r\n\r\n        self.eta = p.eta\r\n        self.ddim_discretize = p.ddim_discretize\r\n        self.s_churn = p.s_churn\r\n        self.s_tmin = p.s_tmin\r\n        self.s_tmax = p.s_tmax\r\n        self.s_noise = p.s_noise\r\n        self.s_min_uncond = p.s_min_uncond\r\n        self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override\r\n        self.prompt = self.prompt if not isinstance(self.prompt, list) else self.prompt[0]\r\n        self.negative_prompt = self.negative_prompt if not isinstance(self.negative_prompt, list) else self.negative_prompt[0]\r\n        self.seed = int(self.seed if not isinstance(self.seed, list) else self.seed[0]) if self.seed is not None else -1\r\n        self.subseed = int(self.subseed if not isinstance(self.subseed, list) else self.subseed[0]) if self.subseed is not None else -1\r\n        self.is_using_inpainting_conditioning = p.is_using_inpainting_conditioning\r\n\r\n        self.all_prompts = all_prompts or p.all_prompts or [self.prompt]\r\n        self.all_negative_prompts = all_negative_prompts or p.all_negative_prompts or [self.negative_prompt]\r\n        self.all_seeds = all_seeds or p.all_seeds or [self.seed]\r\n        self.all_subseeds = all_subseeds or p.all_subseeds or [self.subseed]\r\n        self.infotexts = infotexts or [info] * len(images_list)\r\n        self.version = program_version()\r\n\r\n    def js(self):\r\n        obj = {\r\n            \"prompt\": self.all_prompts[0],\r\n            \"all_prompts\": self.all_prompts,\r\n            \"negative_prompt\": self.all_negative_prompts[0],\r\n            \"all_negative_prompts\": self.all_negative_prompts,\r\n            \"seed\": self.seed,\r\n            \"all_seeds\": self.all_seeds,\r\n            \"subseed\": self.subseed,\r\n            \"all_subseeds\": self.all_subseeds,\r\n            \"subseed_strength\": self.subseed_strength,\r\n            \"width\": self.width,\r\n            \"height\": self.height,\r\n            \"sampler_name\": self.sampler_name,\r\n            \"cfg_scale\": self.cfg_scale,\r\n            \"steps\": self.steps,\r\n            \"batch_size\": self.batch_size,\r\n            \"restore_faces\": self.restore_faces,\r\n            \"face_restoration_model\": self.face_restoration_model,\r\n            \"sd_model_name\": self.sd_model_name,\r\n            \"sd_model_hash\": self.sd_model_hash,\r\n            \"sd_vae_name\": self.sd_vae_name,\r\n            \"sd_vae_hash\": self.sd_vae_hash,\r\n            \"seed_resize_from_w\": self.seed_resize_from_w,\r\n            \"seed_resize_from_h\": self.seed_resize_from_h,\r\n            \"denoising_strength\": self.denoising_strength,\r\n            \"extra_generation_params\": self.extra_generation_params,\r\n            \"index_of_first_image\": self.index_of_first_image,\r\n            \"infotexts\": self.infotexts,\r\n            \"styles\": self.styles,\r\n            \"job_timestamp\": self.job_timestamp,\r\n            \"clip_skip\": self.clip_skip,\r\n            \"is_using_inpainting_conditioning\": self.is_using_inpainting_conditioning,\r\n            \"version\": self.version,\r\n        }\r\n\r\n        return json.dumps(obj, default=lambda o: None)\r\n\r\n    def infotext(self, p: StableDiffusionProcessing, index):\r\n        return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)\r\n\r\n    def get_token_merging_ratio(self, for_hr=False):\r\n        return self.token_merging_ratio_hr if for_hr else self.token_merging_ratio\r\n\r\n\r\ndef create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):\r\n    g = rng.ImageRNG(shape, seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=seed_resize_from_h, seed_resize_from_w=seed_resize_from_w)\r\n    return g.next()\r\n\r\n\r\nclass DecodedSamples(list):\r\n    already_decoded = True\r\n\r\n\r\ndef decode_latent_batch(model, batch, target_device=None, check_for_nans=False):\r\n    samples = DecodedSamples()\r\n\r\n    if check_for_nans:\r\n        devices.test_for_nans(batch, \"unet\")\r\n\r\n    for i in range(batch.shape[0]):\r\n        sample = decode_first_stage(model, batch[i:i + 1])[0]\r\n\r\n        if check_for_nans:\r\n\r\n            try:\r\n                devices.test_for_nans(sample, \"vae\")\r\n            except devices.NansException as e:\r\n                if shared.opts.auto_vae_precision_bfloat16:\r\n                    autofix_dtype = torch.bfloat16\r\n                    autofix_dtype_text = \"bfloat16\"\r\n                    autofix_dtype_setting = \"Automatically convert VAE to bfloat16\"\r\n                    autofix_dtype_comment = \"\"\r\n                elif shared.opts.auto_vae_precision:\r\n                    autofix_dtype = torch.float32\r\n                    autofix_dtype_text = \"32-bit float\"\r\n                    autofix_dtype_setting = \"Automatically revert VAE to 32-bit floats\"\r\n                    autofix_dtype_comment = \"\\nTo always start with 32-bit VAE, use --no-half-vae commandline flag.\"\r\n                else:\r\n                    raise e\r\n\r\n                if devices.dtype_vae == autofix_dtype:\r\n                    raise e\r\n\r\n                errors.print_error_explanation(\r\n                    \"A tensor with all NaNs was produced in VAE.\\n\"\r\n                    f\"Web UI will now convert VAE into {autofix_dtype_text} and retry.\\n\"\r\n                    f\"To disable this behavior, disable the '{autofix_dtype_setting}' setting.{autofix_dtype_comment}\"\r\n                )\r\n\r\n                devices.dtype_vae = autofix_dtype\r\n                model.first_stage_model.to(devices.dtype_vae)\r\n                batch = batch.to(devices.dtype_vae)\r\n\r\n                sample = decode_first_stage(model, batch[i:i + 1])[0]\r\n\r\n        if target_device is not None:\r\n            sample = sample.to(target_device)\r\n\r\n        samples.append(sample)\r\n\r\n    return samples\r\n\r\n\r\ndef get_fixed_seed(seed):\r\n    if seed == '' or seed is None:\r\n        seed = -1\r\n    elif isinstance(seed, str):\r\n        try:\r\n            seed = int(seed)\r\n        except Exception:\r\n            seed = -1\r\n\r\n    if seed == -1:\r\n        return int(random.randrange(4294967294))\r\n\r\n    return seed\r\n\r\n\r\ndef fix_seed(p):\r\n    p.seed = get_fixed_seed(p.seed)\r\n    p.subseed = get_fixed_seed(p.subseed)\r\n\r\n\r\ndef program_version():\r\n    import launch\r\n\r\n    res = launch.git_tag()\r\n    if res == \"<none>\":\r\n        res = None\r\n\r\n    return res\r\n\r\n\r\ndef create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False, index=None, all_negative_prompts=None):\r\n    \"\"\"\r\n    this function is used to generate the infotext that is stored in the generated images, it's contains the parameters that are required to generate the imagee\r\n    Args:\r\n        p: StableDiffusionProcessing\r\n        all_prompts: list[str]\r\n        all_seeds: list[int]\r\n        all_subseeds: list[int]\r\n        comments: list[str]\r\n        iteration: int\r\n        position_in_batch: int\r\n        use_main_prompt: bool\r\n        index: int\r\n        all_negative_prompts: list[str]\r\n\r\n    Returns: str\r\n\r\n    Extra generation params\r\n    p.extra_generation_params dictionary allows for additional parameters to be added to the infotext\r\n    this can be use by the base webui or extensions.\r\n    To add a new entry, add a new key value pair, the dictionary key will be used as the key of the parameter in the infotext\r\n    the value generation_params can be defined as:\r\n        - str | None\r\n        - List[str|None]\r\n        - callable func(**kwargs) -> str | None\r\n\r\n    When defined as a string, it will be used as without extra processing; this is this most common use case.\r\n\r\n    Defining as a list allows for parameter that changes across images in the job, for example, the 'Seed' parameter.\r\n    The list should have the same length as the total number of images in the entire job.\r\n\r\n    Defining as a callable function allows parameter cannot be generated earlier or when extra logic is required.\r\n    For example 'Hires prompt', due to reasons the hr_prompt might be changed by process in the pipeline or extensions\r\n    and may vary across different images, defining as a static string or list would not work.\r\n\r\n    The function takes locals() as **kwargs, as such will have access to variables like 'p' and 'index'.\r\n    the base signature of the function should be:\r\n        func(**kwargs) -> str | None\r\n    optionally it can have additional arguments that will be used in the function:\r\n        func(p, index, **kwargs) -> str | None\r\n    note: for better future compatibility even though this function will have access to all variables in the locals(),\r\n        it is recommended to only use the arguments present in the function signature of create_infotext.\r\n    For actual implementation examples, see StableDiffusionProcessingTxt2Img.init > get_hr_prompt.\r\n    \"\"\"\r\n\r\n    if use_main_prompt:\r\n        index = 0\r\n    elif index is None:\r\n        index = position_in_batch + iteration * p.batch_size\r\n\r\n    if all_negative_prompts is None:\r\n        all_negative_prompts = p.all_negative_prompts\r\n\r\n    clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)\r\n    enable_hr = getattr(p, 'enable_hr', False)\r\n    token_merging_ratio = p.get_token_merging_ratio()\r\n    token_merging_ratio_hr = p.get_token_merging_ratio(for_hr=True)\r\n\r\n    prompt_text = p.main_prompt if use_main_prompt else all_prompts[index]\r\n    negative_prompt = p.main_negative_prompt if use_main_prompt else all_negative_prompts[index]\r\n\r\n    uses_ensd = opts.eta_noise_seed_delta != 0\r\n    if uses_ensd:\r\n        uses_ensd = sd_samplers_common.is_sampler_using_eta_noise_seed_delta(p)\r\n\r\n    generation_params = {\r\n        \"Steps\": p.steps,\r\n        \"Sampler\": p.sampler_name,\r\n        \"Schedule type\": p.scheduler,\r\n        \"CFG scale\": p.cfg_scale,\r\n        \"Image CFG scale\": getattr(p, 'image_cfg_scale', None),\r\n        \"Seed\": p.all_seeds[0] if use_main_prompt else all_seeds[index],\r\n        \"Face restoration\": opts.face_restoration_model if p.restore_faces else None,\r\n        \"Size\": f\"{p.width}x{p.height}\",\r\n        \"Model hash\": p.sd_model_hash if opts.add_model_hash_to_info else None,\r\n        \"Model\": p.sd_model_name if opts.add_model_name_to_info else None,\r\n        \"FP8 weight\": opts.fp8_storage if devices.fp8 else None,\r\n        \"Cache FP16 weight for LoRA\": opts.cache_fp16_weight if devices.fp8 else None,\r\n        \"VAE hash\": p.sd_vae_hash if opts.add_vae_hash_to_info else None,\r\n        \"VAE\": p.sd_vae_name if opts.add_vae_name_to_info else None,\r\n        \"Variation seed\": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),\r\n        \"Variation seed strength\": (None if p.subseed_strength == 0 else p.subseed_strength),\r\n        \"Seed resize from\": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f\"{p.seed_resize_from_w}x{p.seed_resize_from_h}\"),\r\n        \"Denoising strength\": p.extra_generation_params.get(\"Denoising strength\"),\r\n        \"Conditional mask weight\": getattr(p, \"inpainting_mask_weight\", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,\r\n        \"Clip skip\": None if clip_skip <= 1 else clip_skip,\r\n        \"ENSD\": opts.eta_noise_seed_delta if uses_ensd else None,\r\n        \"Token merging ratio\": None if token_merging_ratio == 0 else token_merging_ratio,\r\n        \"Token merging ratio hr\": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr,\r\n        \"Init image hash\": getattr(p, 'init_img_hash', None),\r\n        \"RNG\": opts.randn_source if opts.randn_source != \"GPU\" else None,\r\n        \"Tiling\": \"True\" if p.tiling else None,\r\n        **p.extra_generation_params,\r\n        \"Version\": program_version() if opts.add_version_to_infotext else None,\r\n        \"User\": p.user if opts.add_user_name_to_info else None,\r\n    }\r\n\r\n    for key, value in generation_params.items():\r\n        try:\r\n            if isinstance(value, list):\r\n                generation_params[key] = value[index]\r\n            elif callable(value):\r\n                generation_params[key] = value(**locals())\r\n        except Exception:\r\n            errors.report(f'Error creating infotext for key \"{key}\"', exc_info=True)\r\n            generation_params[key] = None\r\n\r\n    generation_params_text = \", \".join([k if k == v else f'{k}: {infotext_utils.quote(v)}' for k, v in generation_params.items() if v is not None])\r\n\r\n    negative_prompt_text = f\"\\nNegative prompt: {negative_prompt}\" if negative_prompt else \"\"\r\n\r\n    return f\"{prompt_text}{negative_prompt_text}\\n{generation_params_text}\".strip()\r\n\r\n\r\ndef process_images(p: StableDiffusionProcessing) -> Processed:\r\n    if p.scripts is not None:\r\n        p.scripts.before_process(p)\r\n\r\n    stored_opts = {k: opts.data[k] if k in opts.data else opts.get_default(k) for k in p.override_settings.keys() if k in opts.data}\r\n\r\n    try:\r\n        # if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint\r\n        # and if after running refiner, the refiner model is not unloaded - webui swaps back to main model here, if model over is present it will be reloaded afterwards\r\n        if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:\r\n            p.override_settings.pop('sd_model_checkpoint', None)\r\n            sd_models.reload_model_weights()\r\n\r\n        for k, v in p.override_settings.items():\r\n            opts.set(k, v, is_api=True, run_callbacks=False)\r\n\r\n            if k == 'sd_model_checkpoint':\r\n                sd_models.reload_model_weights()\r\n\r\n            if k == 'sd_vae':\r\n                sd_vae.reload_vae_weights()\r\n\r\n        sd_models.apply_token_merging(p.sd_model, p.get_token_merging_ratio())\r\n\r\n        # backwards compatibility, fix sampler and scheduler if invalid\r\n        sd_samplers.fix_p_invalid_sampler_and_scheduler(p)\r\n\r\n        with profiling.Profiler():\r\n            res = process_images_inner(p)\r\n\r\n    finally:\r\n        sd_models.apply_token_merging(p.sd_model, 0)\r\n\r\n        # restore opts to original state\r\n        if p.override_settings_restore_afterwards:\r\n            for k, v in stored_opts.items():\r\n                setattr(opts, k, v)\r\n\r\n                if k == 'sd_vae':\r\n                    sd_vae.reload_vae_weights()\r\n\r\n    return res\r\n\r\n\r\ndef process_images_inner(p: StableDiffusionProcessing) -> Processed:\r\n    \"\"\"this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch\"\"\"\r\n\r\n    if isinstance(p.prompt, list):\r\n        assert(len(p.prompt) > 0)\r\n    else:\r\n        assert p.prompt is not None\r\n\r\n    devices.torch_gc()\r\n\r\n    seed = get_fixed_seed(p.seed)\r\n    subseed = get_fixed_seed(p.subseed)\r\n\r\n    if p.restore_faces is None:\r\n        p.restore_faces = opts.face_restoration\r\n\r\n    if p.tiling is None:\r\n        p.tiling = opts.tiling\r\n\r\n    if p.refiner_checkpoint not in (None, \"\", \"None\", \"none\"):\r\n        p.refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(p.refiner_checkpoint)\r\n        if p.refiner_checkpoint_info is None:\r\n            raise Exception(f'Could not find checkpoint with name {p.refiner_checkpoint}')\r\n\r\n    if hasattr(shared.sd_model, 'fix_dimensions'):\r\n        p.width, p.height = shared.sd_model.fix_dimensions(p.width, p.height)\r\n\r\n    p.sd_model_name = shared.sd_model.sd_checkpoint_info.name_for_extra\r\n    p.sd_model_hash = shared.sd_model.sd_model_hash\r\n    p.sd_vae_name = sd_vae.get_loaded_vae_name()\r\n    p.sd_vae_hash = sd_vae.get_loaded_vae_hash()\r\n\r\n    modules.sd_hijack.model_hijack.apply_circular(p.tiling)\r\n    modules.sd_hijack.model_hijack.clear_comments()\r\n\r\n    p.fill_fields_from_opts()\r\n    p.setup_prompts()\r\n\r\n    if isinstance(seed, list):\r\n        p.all_seeds = seed\r\n    else:\r\n        p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))]\r\n\r\n    if isinstance(subseed, list):\r\n        p.all_subseeds = subseed\r\n    else:\r\n        p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]\r\n\r\n    if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:\r\n        model_hijack.embedding_db.load_textual_inversion_embeddings()\r\n\r\n    if p.scripts is not None:\r\n        p.scripts.process(p)\r\n\r\n    infotexts = []\r\n    output_images = []\r\n    with torch.no_grad(), p.sd_model.ema_scope():\r\n        with devices.autocast():\r\n            p.init(p.all_prompts, p.all_seeds, p.all_subseeds)\r\n\r\n            # for OSX, loading the model during sampling changes the generated picture, so it is loaded here\r\n            if shared.opts.live_previews_enable and opts.show_progress_type == \"Approx NN\":\r\n                sd_vae_approx.model()\r\n\r\n            sd_unet.apply_unet()\r\n\r\n        if state.job_count == -1:\r\n            state.job_count = p.n_iter\r\n\r\n        for n in range(p.n_iter):\r\n            p.iteration = n\r\n\r\n            if state.skipped:\r\n                state.skipped = False\r\n\r\n            if state.interrupted or state.stopping_generation:\r\n                break\r\n\r\n            sd_models.reload_model_weights()  # model can be changed for example by refiner\r\n\r\n            p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]\r\n            p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]\r\n            p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]\r\n            p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]\r\n\r\n            latent_channels = getattr(shared.sd_model, 'latent_channels', opt_C)\r\n            p.rng = rng.ImageRNG((latent_channels, p.height // opt_f, p.width // opt_f), p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w)\r\n\r\n            if p.scripts is not None:\r\n                p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)\r\n\r\n            if len(p.prompts) == 0:\r\n                break\r\n\r\n            p.parse_extra_network_prompts()\r\n\r\n            if not p.disable_extra_networks:\r\n                with devices.autocast():\r\n                    extra_networks.activate(p, p.extra_network_data)\r\n\r\n            if p.scripts is not None:\r\n                p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)\r\n\r\n            p.setup_conds()\r\n\r\n            p.extra_generation_params.update(model_hijack.extra_generation_params)\r\n\r\n            # params.txt should be saved after scripts.process_batch, since the\r\n            # infotext could be modified by that callback\r\n            # Example: a wildcard processed by process_batch sets an extra model\r\n            # strength, which is saved as \"Model Strength: 1.0\" in the infotext\r\n            if n == 0 and not cmd_opts.no_prompt_history:\r\n                with open(os.path.join(paths.data_path, \"params.txt\"), \"w\", encoding=\"utf8\") as file:\r\n                    processed = Processed(p, [])\r\n                    file.write(processed.infotext(p, 0))\r\n\r\n            for comment in model_hijack.comments:\r\n                p.comment(comment)\r\n\r\n            if p.n_iter > 1:\r\n                shared.state.job = f\"Batch {n+1} out of {p.n_iter}\"\r\n\r\n            sd_models.apply_alpha_schedule_override(p.sd_model, p)\r\n\r\n            with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():\r\n                samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)\r\n\r\n            if p.scripts is not None:\r\n                ps = scripts.PostSampleArgs(samples_ddim)\r\n                p.scripts.post_sample(p, ps)\r\n                samples_ddim = ps.samples\r\n\r\n            if getattr(samples_ddim, 'already_decoded', False):\r\n                x_samples_ddim = samples_ddim\r\n            else:\r\n                devices.test_for_nans(samples_ddim, \"unet\")\r\n\r\n                if opts.sd_vae_decode_method != 'Full':\r\n                    p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method\r\n                x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)\r\n\r\n            x_samples_ddim = torch.stack(x_samples_ddim).float()\r\n            x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)\r\n\r\n            del samples_ddim\r\n\r\n            if lowvram.is_enabled(shared.sd_model):\r\n                lowvram.send_everything_to_cpu()\r\n\r\n            devices.torch_gc()\r\n\r\n            state.nextjob()\r\n\r\n            if p.scripts is not None:\r\n                p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)\r\n\r\n                p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]\r\n                p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]\r\n\r\n                batch_params = scripts.PostprocessBatchListArgs(list(x_samples_ddim))\r\n                p.scripts.postprocess_batch_list(p, batch_params, batch_number=n)\r\n                x_samples_ddim = batch_params.images\r\n\r\n            def infotext(index=0, use_main_prompt=False):\r\n                return create_infotext(p, p.prompts, p.seeds, p.subseeds, use_main_prompt=use_main_prompt, index=index, all_negative_prompts=p.negative_prompts)\r\n\r\n            save_samples = p.save_samples()\r\n\r\n            for i, x_sample in enumerate(x_samples_ddim):\r\n                p.batch_index = i\r\n\r\n                x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)\r\n                x_sample = x_sample.astype(np.uint8)\r\n\r\n                if p.restore_faces:\r\n                    if save_samples and opts.save_images_before_face_restoration:\r\n                        images.save_image(Image.fromarray(x_sample), p.outpath_samples, \"\", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix=\"-before-face-restoration\")\r\n\r\n                    devices.torch_gc()\r\n\r\n                    x_sample = modules.face_restoration.restore_faces(x_sample)\r\n                    devices.torch_gc()\r\n\r\n                image = Image.fromarray(x_sample)\r\n\r\n                if p.scripts is not None:\r\n                    pp = scripts.PostprocessImageArgs(image)\r\n                    p.scripts.postprocess_image(p, pp)\r\n                    image = pp.image\r\n\r\n                mask_for_overlay = getattr(p, \"mask_for_overlay\", None)\r\n\r\n                if not shared.opts.overlay_inpaint:\r\n                    overlay_image = None\r\n                elif getattr(p, \"overlay_images\", None) is not None and i < len(p.overlay_images):\r\n                    overlay_image = p.overlay_images[i]\r\n                else:\r\n                    overlay_image = None\r\n\r\n                if p.scripts is not None:\r\n                    ppmo = scripts.PostProcessMaskOverlayArgs(i, mask_for_overlay, overlay_image)\r\n                    p.scripts.postprocess_maskoverlay(p, ppmo)\r\n                    mask_for_overlay, overlay_image = ppmo.mask_for_overlay, ppmo.overlay_image\r\n\r\n                if p.color_corrections is not None and i < len(p.color_corrections):\r\n                    if save_samples and opts.save_images_before_color_correction:\r\n                        image_without_cc, _ = apply_overlay(image, p.paste_to, overlay_image)\r\n                        images.save_image(image_without_cc, p.outpath_samples, \"\", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix=\"-before-color-correction\")\r\n                    image = apply_color_correction(p.color_corrections[i], image)\r\n\r\n                # If the intention is to show the output from the model\r\n                # that is being composited over the original image,\r\n                # we need to keep the original image around\r\n                # and use it in the composite step.\r\n                image, original_denoised_image = apply_overlay(image, p.paste_to, overlay_image)\r\n\r\n                if p.scripts is not None:\r\n                    pp = scripts.PostprocessImageArgs(image)\r\n                    p.scripts.postprocess_image_after_composite(p, pp)\r\n                    image = pp.image\r\n\r\n                if save_samples:\r\n                    images.save_image(image, p.outpath_samples, \"\", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)\r\n\r\n                text = infotext(i)\r\n                infotexts.append(text)\r\n                if opts.enable_pnginfo:\r\n                    image.info[\"parameters\"] = text\r\n                output_images.append(image)\r\n\r\n                if mask_for_overlay is not None:\r\n                    if opts.return_mask or opts.save_mask:\r\n                        image_mask = mask_for_overlay.convert('RGB')\r\n                        if save_samples and opts.save_mask:\r\n                            images.save_image(image_mask, p.outpath_samples, \"\", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix=\"-mask\")\r\n                        if opts.return_mask:\r\n                            output_images.append(image_mask)\r\n\r\n                    if opts.return_mask_composite or opts.save_mask_composite:\r\n                        image_mask_composite = Image.composite(original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')\r\n                        if save_samples and opts.save_mask_composite:\r\n                            images.save_image(image_mask_composite, p.outpath_samples, \"\", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix=\"-mask-composite\")\r\n                        if opts.return_mask_composite:\r\n                            output_images.append(image_mask_composite)\r\n\r\n            del x_samples_ddim\r\n\r\n            devices.torch_gc()\r\n\r\n        if not infotexts:\r\n            infotexts.append(Processed(p, []).infotext(p, 0))\r\n\r\n        p.color_corrections = None\r\n\r\n        index_of_first_image = 0\r\n        unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple\r\n        if (opts.return_grid or opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count:\r\n            grid = images.image_grid(output_images, p.batch_size)\r\n\r\n            if opts.return_grid:\r\n                text = infotext(use_main_prompt=True)\r\n                infotexts.insert(0, text)\r\n                if opts.enable_pnginfo:\r\n                    grid.info[\"parameters\"] = text\r\n                output_images.insert(0, grid)\r\n                index_of_first_image = 1\r\n            if opts.grid_save:\r\n                images.save_image(grid, p.outpath_grids, \"grid\", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True)\r\n\r\n    if not p.disable_extra_networks and p.extra_network_data:\r\n        extra_networks.deactivate(p, p.extra_network_data)\r\n\r\n    devices.torch_gc()\r\n\r\n    res = Processed(\r\n        p,\r\n        images_list=output_images,\r\n        seed=p.all_seeds[0],\r\n        info=infotexts[0],\r\n        subseed=p.all_subseeds[0],\r\n        index_of_first_image=index_of_first_image,\r\n        infotexts=infotexts,\r\n    )\r\n\r\n    if p.scripts is not None:\r\n        p.scripts.postprocess(p, res)\r\n\r\n    return res\r\n\r\n\r\ndef old_hires_fix_first_pass_dimensions(width, height):\r\n    \"\"\"old algorithm for auto-calculating first pass size\"\"\"\r\n\r\n    desired_pixel_count = 512 * 512\r\n    actual_pixel_count = width * height\r\n    scale = math.sqrt(desired_pixel_count / actual_pixel_count)\r\n    width = math.ceil(scale * width / 64) * 64\r\n    height = math.ceil(scale * height / 64) * 64\r\n\r\n    return width, height\r\n\r\n\r\n@dataclass(repr=False)\r\nclass StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):\r\n    enable_hr: bool = False\r\n    denoising_strength: float = 0.75\r\n    firstphase_width: int = 0\r\n    firstphase_height: int = 0\r\n    hr_scale: float = 2.0\r\n    hr_upscaler: str = None\r\n    hr_second_pass_steps: int = 0\r\n    hr_resize_x: int = 0\r\n    hr_resize_y: int = 0\r\n    hr_checkpoint_name: str = None\r\n    hr_sampler_name: str = None\r\n    hr_scheduler: str = None\r\n    hr_prompt: str = ''\r\n    hr_negative_prompt: str = ''\r\n    force_task_id: str = None\r\n\r\n    cached_hr_uc = [None, None]\r\n    cached_hr_c = [None, None]\r\n\r\n    hr_checkpoint_info: dict = field(default=None, init=False)\r\n    hr_upscale_to_x: int = field(default=0, init=False)\r\n    hr_upscale_to_y: int = field(default=0, init=False)\r\n    truncate_x: int = field(default=0, init=False)\r\n    truncate_y: int = field(default=0, init=False)\r\n    applied_old_hires_behavior_to: tuple = field(default=None, init=False)\r\n    latent_scale_mode: dict = field(default=None, init=False)\r\n    hr_c: tuple | None = field(default=None, init=False)\r\n    hr_uc: tuple | None = field(default=None, init=False)\r\n    all_hr_prompts: list = field(default=None, init=False)\r\n    all_hr_negative_prompts: list = field(default=None, init=False)\r\n    hr_prompts: list = field(default=None, init=False)\r\n    hr_negative_prompts: list = field(default=None, init=False)\r\n    hr_extra_network_data: list = field(default=None, init=False)\r\n\r\n    def __post_init__(self):\r\n        super().__post_init__()\r\n\r\n        if self.firstphase_width != 0 or self.firstphase_height != 0:\r\n            self.hr_upscale_to_x = self.width\r\n            self.hr_upscale_to_y = self.height\r\n            self.width = self.firstphase_width\r\n            self.height = self.firstphase_height\r\n\r\n        self.cached_hr_uc = StableDiffusionProcessingTxt2Img.cached_hr_uc\r\n        self.cached_hr_c = StableDiffusionProcessingTxt2Img.cached_hr_c\r\n\r\n    def calculate_target_resolution(self):\r\n        if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):\r\n            self.hr_resize_x = self.width\r\n            self.hr_resize_y = self.height\r\n            self.hr_upscale_to_x = self.width\r\n            self.hr_upscale_to_y = self.height\r\n\r\n            self.width, self.height = old_hires_fix_first_pass_dimensions(self.width, self.height)\r\n            self.applied_old_hires_behavior_to = (self.width, self.height)\r\n\r\n        if self.hr_resize_x == 0 and self.hr_resize_y == 0:\r\n            self.extra_generation_params[\"Hires upscale\"] = self.hr_scale\r\n            self.hr_upscale_to_x = int(self.width * self.hr_scale)\r\n            self.hr_upscale_to_y = int(self.height * self.hr_scale)\r\n        else:\r\n            self.extra_generation_params[\"Hires resize\"] = f\"{self.hr_resize_x}x{self.hr_resize_y}\"\r\n\r\n            if self.hr_resize_y == 0:\r\n                self.hr_upscale_to_x = self.hr_resize_x\r\n                self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width\r\n            elif self.hr_resize_x == 0:\r\n                self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height\r\n                self.hr_upscale_to_y = self.hr_resize_y\r\n            else:\r\n                target_w = self.hr_resize_x\r\n                target_h = self.hr_resize_y\r\n                src_ratio = self.width / self.height\r\n                dst_ratio = self.hr_resize_x / self.hr_resize_y\r\n\r\n                if src_ratio < dst_ratio:\r\n                    self.hr_upscale_to_x = self.hr_resize_x\r\n                    self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width\r\n                else:\r\n                    self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height\r\n                    self.hr_upscale_to_y = self.hr_resize_y\r\n\r\n                self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f\r\n                self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f\r\n\r\n    def init(self, all_prompts, all_seeds, all_subseeds):\r\n        if self.enable_hr:\r\n            self.extra_generation_params[\"Denoising strength\"] = self.denoising_strength\r\n\r\n            if self.hr_checkpoint_name and self.hr_checkpoint_name != 'Use same checkpoint':\r\n                self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name)\r\n\r\n                if self.hr_checkpoint_info is None:\r\n                    raise Exception(f'Could not find checkpoint with name {self.hr_checkpoint_name}')\r\n\r\n                self.extra_generation_params[\"Hires checkpoint\"] = self.hr_checkpoint_info.short_title\r\n\r\n            if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:\r\n                self.extra_generation_params[\"Hires sampler\"] = self.hr_sampler_name\r\n\r\n            def get_hr_prompt(p, index, prompt_text, **kwargs):\r\n                hr_prompt = p.all_hr_prompts[index]\r\n                return hr_prompt if hr_prompt != prompt_text else None\r\n\r\n            def get_hr_negative_prompt(p, index, negative_prompt, **kwargs):\r\n                hr_negative_prompt = p.all_hr_negative_prompts[index]\r\n                return hr_negative_prompt if hr_negative_prompt != negative_prompt else None\r\n\r\n            self.extra_generation_params[\"Hires prompt\"] = get_hr_prompt\r\n            self.extra_generation_params[\"Hires negative prompt\"] = get_hr_negative_prompt\r\n\r\n            self.extra_generation_params[\"Hires schedule type\"] = None  # to be set in sd_samplers_kdiffusion.py\r\n\r\n            if self.hr_scheduler is None:\r\n                self.hr_scheduler = self.scheduler\r\n\r\n            self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, \"nearest\")\r\n            if self.enable_hr and self.latent_scale_mode is None:\r\n                if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):\r\n                    raise Exception(f\"could not find upscaler named {self.hr_upscaler}\")\r\n\r\n            self.calculate_target_resolution()\r\n\r\n            if not state.processing_has_refined_job_count:\r\n                if state.job_count == -1:\r\n                    state.job_count = self.n_iter\r\n                if getattr(self, 'txt2img_upscale', False):\r\n                    total_steps = (self.hr_second_pass_steps or self.steps) * state.job_count\r\n                else:\r\n                    total_steps = (self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count\r\n                shared.total_tqdm.updateTotal(total_steps)\r\n                state.job_count = state.job_count * 2\r\n                state.processing_has_refined_job_count = True\r\n\r\n            if self.hr_second_pass_steps:\r\n                self.extra_generation_params[\"Hires steps\"] = self.hr_second_pass_steps\r\n\r\n            if self.hr_upscaler is not None:\r\n                self.extra_generation_params[\"Hires upscaler\"] = self.hr_upscaler\r\n\r\n    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):\r\n        self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)\r\n\r\n        if self.firstpass_image is not None and self.enable_hr:\r\n            # here we don't need to generate image, we just take self.firstpass_image and prepare it for hires fix\r\n\r\n            if self.latent_scale_mode is None:\r\n                image = np.array(self.firstpass_image).astype(np.float32) / 255.0 * 2.0 - 1.0\r\n                image = np.moveaxis(image, 2, 0)\r\n\r\n                samples = None\r\n                decoded_samples = torch.asarray(np.expand_dims(image, 0))\r\n\r\n            else:\r\n                image = np.array(self.firstpass_image).astype(np.float32) / 255.0\r\n                image = np.moveaxis(image, 2, 0)\r\n                image = torch.from_numpy(np.expand_dims(image, axis=0))\r\n                image = image.to(shared.device, dtype=devices.dtype_vae)\r\n\r\n                if opts.sd_vae_encode_method != 'Full':\r\n                    self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method\r\n\r\n                samples = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model)\r\n                decoded_samples = None\r\n                devices.torch_gc()\r\n\r\n        else:\r\n            # here we generate an image normally\r\n\r\n            x = self.rng.next()\r\n            if self.scripts is not None:\r\n                self.scripts.process_before_every_sampling(\r\n                    p=self,\r\n                    x=x,\r\n                    noise=x,\r\n                    c=conditioning,\r\n                    uc=unconditional_conditioning\r\n                )\r\n\r\n            samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))\r\n            del x\r\n\r\n            if not self.enable_hr:\r\n                return samples\r\n\r\n            devices.torch_gc()\r\n\r\n            if self.latent_scale_mode is None:\r\n                decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)\r\n            else:\r\n                decoded_samples = None\r\n\r\n        with sd_models.SkipWritingToConfig():\r\n            sd_models.reload_model_weights(info=self.hr_checkpoint_info)\r\n\r\n        return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)\r\n\r\n    def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):\r\n        if shared.state.interrupted:\r\n            return samples\r\n\r\n        self.is_hr_pass = True\r\n        target_width = self.hr_upscale_to_x\r\n        target_height = self.hr_upscale_to_y\r\n\r\n        def save_intermediate(image, index):\r\n            \"\"\"saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images\"\"\"\r\n\r\n            if not self.save_samples() or not opts.save_images_before_highres_fix:\r\n                return\r\n\r\n            if not isinstance(image, Image.Image):\r\n                image = sd_samplers.sample_to_image(image, index, approximation=0)\r\n\r\n            info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)\r\n            images.save_image(image, self.outpath_samples, \"\", seeds[index], prompts[index], opts.samples_format, info=info, p=self, suffix=\"-before-highres-fix\")\r\n\r\n        img2img_sampler_name = self.hr_sampler_name or self.sampler_name\r\n\r\n        self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)\r\n\r\n        if self.latent_scale_mode is not None:\r\n            for i in range(samples.shape[0]):\r\n                save_intermediate(samples, i)\r\n\r\n            samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=self.latent_scale_mode[\"mode\"], antialias=self.latent_scale_mode[\"antialias\"])\r\n\r\n            # Avoid making the inpainting conditioning unless necessary as\r\n            # this does need some extra compute to decode / encode the image again.\r\n            if getattr(self, \"inpainting_mask_weight\", shared.opts.inpainting_mask_weight) < 1.0:\r\n                image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)\r\n            else:\r\n                image_conditioning = self.txt2img_image_conditioning(samples)\r\n        else:\r\n            lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)\r\n\r\n            batch_images = []\r\n            for i, x_sample in enumerate(lowres_samples):\r\n                x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)\r\n                x_sample = x_sample.astype(np.uint8)\r\n                image = Image.fromarray(x_sample)\r\n\r\n                save_intermediate(image, i)\r\n\r\n                image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler)\r\n                image = np.array(image).astype(np.float32) / 255.0\r\n                image = np.moveaxis(image, 2, 0)\r\n                batch_images.append(image)\r\n\r\n            decoded_samples = torch.from_numpy(np.array(batch_images))\r\n            decoded_samples = decoded_samples.to(shared.device, dtype=devices.dtype_vae)\r\n\r\n            if opts.sd_vae_encode_method != 'Full':\r\n                self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method\r\n            samples = images_tensor_to_samples(decoded_samples, approximation_indexes.get(opts.sd_vae_encode_method))\r\n\r\n            image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)\r\n\r\n        shared.state.nextjob()\r\n\r\n        samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]\r\n\r\n        self.rng = rng.ImageRNG(samples.shape[1:], self.seeds, subseeds=self.subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w)\r\n        noise = self.rng.next()\r\n\r\n        # GC now before running the next img2img to prevent running out of memory\r\n        devices.torch_gc()\r\n\r\n        if not self.disable_extra_networks:\r\n            with devices.autocast():\r\n                extra_networks.activate(self, self.hr_extra_network_data)\r\n\r\n        with devices.autocast():\r\n            self.calculate_hr_conds()\r\n\r\n        sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))\r\n\r\n        if self.scripts is not None:\r\n            self.scripts.before_hr(self)\r\n            self.scripts.process_before_every_sampling(\r\n                p=self,\r\n                x=samples,\r\n                noise=noise,\r\n                c=self.hr_c,\r\n                uc=self.hr_uc,\r\n            )\r\n\r\n        samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)\r\n\r\n        sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())\r\n\r\n        self.sampler = None\r\n        devices.torch_gc()\r\n\r\n        decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)\r\n\r\n        self.is_hr_pass = False\r\n        return decoded_samples\r\n\r\n    def close(self):\r\n        super().close()\r\n        self.hr_c = None\r\n        self.hr_uc = None\r\n        if not opts.persistent_cond_cache:\r\n            StableDiffusionProcessingTxt2Img.cached_hr_uc = [None, None]\r\n            StableDiffusionProcessingTxt2Img.cached_hr_c = [None, None]\r\n\r\n    def setup_prompts(self):\r\n        super().setup_prompts()\r\n\r\n        if not self.enable_hr:\r\n            return\r\n\r\n        if self.hr_prompt == '':\r\n            self.hr_prompt = self.prompt\r\n\r\n        if self.hr_negative_prompt == '':\r\n            self.hr_negative_prompt = self.negative_prompt\r\n\r\n        if isinstance(self.hr_prompt, list):\r\n            self.all_hr_prompts = self.hr_prompt\r\n        else:\r\n            self.all_hr_prompts = self.batch_size * self.n_iter * [self.hr_prompt]\r\n\r\n        if isinstance(self.hr_negative_prompt, list):\r\n            self.all_hr_negative_prompts = self.hr_negative_prompt\r\n        else:\r\n            self.all_hr_negative_prompts = self.batch_size * self.n_iter * [self.hr_negative_prompt]\r\n\r\n        self.all_hr_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_hr_prompts]\r\n        self.all_hr_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_hr_negative_prompts]\r\n\r\n    def calculate_hr_conds(self):\r\n        if self.hr_c is not None:\r\n            return\r\n\r\n        hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y)\r\n        hr_negative_prompts = prompt_parser.SdConditioning(self.hr_negative_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y, is_negative_prompt=True)\r\n\r\n        sampler_config = sd_samplers.find_sampler_config(self.hr_sampler_name or self.sampler_name)\r\n        steps = self.hr_second_pass_steps or self.steps\r\n        total_steps = sampler_config.total_steps(steps) if sampler_config else steps\r\n\r\n        self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.firstpass_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data, total_steps)\r\n        self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.firstpass_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data, total_steps)\r\n\r\n    def setup_conds(self):\r\n        if self.is_hr_pass:\r\n            # if we are in hr pass right now, the call is being made from the refiner, and we don't need to setup firstpass cons or switch model\r\n            self.hr_c = None\r\n            self.calculate_hr_conds()\r\n            return\r\n\r\n        super().setup_conds()\r\n\r\n        self.hr_uc = None\r\n        self.hr_c = None\r\n\r\n        if self.enable_hr and self.hr_checkpoint_info is None:\r\n            if shared.opts.hires_fix_use_firstpass_conds:\r\n                self.calculate_hr_conds()\r\n\r\n            elif lowvram.is_enabled(shared.sd_model) and shared.sd_model.sd_checkpoint_info == sd_models.select_checkpoint():  # if in lowvram mode, we need to calculate conds right away, before the cond NN is unloaded\r\n                with devices.autocast():\r\n                    extra_networks.activate(self, self.hr_extra_network_data)\r\n\r\n                self.calculate_hr_conds()\r\n\r\n                with devices.autocast():\r\n                    extra_networks.activate(self, self.extra_network_data)\r\n\r\n    def get_conds(self):\r\n        if self.is_hr_pass:\r\n            return self.hr_c, self.hr_uc\r\n\r\n        return super().get_conds()\r\n\r\n    def parse_extra_network_prompts(self):\r\n        res = super().parse_extra_network_prompts()\r\n\r\n        if self.enable_hr:\r\n            self.hr_prompts = self.all_hr_prompts[self.iteration * self.batch_size:(self.iteration + 1) * self.batch_size]\r\n            self.hr_negative_prompts = self.all_hr_negative_prompts[self.iteration * self.batch_size:(self.iteration + 1) * self.batch_size]\r\n\r\n            self.hr_prompts, self.hr_extra_network_data = extra_networks.parse_prompts(self.hr_prompts)\r\n\r\n        return res\r\n\r\n\r\n@dataclass(repr=False)\r\nclass StableDiffusionProcessingImg2Img(StableDiffusionProcessing):\r\n    init_images: list = None\r\n    resize_mode: int = 0\r\n    denoising_strength: float = 0.75\r\n    image_cfg_scale: float = None\r\n    mask: Any = None\r\n    mask_blur_x: int = 4\r\n    mask_blur_y: int = 4\r\n    mask_blur: int = None\r\n    mask_round: bool = True\r\n    inpainting_fill: int = 0\r\n    inpaint_full_res: bool = True\r\n    inpaint_full_res_padding: int = 0\r\n    inpainting_mask_invert: int = 0\r\n    initial_noise_multiplier: float = None\r\n    latent_mask: Image = None\r\n    force_task_id: str = None\r\n\r\n    image_mask: Any = field(default=None, init=False)\r\n\r\n    nmask: torch.Tensor = field(default=None, init=False)\r\n    image_conditioning: torch.Tensor = field(default=None, init=False)\r\n    init_img_hash: str = field(default=None, init=False)\r\n    mask_for_overlay: Image = field(default=None, init=False)\r\n    init_latent: torch.Tensor = field(default=None, init=False)\r\n\r\n    def __post_init__(self):\r\n        super().__post_init__()\r\n\r\n        self.image_mask = self.mask\r\n        self.mask = None\r\n        self.initial_noise_multiplier = opts.initial_noise_multiplier if self.initial_noise_multiplier is None else self.initial_noise_multiplier\r\n\r\n    @property\r\n    def mask_blur(self):\r\n        if self.mask_blur_x == self.mask_blur_y:\r\n            return self.mask_blur_x\r\n        return None\r\n\r\n    @mask_blur.setter\r\n    def mask_blur(self, value):\r\n        if isinstance(value, int):\r\n            self.mask_blur_x = value\r\n            self.mask_blur_y = value\r\n\r\n    def init(self, all_prompts, all_seeds, all_subseeds):\r\n        self.extra_generation_params[\"Denoising strength\"] = self.denoising_strength\r\n\r\n        self.image_cfg_scale: float = self.image_cfg_scale if shared.sd_model.cond_stage_key == \"edit\" else None\r\n\r\n        self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)\r\n        crop_region = None\r\n\r\n        image_mask = self.image_mask\r\n\r\n        if image_mask is not None:\r\n            # image_mask is passed in as RGBA by Gradio to support alpha masks,\r\n            # but we still want to support binary masks.\r\n            image_mask = create_binary_mask(image_mask, round=self.mask_round)\r\n\r\n            if self.inpainting_mask_invert:\r\n                image_mask = ImageOps.invert(image_mask)\r\n                self.extra_generation_params[\"Mask mode\"] = \"Inpaint not masked\"\r\n\r\n            if self.mask_blur_x > 0:\r\n                np_mask = np.array(image_mask)\r\n                kernel_size = 2 * int(2.5 * self.mask_blur_x + 0.5) + 1\r\n                np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), self.mask_blur_x)\r\n                image_mask = Image.fromarray(np_mask)\r\n\r\n            if self.mask_blur_y > 0:\r\n                np_mask = np.array(image_mask)\r\n                kernel_size = 2 * int(2.5 * self.mask_blur_y + 0.5) + 1\r\n                np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y)\r\n                image_mask = Image.fromarray(np_mask)\r\n\r\n            if self.mask_blur_x > 0 or self.mask_blur_y > 0:\r\n                self.extra_generation_params[\"Mask blur\"] = self.mask_blur\r\n\r\n            if self.inpaint_full_res:\r\n                self.mask_for_overlay = image_mask\r\n                mask = image_mask.convert('L')\r\n                crop_region = masking.get_crop_region_v2(mask, self.inpaint_full_res_padding)\r\n                if crop_region:\r\n                    crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)\r\n                    x1, y1, x2, y2 = crop_region\r\n                    mask = mask.crop(crop_region)\r\n                    image_mask = images.resize_image(2, mask, self.width, self.height)\r\n                    self.paste_to = (x1, y1, x2-x1, y2-y1)\r\n                    self.extra_generation_params[\"Inpaint area\"] = \"Only masked\"\r\n                    self.extra_generation_params[\"Masked area padding\"] = self.inpaint_full_res_padding\r\n                else:\r\n                    crop_region = None\r\n                    image_mask = None\r\n                    self.mask_for_overlay = None\r\n                    self.inpaint_full_res = False\r\n                    massage = 'Unable to perform \"Inpaint Only mask\" because mask is blank, switch to img2img mode.'\r\n                    model_hijack.comments.append(massage)\r\n                    logging.info(massage)\r\n            else:\r\n                image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)\r\n                np_mask = np.array(image_mask)\r\n                np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)\r\n                self.mask_for_overlay = Image.fromarray(np_mask)\r\n\r\n            self.overlay_images = []\r\n\r\n        latent_mask = self.latent_mask if self.latent_mask is not None else image_mask\r\n\r\n        add_color_corrections = opts.img2img_color_correction and self.color_corrections is None\r\n        if add_color_corrections:\r\n            self.color_corrections = []\r\n        imgs = []\r\n        for img in self.init_images:\r\n\r\n            # Save init image\r\n            if opts.save_init_img:\r\n                self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest()\r\n                images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False, existing_info=img.info)\r\n\r\n            image = images.flatten(img, opts.img2img_background_color)\r\n\r\n            if crop_region is None and self.resize_mode != 3:\r\n                image = images.resize_image(self.resize_mode, image, self.width, self.height)\r\n\r\n            if image_mask is not None:\r\n                if self.mask_for_overlay.size != (image.width, image.height):\r\n                    self.mask_for_overlay = images.resize_image(self.resize_mode, self.mask_for_overlay, image.width, image.height)\r\n                image_masked = Image.new('RGBa', (image.width, image.height))\r\n                image_masked.paste(image.convert(\"RGBA\").convert(\"RGBa\"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))\r\n\r\n                self.overlay_images.append(image_masked.convert('RGBA'))\r\n\r\n            # crop_region is not None if we are doing inpaint full res\r\n            if crop_region is not None:\r\n                image = image.crop(crop_region)\r\n                image = images.resize_image(2, image, self.width, self.height)\r\n\r\n            if image_mask is not None:\r\n                if self.inpainting_fill != 1:\r\n                    image = masking.fill(image, latent_mask)\r\n\r\n                    if self.inpainting_fill == 0:\r\n                        self.extra_generation_params[\"Masked content\"] = 'fill'\r\n\r\n            if add_color_corrections:\r\n                self.color_corrections.append(setup_color_correction(image))\r\n\r\n            image = np.array(image).astype(np.float32) / 255.0\r\n            image = np.moveaxis(image, 2, 0)\r\n\r\n            imgs.append(image)\r\n\r\n        if len(imgs) == 1:\r\n            batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)\r\n            if self.overlay_images is not None:\r\n                self.overlay_images = self.overlay_images * self.batch_size\r\n\r\n            if self.color_corrections is not None and len(self.color_corrections) == 1:\r\n                self.color_corrections = self.color_corrections * self.batch_size\r\n\r\n        elif len(imgs) <= self.batch_size:\r\n            self.batch_size = len(imgs)\r\n            batch_images = np.array(imgs)\r\n        else:\r\n            raise RuntimeError(f\"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less\")\r\n\r\n        image = torch.from_numpy(batch_images)\r\n        image = image.to(shared.device, dtype=devices.dtype_vae)\r\n\r\n        if opts.sd_vae_encode_method != 'Full':\r\n            self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method\r\n\r\n        self.init_latent = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model)\r\n        devices.torch_gc()\r\n\r\n        if self.resize_mode == 3:\r\n            self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode=\"bilinear\")\r\n\r\n        if image_mask is not None:\r\n            init_mask = latent_mask\r\n            latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))\r\n            latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255\r\n            latmask = latmask[0]\r\n            if self.mask_round:\r\n                latmask = np.around(latmask)\r\n            latmask = np.tile(latmask[None], (self.init_latent.shape[1], 1, 1))\r\n\r\n            self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(devices.dtype)\r\n            self.nmask = torch.asarray(latmask).to(shared.device).type(devices.dtype)\r\n\r\n            # this needs to be fixed to be done in sample() using actual seeds for batches\r\n            if self.inpainting_fill == 2:\r\n                self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask\r\n                self.extra_generation_params[\"Masked content\"] = 'latent noise'\r\n\r\n            elif self.inpainting_fill == 3:\r\n                self.init_latent = self.init_latent * self.mask\r\n                self.extra_generation_params[\"Masked content\"] = 'latent nothing'\r\n\r\n        self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask, self.mask_round)\r\n\r\n    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):\r\n        x = self.rng.next()\r\n\r\n        if self.initial_noise_multiplier != 1.0:\r\n            self.extra_generation_params[\"Noise multiplier\"] = self.initial_noise_multiplier\r\n            x *= self.initial_noise_multiplier\r\n\r\n        if self.scripts is not None:\r\n            self.scripts.process_before_every_sampling(\r\n                p=self,\r\n                x=self.init_latent,\r\n                noise=x,\r\n                c=conditioning,\r\n                uc=unconditional_conditioning\r\n            )\r\n        samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)\r\n\r\n        if self.mask is not None:\r\n            blended_samples = samples * self.nmask + self.init_latent * self.mask\r\n\r\n            if self.scripts is not None:\r\n                mba = scripts.MaskBlendArgs(samples, self.nmask, self.init_latent, self.mask, blended_samples)\r\n                self.scripts.on_mask_blend(self, mba)\r\n                blended_samples = mba.blended_latent\r\n\r\n            samples = blended_samples\r\n\r\n        del x\r\n        devices.torch_gc()\r\n\r\n        return samples\r\n\r\n    def get_token_merging_ratio(self, for_hr=False):\r\n        return self.token_merging_ratio or (\"token_merging_ratio\" in self.override_settings and opts.token_merging_ratio) or opts.token_merging_ratio_img2img or opts.token_merging_ratio\r\n"
  },
  {
    "path": "modules/processing_scripts/comments.py",
    "content": "from modules import scripts, shared, script_callbacks\r\nimport re\r\n\r\n\r\ndef strip_comments(text):\r\n    text = re.sub('(^|\\n)#[^\\n]*(\\n|$)', '\\n', text)  # while line comment\r\n    text = re.sub('#[^\\n]*(\\n|$)', '\\n', text)  # in the middle of the line comment\r\n\r\n    return text\r\n\r\n\r\nclass ScriptStripComments(scripts.Script):\r\n    def title(self):\r\n        return \"Comments\"\r\n\r\n    def show(self, is_img2img):\r\n        return scripts.AlwaysVisible\r\n\r\n    def process(self, p, *args):\r\n        if not shared.opts.enable_prompt_comments:\r\n            return\r\n\r\n        p.all_prompts = [strip_comments(x) for x in p.all_prompts]\r\n        p.all_negative_prompts = [strip_comments(x) for x in p.all_negative_prompts]\r\n\r\n        p.main_prompt = strip_comments(p.main_prompt)\r\n        p.main_negative_prompt = strip_comments(p.main_negative_prompt)\r\n\r\n        if getattr(p, 'enable_hr', False):\r\n            p.all_hr_prompts = [strip_comments(x) for x in p.all_hr_prompts]\r\n            p.all_hr_negative_prompts = [strip_comments(x) for x in p.all_hr_negative_prompts]\r\n\r\n            p.hr_prompt = strip_comments(p.hr_prompt)\r\n            p.hr_negative_prompt = strip_comments(p.hr_negative_prompt)\r\n\r\n\r\ndef before_token_counter(params: script_callbacks.BeforeTokenCounterParams):\r\n    if not shared.opts.enable_prompt_comments:\r\n        return\r\n\r\n    params.prompt = strip_comments(params.prompt)\r\n\r\n\r\nscript_callbacks.on_before_token_counter(before_token_counter)\r\n\r\n\r\nshared.options_templates.update(shared.options_section(('sd', \"Stable Diffusion\", \"sd\"), {\r\n    \"enable_prompt_comments\": shared.OptionInfo(True, \"Enable comments\").info(\"Use # anywhere in the prompt to hide the text between # and the end of the line from the generation.\"),\r\n}))\r\n"
  },
  {
    "path": "modules/processing_scripts/refiner.py",
    "content": "import gradio as gr\r\n\r\nfrom modules import scripts, sd_models\r\nfrom modules.infotext_utils import PasteField\r\nfrom modules.ui_common import create_refresh_button\r\nfrom modules.ui_components import InputAccordion\r\n\r\n\r\nclass ScriptRefiner(scripts.ScriptBuiltinUI):\r\n    section = \"accordions\"\r\n    create_group = False\r\n\r\n    def __init__(self):\r\n        pass\r\n\r\n    def title(self):\r\n        return \"Refiner\"\r\n\r\n    def show(self, is_img2img):\r\n        return scripts.AlwaysVisible\r\n\r\n    def ui(self, is_img2img):\r\n        with InputAccordion(False, label=\"Refiner\", elem_id=self.elem_id(\"enable\")) as enable_refiner:\r\n            with gr.Row():\r\n                refiner_checkpoint = gr.Dropdown(label='Checkpoint', elem_id=self.elem_id(\"checkpoint\"), choices=sd_models.checkpoint_tiles(), value='', tooltip=\"switch to another model in the middle of generation\")\r\n                create_refresh_button(refiner_checkpoint, sd_models.list_models, lambda: {\"choices\": sd_models.checkpoint_tiles()}, self.elem_id(\"checkpoint_refresh\"))\r\n\r\n                refiner_switch_at = gr.Slider(value=0.8, label=\"Switch at\", minimum=0.01, maximum=1.0, step=0.01, elem_id=self.elem_id(\"switch_at\"), tooltip=\"fraction of sampling steps when the switch to refiner model should happen; 1=never, 0.5=switch in the middle of generation\")\r\n\r\n        def lookup_checkpoint(title):\r\n            info = sd_models.get_closet_checkpoint_match(title)\r\n            return None if info is None else info.title\r\n\r\n        self.infotext_fields = [\r\n            PasteField(enable_refiner, lambda d: 'Refiner' in d),\r\n            PasteField(refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner')), api=\"refiner_checkpoint\"),\r\n            PasteField(refiner_switch_at, 'Refiner switch at', api=\"refiner_switch_at\"),\r\n        ]\r\n\r\n        return enable_refiner, refiner_checkpoint, refiner_switch_at\r\n\r\n    def setup(self, p, enable_refiner, refiner_checkpoint, refiner_switch_at):\r\n        # the actual implementation is in sd_samplers_common.py, apply_refiner\r\n\r\n        if not enable_refiner or refiner_checkpoint in (None, \"\", \"None\"):\r\n            p.refiner_checkpoint = None\r\n            p.refiner_switch_at = None\r\n        else:\r\n            p.refiner_checkpoint = refiner_checkpoint\r\n            p.refiner_switch_at = refiner_switch_at\r\n"
  },
  {
    "path": "modules/processing_scripts/sampler.py",
    "content": "import gradio as gr\r\n\r\nfrom modules import scripts, sd_samplers, sd_schedulers, shared\r\nfrom modules.infotext_utils import PasteField\r\nfrom modules.ui_components import FormRow, FormGroup\r\n\r\n\r\nclass ScriptSampler(scripts.ScriptBuiltinUI):\r\n    section = \"sampler\"\r\n\r\n    def __init__(self):\r\n        self.steps = None\r\n        self.sampler_name = None\r\n        self.scheduler = None\r\n\r\n    def title(self):\r\n        return \"Sampler\"\r\n\r\n    def ui(self, is_img2img):\r\n        sampler_names = [x.name for x in sd_samplers.visible_samplers()]\r\n        scheduler_names = [x.label for x in sd_schedulers.schedulers]\r\n\r\n        if shared.opts.samplers_in_dropdown:\r\n            with FormRow(elem_id=f\"sampler_selection_{self.tabname}\"):\r\n                self.sampler_name = gr.Dropdown(label='Sampling method', elem_id=f\"{self.tabname}_sampling\", choices=sampler_names, value=sampler_names[0])\r\n                self.scheduler = gr.Dropdown(label='Schedule type', elem_id=f\"{self.tabname}_scheduler\", choices=scheduler_names, value=scheduler_names[0])\r\n                self.steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f\"{self.tabname}_steps\", label=\"Sampling steps\", value=20)\r\n        else:\r\n            with FormGroup(elem_id=f\"sampler_selection_{self.tabname}\"):\r\n                self.steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f\"{self.tabname}_steps\", label=\"Sampling steps\", value=20)\r\n                self.sampler_name = gr.Radio(label='Sampling method', elem_id=f\"{self.tabname}_sampling\", choices=sampler_names, value=sampler_names[0])\r\n                self.scheduler = gr.Dropdown(label='Schedule type', elem_id=f\"{self.tabname}_scheduler\", choices=scheduler_names, value=scheduler_names[0])\r\n\r\n        self.infotext_fields = [\r\n            PasteField(self.steps, \"Steps\", api=\"steps\"),\r\n            PasteField(self.sampler_name, sd_samplers.get_sampler_from_infotext, api=\"sampler_name\"),\r\n            PasteField(self.scheduler, sd_samplers.get_scheduler_from_infotext, api=\"scheduler\"),\r\n        ]\r\n\r\n        return self.steps, self.sampler_name, self.scheduler\r\n\r\n    def setup(self, p, steps, sampler_name, scheduler):\r\n        p.steps = steps\r\n        p.sampler_name = sampler_name\r\n        p.scheduler = scheduler\r\n"
  },
  {
    "path": "modules/processing_scripts/seed.py",
    "content": "import json\r\n\r\nimport gradio as gr\r\n\r\nfrom modules import scripts, ui, errors\r\nfrom modules.infotext_utils import PasteField\r\nfrom modules.shared import cmd_opts\r\nfrom modules.ui_components import ToolButton\r\nfrom modules import infotext_utils\r\n\r\n\r\nclass ScriptSeed(scripts.ScriptBuiltinUI):\r\n    section = \"seed\"\r\n    create_group = False\r\n\r\n    def __init__(self):\r\n        self.seed = None\r\n        self.reuse_seed = None\r\n        self.reuse_subseed = None\r\n\r\n    def title(self):\r\n        return \"Seed\"\r\n\r\n    def show(self, is_img2img):\r\n        return scripts.AlwaysVisible\r\n\r\n    def ui(self, is_img2img):\r\n        with gr.Row(elem_id=self.elem_id(\"seed_row\")):\r\n            if cmd_opts.use_textbox_seed:\r\n                self.seed = gr.Textbox(label='Seed', value=\"\", elem_id=self.elem_id(\"seed\"), min_width=100)\r\n            else:\r\n                self.seed = gr.Number(label='Seed', value=-1, elem_id=self.elem_id(\"seed\"), min_width=100, precision=0)\r\n\r\n            random_seed = ToolButton(ui.random_symbol, elem_id=self.elem_id(\"random_seed\"), tooltip=\"Set seed to -1, which will cause a new random number to be used every time\")\r\n            reuse_seed = ToolButton(ui.reuse_symbol, elem_id=self.elem_id(\"reuse_seed\"), tooltip=\"Reuse seed from last generation, mostly useful if it was randomized\")\r\n\r\n            seed_checkbox = gr.Checkbox(label='Extra', elem_id=self.elem_id(\"subseed_show\"), value=False)\r\n\r\n        with gr.Group(visible=False, elem_id=self.elem_id(\"seed_extras\")) as seed_extras:\r\n            with gr.Row(elem_id=self.elem_id(\"subseed_row\")):\r\n                subseed = gr.Number(label='Variation seed', value=-1, elem_id=self.elem_id(\"subseed\"), precision=0)\r\n                random_subseed = ToolButton(ui.random_symbol, elem_id=self.elem_id(\"random_subseed\"))\r\n                reuse_subseed = ToolButton(ui.reuse_symbol, elem_id=self.elem_id(\"reuse_subseed\"))\r\n                subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=self.elem_id(\"subseed_strength\"))\r\n\r\n            with gr.Row(elem_id=self.elem_id(\"seed_resize_from_row\")):\r\n                seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label=\"Resize seed from width\", value=0, elem_id=self.elem_id(\"seed_resize_from_w\"))\r\n                seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label=\"Resize seed from height\", value=0, elem_id=self.elem_id(\"seed_resize_from_h\"))\r\n\r\n        random_seed.click(fn=None, _js=\"function(){setRandomSeed('\" + self.elem_id(\"seed\") + \"')}\", show_progress=False, inputs=[], outputs=[])\r\n        random_subseed.click(fn=None, _js=\"function(){setRandomSeed('\" + self.elem_id(\"subseed\") + \"')}\", show_progress=False, inputs=[], outputs=[])\r\n\r\n        seed_checkbox.change(lambda x: gr.update(visible=x), show_progress=False, inputs=[seed_checkbox], outputs=[seed_extras])\r\n\r\n        self.infotext_fields = [\r\n            PasteField(self.seed, \"Seed\", api=\"seed\"),\r\n            PasteField(seed_checkbox, lambda d: \"Variation seed\" in d or \"Seed resize from-1\" in d),\r\n            PasteField(subseed, \"Variation seed\", api=\"subseed\"),\r\n            PasteField(subseed_strength, \"Variation seed strength\", api=\"subseed_strength\"),\r\n            PasteField(seed_resize_from_w, \"Seed resize from-1\", api=\"seed_resize_from_h\"),\r\n            PasteField(seed_resize_from_h, \"Seed resize from-2\", api=\"seed_resize_from_w\"),\r\n        ]\r\n\r\n        self.on_after_component(lambda x: connect_reuse_seed(self.seed, reuse_seed, x.component, False), elem_id=f'generation_info_{self.tabname}')\r\n        self.on_after_component(lambda x: connect_reuse_seed(subseed, reuse_subseed, x.component, True), elem_id=f'generation_info_{self.tabname}')\r\n\r\n        return self.seed, seed_checkbox, subseed, subseed_strength, seed_resize_from_w, seed_resize_from_h\r\n\r\n    def setup(self, p, seed, seed_checkbox, subseed, subseed_strength, seed_resize_from_w, seed_resize_from_h):\r\n        p.seed = seed\r\n\r\n        if seed_checkbox and subseed_strength > 0:\r\n            p.subseed = subseed\r\n            p.subseed_strength = subseed_strength\r\n\r\n        if seed_checkbox and seed_resize_from_w > 0 and seed_resize_from_h > 0:\r\n            p.seed_resize_from_w = seed_resize_from_w\r\n            p.seed_resize_from_h = seed_resize_from_h\r\n\r\n\r\ndef connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, is_subseed):\r\n    \"\"\" Connects a 'reuse (sub)seed' button's click event so that it copies last used\r\n        (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength\r\n        was 0, i.e. no variation seed was used, it copies the normal seed value instead.\"\"\"\r\n\r\n    def copy_seed(gen_info_string: str, index):\r\n        res = -1\r\n        try:\r\n            gen_info = json.loads(gen_info_string)\r\n            infotext = gen_info.get('infotexts')[index]\r\n            gen_parameters = infotext_utils.parse_generation_parameters(infotext, [])\r\n            res = int(gen_parameters.get('Variation seed' if is_subseed else 'Seed', -1))\r\n        except Exception:\r\n            if gen_info_string:\r\n                errors.report(f\"Error retrieving seed from generation info: {gen_info_string}\", exc_info=True)\r\n\r\n        return [res, gr.update()]\r\n\r\n    reuse_seed.click(\r\n        fn=copy_seed,\r\n        _js=\"(x, y) => [x, selected_gallery_index()]\",\r\n        show_progress=False,\r\n        inputs=[generation_info, seed],\r\n        outputs=[seed, seed]\r\n    )\r\n"
  },
  {
    "path": "modules/profiling.py",
    "content": "import torch\r\n\r\nfrom modules import shared, ui_gradio_extensions\r\n\r\n\r\nclass Profiler:\r\n    def __init__(self):\r\n        if not shared.opts.profiling_enable:\r\n            self.profiler = None\r\n            return\r\n\r\n        activities = []\r\n        if \"CPU\" in shared.opts.profiling_activities:\r\n            activities.append(torch.profiler.ProfilerActivity.CPU)\r\n        if \"CUDA\" in shared.opts.profiling_activities:\r\n            activities.append(torch.profiler.ProfilerActivity.CUDA)\r\n\r\n        if not activities:\r\n            self.profiler = None\r\n            return\r\n\r\n        self.profiler = torch.profiler.profile(\r\n            activities=activities,\r\n            record_shapes=shared.opts.profiling_record_shapes,\r\n            profile_memory=shared.opts.profiling_profile_memory,\r\n            with_stack=shared.opts.profiling_with_stack\r\n        )\r\n\r\n    def __enter__(self):\r\n        if self.profiler:\r\n            self.profiler.__enter__()\r\n\r\n        return self\r\n\r\n    def __exit__(self, exc_type, exc, exc_tb):\r\n        if self.profiler:\r\n            shared.state.textinfo = \"Finishing profile...\"\r\n\r\n            self.profiler.__exit__(exc_type, exc, exc_tb)\r\n\r\n            self.profiler.export_chrome_trace(shared.opts.profiling_filename)\r\n\r\n\r\ndef webpath():\r\n    return ui_gradio_extensions.webpath(shared.opts.profiling_filename)\r\n\r\n"
  },
  {
    "path": "modules/progress.py",
    "content": "import base64\r\nimport io\r\nimport time\r\n\r\nimport gradio as gr\r\nfrom pydantic import BaseModel, Field\r\n\r\nfrom modules.shared import opts\r\n\r\nimport modules.shared as shared\r\nfrom collections import OrderedDict\r\nimport string\r\nimport random\r\nfrom typing import List\r\n\r\ncurrent_task = None\r\npending_tasks = OrderedDict()\r\nfinished_tasks = []\r\nrecorded_results = []\r\nrecorded_results_limit = 2\r\n\r\n\r\ndef start_task(id_task):\r\n    global current_task\r\n\r\n    current_task = id_task\r\n    pending_tasks.pop(id_task, None)\r\n\r\n\r\ndef finish_task(id_task):\r\n    global current_task\r\n\r\n    if current_task == id_task:\r\n        current_task = None\r\n\r\n    finished_tasks.append(id_task)\r\n    if len(finished_tasks) > 16:\r\n        finished_tasks.pop(0)\r\n\r\ndef create_task_id(task_type):\r\n    N = 7\r\n    res = ''.join(random.choices(string.ascii_uppercase +\r\n    string.digits, k=N))\r\n    return f\"task({task_type}-{res})\"\r\n\r\ndef record_results(id_task, res):\r\n    recorded_results.append((id_task, res))\r\n    if len(recorded_results) > recorded_results_limit:\r\n        recorded_results.pop(0)\r\n\r\n\r\ndef add_task_to_queue(id_job):\r\n    pending_tasks[id_job] = time.time()\r\n\r\nclass PendingTasksResponse(BaseModel):\r\n    size: int = Field(title=\"Pending task size\")\r\n    tasks: List[str] = Field(title=\"Pending task ids\")\r\n\r\nclass ProgressRequest(BaseModel):\r\n    id_task: str = Field(default=None, title=\"Task ID\", description=\"id of the task to get progress for\")\r\n    id_live_preview: int = Field(default=-1, title=\"Live preview image ID\", description=\"id of last received last preview image\")\r\n    live_preview: bool = Field(default=True, title=\"Include live preview\", description=\"boolean flag indicating whether to include the live preview image\")\r\n\r\n\r\nclass ProgressResponse(BaseModel):\r\n    active: bool = Field(title=\"Whether the task is being worked on right now\")\r\n    queued: bool = Field(title=\"Whether the task is in queue\")\r\n    completed: bool = Field(title=\"Whether the task has already finished\")\r\n    progress: float = Field(default=None, title=\"Progress\", description=\"The progress with a range of 0 to 1\")\r\n    eta: float = Field(default=None, title=\"ETA in secs\")\r\n    live_preview: str = Field(default=None, title=\"Live preview image\", description=\"Current live preview; a data: uri\")\r\n    id_live_preview: int = Field(default=None, title=\"Live preview image ID\", description=\"Send this together with next request to prevent receiving same image\")\r\n    textinfo: str = Field(default=None, title=\"Info text\", description=\"Info text used by WebUI.\")\r\n\r\n\r\ndef setup_progress_api(app):\r\n    app.add_api_route(\"/internal/pending-tasks\", get_pending_tasks, methods=[\"GET\"])\r\n    return app.add_api_route(\"/internal/progress\", progressapi, methods=[\"POST\"], response_model=ProgressResponse)\r\n\r\n\r\ndef get_pending_tasks():\r\n    pending_tasks_ids = list(pending_tasks)\r\n    pending_len = len(pending_tasks_ids)\r\n    return PendingTasksResponse(size=pending_len, tasks=pending_tasks_ids)\r\n\r\n\r\ndef progressapi(req: ProgressRequest):\r\n    active = req.id_task == current_task\r\n    queued = req.id_task in pending_tasks\r\n    completed = req.id_task in finished_tasks\r\n\r\n    if not active:\r\n        textinfo = \"Waiting...\"\r\n        if queued:\r\n            sorted_queued = sorted(pending_tasks.keys(), key=lambda x: pending_tasks[x])\r\n            queue_index = sorted_queued.index(req.id_task)\r\n            textinfo = \"In queue: {}/{}\".format(queue_index + 1, len(sorted_queued))\r\n        return ProgressResponse(active=active, queued=queued, completed=completed, id_live_preview=-1, textinfo=textinfo)\r\n\r\n    progress = 0\r\n\r\n    job_count, job_no = shared.state.job_count, shared.state.job_no\r\n    sampling_steps, sampling_step = shared.state.sampling_steps, shared.state.sampling_step\r\n\r\n    if job_count > 0:\r\n        progress += job_no / job_count\r\n    if sampling_steps > 0 and job_count > 0:\r\n        progress += 1 / job_count * sampling_step / sampling_steps\r\n\r\n    progress = min(progress, 1)\r\n\r\n    elapsed_since_start = time.time() - shared.state.time_start\r\n    predicted_duration = elapsed_since_start / progress if progress > 0 else None\r\n    eta = predicted_duration - elapsed_since_start if predicted_duration is not None else None\r\n\r\n    live_preview = None\r\n    id_live_preview = req.id_live_preview\r\n\r\n    if opts.live_previews_enable and req.live_preview:\r\n        shared.state.set_current_image()\r\n        if shared.state.id_live_preview != req.id_live_preview:\r\n            image = shared.state.current_image\r\n            if image is not None:\r\n                buffered = io.BytesIO()\r\n\r\n                if opts.live_previews_image_format == \"png\":\r\n                    # using optimize for large images takes an enormous amount of time\r\n                    if max(*image.size) <= 256:\r\n                        save_kwargs = {\"optimize\": True}\r\n                    else:\r\n                        save_kwargs = {\"optimize\": False, \"compress_level\": 1}\r\n\r\n                else:\r\n                    save_kwargs = {}\r\n\r\n                image.save(buffered, format=opts.live_previews_image_format, **save_kwargs)\r\n                base64_image = base64.b64encode(buffered.getvalue()).decode('ascii')\r\n                live_preview = f\"data:image/{opts.live_previews_image_format};base64,{base64_image}\"\r\n                id_live_preview = shared.state.id_live_preview\r\n\r\n    return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)\r\n\r\n\r\ndef restore_progress(id_task):\r\n    while id_task == current_task or id_task in pending_tasks:\r\n        time.sleep(0.1)\r\n\r\n    res = next(iter([x[1] for x in recorded_results if id_task == x[0]]), None)\r\n    if res is not None:\r\n        return res\r\n\r\n    return gr.update(), gr.update(), gr.update(), f\"Couldn't restore progress for {id_task}: results either have been discarded or never were obtained\"\r\n"
  },
  {
    "path": "modules/prompt_parser.py",
    "content": "from __future__ import annotations\r\n\r\nimport re\r\nfrom collections import namedtuple\r\nimport lark\r\n\r\n# a prompt like this: \"fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][: in background:0.25] [shoddy:masterful:0.5]\"\r\n# will be represented with prompt_schedule like this (assuming steps=100):\r\n# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']\r\n# [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy']\r\n# [60, 'fantasy landscape with a lake and an oak in foreground in background masterful']\r\n# [75, 'fantasy landscape with a lake and an oak in background masterful']\r\n# [100, 'fantasy landscape with a lake and a christmas tree in background masterful']\r\n\r\nschedule_parser = lark.Lark(r\"\"\"\r\n!start: (prompt | /[][():]/+)*\r\nprompt: (emphasized | scheduled | alternate | plain | WHITESPACE)*\r\n!emphasized: \"(\" prompt \")\"\r\n        | \"(\" prompt \":\" prompt \")\"\r\n        | \"[\" prompt \"]\"\r\nscheduled: \"[\" [prompt \":\"] prompt \":\" [WHITESPACE] NUMBER [WHITESPACE] \"]\"\r\nalternate: \"[\" prompt (\"|\" [prompt])+ \"]\"\r\nWHITESPACE: /\\s+/\r\nplain: /([^\\\\\\[\\]():|]|\\\\.)+/\r\n%import common.SIGNED_NUMBER -> NUMBER\r\n\"\"\")\r\n\r\ndef get_learned_conditioning_prompt_schedules(prompts, base_steps, hires_steps=None, use_old_scheduling=False):\r\n    \"\"\"\r\n    >>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]\r\n    >>> g(\"test\")\r\n    [[10, 'test']]\r\n    >>> g(\"a [b:3]\")\r\n    [[3, 'a '], [10, 'a b']]\r\n    >>> g(\"a [b: 3]\")\r\n    [[3, 'a '], [10, 'a b']]\r\n    >>> g(\"a [[[b]]:2]\")\r\n    [[2, 'a '], [10, 'a [[b]]']]\r\n    >>> g(\"[(a:2):3]\")\r\n    [[3, ''], [10, '(a:2)']]\r\n    >>> g(\"a [b : c : 1] d\")\r\n    [[1, 'a b  d'], [10, 'a  c  d']]\r\n    >>> g(\"a[b:[c:d:2]:1]e\")\r\n    [[1, 'abe'], [2, 'ace'], [10, 'ade']]\r\n    >>> g(\"a [unbalanced\")\r\n    [[10, 'a [unbalanced']]\r\n    >>> g(\"a [b:.5] c\")\r\n    [[5, 'a  c'], [10, 'a b c']]\r\n    >>> g(\"a [{b|d{:.5] c\")  # not handling this right now\r\n    [[5, 'a  c'], [10, 'a {b|d{ c']]\r\n    >>> g(\"((a][:b:c [d:3]\")\r\n    [[3, '((a][:b:c '], [10, '((a][:b:c d']]\r\n    >>> g(\"[a|(b:1.1)]\")\r\n    [[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']]\r\n    >>> g(\"[fe|]male\")\r\n    [[1, 'female'], [2, 'male'], [3, 'female'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'female'], [8, 'male'], [9, 'female'], [10, 'male']]\r\n    >>> g(\"[fe|||]male\")\r\n    [[1, 'female'], [2, 'male'], [3, 'male'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'male'], [8, 'male'], [9, 'female'], [10, 'male']]\r\n    >>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10, 10)[0]\r\n    >>> g(\"a [b:.5] c\")\r\n    [[10, 'a b c']]\r\n    >>> g(\"a [b:1.5] c\")\r\n    [[5, 'a  c'], [10, 'a b c']]\r\n    \"\"\"\r\n\r\n    if hires_steps is None or use_old_scheduling:\r\n        int_offset = 0\r\n        flt_offset = 0\r\n        steps = base_steps\r\n    else:\r\n        int_offset = base_steps\r\n        flt_offset = 1.0\r\n        steps = hires_steps\r\n\r\n    def collect_steps(steps, tree):\r\n        res = [steps]\r\n\r\n        class CollectSteps(lark.Visitor):\r\n            def scheduled(self, tree):\r\n                s = tree.children[-2]\r\n                v = float(s)\r\n                if use_old_scheduling:\r\n                    v = v*steps if v<1 else v\r\n                else:\r\n                    if \".\" in s:\r\n                        v = (v - flt_offset) * steps\r\n                    else:\r\n                        v = (v - int_offset)\r\n                tree.children[-2] = min(steps, int(v))\r\n                if tree.children[-2] >= 1:\r\n                    res.append(tree.children[-2])\r\n\r\n            def alternate(self, tree):\r\n                res.extend(range(1, steps+1))\r\n\r\n        CollectSteps().visit(tree)\r\n        return sorted(set(res))\r\n\r\n    def at_step(step, tree):\r\n        class AtStep(lark.Transformer):\r\n            def scheduled(self, args):\r\n                before, after, _, when, _ = args\r\n                yield before or () if step <= when else after\r\n            def alternate(self, args):\r\n                args = [\"\" if not arg else arg for arg in args]\r\n                yield args[(step - 1) % len(args)]\r\n            def start(self, args):\r\n                def flatten(x):\r\n                    if isinstance(x, str):\r\n                        yield x\r\n                    else:\r\n                        for gen in x:\r\n                            yield from flatten(gen)\r\n                return ''.join(flatten(args))\r\n            def plain(self, args):\r\n                yield args[0].value\r\n            def __default__(self, data, children, meta):\r\n                for child in children:\r\n                    yield child\r\n        return AtStep().transform(tree)\r\n\r\n    def get_schedule(prompt):\r\n        try:\r\n            tree = schedule_parser.parse(prompt)\r\n        except lark.exceptions.LarkError:\r\n            if 0:\r\n                import traceback\r\n                traceback.print_exc()\r\n            return [[steps, prompt]]\r\n        return [[t, at_step(t, tree)] for t in collect_steps(steps, tree)]\r\n\r\n    promptdict = {prompt: get_schedule(prompt) for prompt in set(prompts)}\r\n    return [promptdict[prompt] for prompt in prompts]\r\n\r\n\r\nScheduledPromptConditioning = namedtuple(\"ScheduledPromptConditioning\", [\"end_at_step\", \"cond\"])\r\n\r\n\r\nclass SdConditioning(list):\r\n    \"\"\"\r\n    A list with prompts for stable diffusion's conditioner model.\r\n    Can also specify width and height of created image - SDXL needs it.\r\n    \"\"\"\r\n    def __init__(self, prompts, is_negative_prompt=False, width=None, height=None, copy_from=None):\r\n        super().__init__()\r\n        self.extend(prompts)\r\n\r\n        if copy_from is None:\r\n            copy_from = prompts\r\n\r\n        self.is_negative_prompt = is_negative_prompt or getattr(copy_from, 'is_negative_prompt', False)\r\n        self.width = width or getattr(copy_from, 'width', None)\r\n        self.height = height or getattr(copy_from, 'height', None)\r\n\r\n\r\n\r\ndef get_learned_conditioning(model, prompts: SdConditioning | list[str], steps, hires_steps=None, use_old_scheduling=False):\r\n    \"\"\"converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),\r\n    and the sampling step at which this condition is to be replaced by the next one.\r\n\r\n    Input:\r\n    (model, ['a red crown', 'a [blue:green:5] jeweled crown'], 20)\r\n\r\n    Output:\r\n    [\r\n        [\r\n            ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886,  0.0229, -0.0523,  ..., -0.4901, -0.3066,  0.0674], ..., [ 0.3317, -0.5102, -0.4066,  ...,  0.4119, -0.7647, -1.0160]], device='cuda:0'))\r\n        ],\r\n        [\r\n            ScheduledPromptConditioning(end_at_step=5, cond=tensor([[-0.3886,  0.0229, -0.0522,  ..., -0.4901, -0.3067,  0.0673], ..., [-0.0192,  0.3867, -0.4644,  ...,  0.1135, -0.3696, -0.4625]], device='cuda:0')),\r\n            ScheduledPromptConditioning(end_at_step=20, cond=tensor([[-0.3886,  0.0229, -0.0522,  ..., -0.4901, -0.3067,  0.0673], ..., [-0.7352, -0.4356, -0.7888,  ...,  0.6994, -0.4312, -1.2593]], device='cuda:0'))\r\n        ]\r\n    ]\r\n    \"\"\"\r\n    res = []\r\n\r\n    prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps, hires_steps, use_old_scheduling)\r\n    cache = {}\r\n\r\n    for prompt, prompt_schedule in zip(prompts, prompt_schedules):\r\n\r\n        cached = cache.get(prompt, None)\r\n        if cached is not None:\r\n            res.append(cached)\r\n            continue\r\n\r\n        texts = SdConditioning([x[1] for x in prompt_schedule], copy_from=prompts)\r\n        conds = model.get_learned_conditioning(texts)\r\n\r\n        cond_schedule = []\r\n        for i, (end_at_step, _) in enumerate(prompt_schedule):\r\n            if isinstance(conds, dict):\r\n                cond = {k: v[i] for k, v in conds.items()}\r\n            else:\r\n                cond = conds[i]\r\n\r\n            cond_schedule.append(ScheduledPromptConditioning(end_at_step, cond))\r\n\r\n        cache[prompt] = cond_schedule\r\n        res.append(cond_schedule)\r\n\r\n    return res\r\n\r\n\r\nre_AND = re.compile(r\"\\bAND\\b\")\r\nre_weight = re.compile(r\"^((?:\\s|.)*?)(?:\\s*:\\s*([-+]?(?:\\d+\\.?|\\d*\\.\\d+)))?\\s*$\")\r\n\r\n\r\ndef get_multicond_prompt_list(prompts: SdConditioning | list[str]):\r\n    res_indexes = []\r\n\r\n    prompt_indexes = {}\r\n    prompt_flat_list = SdConditioning(prompts)\r\n    prompt_flat_list.clear()\r\n\r\n    for prompt in prompts:\r\n        subprompts = re_AND.split(prompt)\r\n\r\n        indexes = []\r\n        for subprompt in subprompts:\r\n            match = re_weight.search(subprompt)\r\n\r\n            text, weight = match.groups() if match is not None else (subprompt, 1.0)\r\n\r\n            weight = float(weight) if weight is not None else 1.0\r\n\r\n            index = prompt_indexes.get(text, None)\r\n            if index is None:\r\n                index = len(prompt_flat_list)\r\n                prompt_flat_list.append(text)\r\n                prompt_indexes[text] = index\r\n\r\n            indexes.append((index, weight))\r\n\r\n        res_indexes.append(indexes)\r\n\r\n    return res_indexes, prompt_flat_list, prompt_indexes\r\n\r\n\r\nclass ComposableScheduledPromptConditioning:\r\n    def __init__(self, schedules, weight=1.0):\r\n        self.schedules: list[ScheduledPromptConditioning] = schedules\r\n        self.weight: float = weight\r\n\r\n\r\nclass MulticondLearnedConditioning:\r\n    def __init__(self, shape, batch):\r\n        self.shape: tuple = shape  # the shape field is needed to send this object to DDIM/PLMS\r\n        self.batch: list[list[ComposableScheduledPromptConditioning]] = batch\r\n\r\n\r\ndef get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None, use_old_scheduling=False) -> MulticondLearnedConditioning:\r\n    \"\"\"same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.\r\n    For each prompt, the list is obtained by splitting the prompt using the AND separator.\r\n\r\n    https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/\r\n    \"\"\"\r\n\r\n    res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts)\r\n\r\n    learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps, hires_steps, use_old_scheduling)\r\n\r\n    res = []\r\n    for indexes in res_indexes:\r\n        res.append([ComposableScheduledPromptConditioning(learned_conditioning[i], weight) for i, weight in indexes])\r\n\r\n    return MulticondLearnedConditioning(shape=(len(prompts),), batch=res)\r\n\r\n\r\nclass DictWithShape(dict):\r\n    def __init__(self, x, shape=None):\r\n        super().__init__()\r\n        self.update(x)\r\n\r\n    @property\r\n    def shape(self):\r\n        return self[\"crossattn\"].shape\r\n\r\n\r\ndef reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_step):\r\n    param = c[0][0].cond\r\n    is_dict = isinstance(param, dict)\r\n\r\n    if is_dict:\r\n        dict_cond = param\r\n        res = {k: torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) for k, param in dict_cond.items()}\r\n        res = DictWithShape(res, (len(c),) + dict_cond['crossattn'].shape)\r\n    else:\r\n        res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)\r\n\r\n    for i, cond_schedule in enumerate(c):\r\n        target_index = 0\r\n        for current, entry in enumerate(cond_schedule):\r\n            if current_step <= entry.end_at_step:\r\n                target_index = current\r\n                break\r\n\r\n        if is_dict:\r\n            for k, param in cond_schedule[target_index].cond.items():\r\n                res[k][i] = param\r\n        else:\r\n            res[i] = cond_schedule[target_index].cond\r\n\r\n    return res\r\n\r\n\r\ndef stack_conds(tensors):\r\n    # if prompts have wildly different lengths above the limit we'll get tensors of different shapes\r\n    # and won't be able to torch.stack them. So this fixes that.\r\n    token_count = max([x.shape[0] for x in tensors])\r\n    for i in range(len(tensors)):\r\n        if tensors[i].shape[0] != token_count:\r\n            last_vector = tensors[i][-1:]\r\n            last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])\r\n            tensors[i] = torch.vstack([tensors[i], last_vector_repeated])\r\n\r\n    return torch.stack(tensors)\r\n\r\n\r\n\r\ndef reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):\r\n    param = c.batch[0][0].schedules[0].cond\r\n\r\n    tensors = []\r\n    conds_list = []\r\n\r\n    for composable_prompts in c.batch:\r\n        conds_for_batch = []\r\n\r\n        for composable_prompt in composable_prompts:\r\n            target_index = 0\r\n            for current, entry in enumerate(composable_prompt.schedules):\r\n                if current_step <= entry.end_at_step:\r\n                    target_index = current\r\n                    break\r\n\r\n            conds_for_batch.append((len(tensors), composable_prompt.weight))\r\n            tensors.append(composable_prompt.schedules[target_index].cond)\r\n\r\n        conds_list.append(conds_for_batch)\r\n\r\n    if isinstance(tensors[0], dict):\r\n        keys = list(tensors[0].keys())\r\n        stacked = {k: stack_conds([x[k] for x in tensors]) for k in keys}\r\n        stacked = DictWithShape(stacked, stacked['crossattn'].shape)\r\n    else:\r\n        stacked = stack_conds(tensors).to(device=param.device, dtype=param.dtype)\r\n\r\n    return conds_list, stacked\r\n\r\n\r\nre_attention = re.compile(r\"\"\"\r\n\\\\\\(|\r\n\\\\\\)|\r\n\\\\\\[|\r\n\\\\]|\r\n\\\\\\\\|\r\n\\\\|\r\n\\(|\r\n\\[|\r\n:\\s*([+-]?[.\\d]+)\\s*\\)|\r\n\\)|\r\n]|\r\n[^\\\\()\\[\\]:]+|\r\n:\r\n\"\"\", re.X)\r\n\r\nre_break = re.compile(r\"\\s*\\bBREAK\\b\\s*\", re.S)\r\n\r\ndef parse_prompt_attention(text):\r\n    \"\"\"\r\n    Parses a string with attention tokens and returns a list of pairs: text and its associated weight.\r\n    Accepted tokens are:\r\n      (abc) - increases attention to abc by a multiplier of 1.1\r\n      (abc:3.12) - increases attention to abc by a multiplier of 3.12\r\n      [abc] - decreases attention to abc by a multiplier of 1.1\r\n      \\( - literal character '('\r\n      \\[ - literal character '['\r\n      \\) - literal character ')'\r\n      \\] - literal character ']'\r\n      \\\\ - literal character '\\'\r\n      anything else - just text\r\n\r\n    >>> parse_prompt_attention('normal text')\r\n    [['normal text', 1.0]]\r\n    >>> parse_prompt_attention('an (important) word')\r\n    [['an ', 1.0], ['important', 1.1], [' word', 1.0]]\r\n    >>> parse_prompt_attention('(unbalanced')\r\n    [['unbalanced', 1.1]]\r\n    >>> parse_prompt_attention('\\(literal\\]')\r\n    [['(literal]', 1.0]]\r\n    >>> parse_prompt_attention('(unnecessary)(parens)')\r\n    [['unnecessaryparens', 1.1]]\r\n    >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')\r\n    [['a ', 1.0],\r\n     ['house', 1.5730000000000004],\r\n     [' ', 1.1],\r\n     ['on', 1.0],\r\n     [' a ', 1.1],\r\n     ['hill', 0.55],\r\n     [', sun, ', 1.1],\r\n     ['sky', 1.4641000000000006],\r\n     ['.', 1.1]]\r\n    \"\"\"\r\n\r\n    res = []\r\n    round_brackets = []\r\n    square_brackets = []\r\n\r\n    round_bracket_multiplier = 1.1\r\n    square_bracket_multiplier = 1 / 1.1\r\n\r\n    def multiply_range(start_position, multiplier):\r\n        for p in range(start_position, len(res)):\r\n            res[p][1] *= multiplier\r\n\r\n    for m in re_attention.finditer(text):\r\n        text = m.group(0)\r\n        weight = m.group(1)\r\n\r\n        if text.startswith('\\\\'):\r\n            res.append([text[1:], 1.0])\r\n        elif text == '(':\r\n            round_brackets.append(len(res))\r\n        elif text == '[':\r\n            square_brackets.append(len(res))\r\n        elif weight is not None and round_brackets:\r\n            multiply_range(round_brackets.pop(), float(weight))\r\n        elif text == ')' and round_brackets:\r\n            multiply_range(round_brackets.pop(), round_bracket_multiplier)\r\n        elif text == ']' and square_brackets:\r\n            multiply_range(square_brackets.pop(), square_bracket_multiplier)\r\n        else:\r\n            parts = re.split(re_break, text)\r\n            for i, part in enumerate(parts):\r\n                if i > 0:\r\n                    res.append([\"BREAK\", -1])\r\n                res.append([part, 1.0])\r\n\r\n    for pos in round_brackets:\r\n        multiply_range(pos, round_bracket_multiplier)\r\n\r\n    for pos in square_brackets:\r\n        multiply_range(pos, square_bracket_multiplier)\r\n\r\n    if len(res) == 0:\r\n        res = [[\"\", 1.0]]\r\n\r\n    # merge runs of identical weights\r\n    i = 0\r\n    while i + 1 < len(res):\r\n        if res[i][1] == res[i + 1][1]:\r\n            res[i][0] += res[i + 1][0]\r\n            res.pop(i + 1)\r\n        else:\r\n            i += 1\r\n\r\n    return res\r\n\r\nif __name__ == \"__main__\":\r\n    import doctest\r\n    doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE)\r\nelse:\r\n    import torch  # doctest faster\r\n"
  },
  {
    "path": "modules/realesrgan_model.py",
    "content": "import os\r\n\r\nfrom modules import modelloader, errors\r\nfrom modules.shared import cmd_opts, opts\r\nfrom modules.upscaler import Upscaler, UpscalerData\r\nfrom modules.upscaler_utils import upscale_with_model\r\n\r\n\r\nclass UpscalerRealESRGAN(Upscaler):\r\n    def __init__(self, path):\r\n        self.name = \"RealESRGAN\"\r\n        self.user_path = path\r\n        super().__init__()\r\n        self.enable = True\r\n        self.scalers = []\r\n        scalers = get_realesrgan_models(self)\r\n\r\n        local_model_paths = self.find_models(ext_filter=[\".pth\"])\r\n        for scaler in scalers:\r\n            if scaler.local_data_path.startswith(\"http\"):\r\n                filename = modelloader.friendly_name(scaler.local_data_path)\r\n                local_model_candidates = [local_model for local_model in local_model_paths if local_model.endswith(f\"{filename}.pth\")]\r\n                if local_model_candidates:\r\n                    scaler.local_data_path = local_model_candidates[0]\r\n\r\n            if scaler.name in opts.realesrgan_enabled_models:\r\n                self.scalers.append(scaler)\r\n\r\n    def do_upscale(self, img, path):\r\n        if not self.enable:\r\n            return img\r\n\r\n        try:\r\n            info = self.load_model(path)\r\n        except Exception:\r\n            errors.report(f\"Unable to load RealESRGAN model {path}\", exc_info=True)\r\n            return img\r\n\r\n        model_descriptor = modelloader.load_spandrel_model(\r\n            info.local_data_path,\r\n            device=self.device,\r\n            prefer_half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),\r\n            expected_architecture=\"ESRGAN\",  # \"RealESRGAN\" isn't a specific thing for Spandrel\r\n        )\r\n        return upscale_with_model(\r\n            model_descriptor,\r\n            img,\r\n            tile_size=opts.ESRGAN_tile,\r\n            tile_overlap=opts.ESRGAN_tile_overlap,\r\n            # TODO: `outscale`?\r\n        )\r\n\r\n    def load_model(self, path):\r\n        for scaler in self.scalers:\r\n            if scaler.data_path == path:\r\n                if scaler.local_data_path.startswith(\"http\"):\r\n                    scaler.local_data_path = modelloader.load_file_from_url(\r\n                        scaler.data_path,\r\n                        model_dir=self.model_download_path,\r\n                    )\r\n                if not os.path.exists(scaler.local_data_path):\r\n                    raise FileNotFoundError(f\"RealESRGAN data missing: {scaler.local_data_path}\")\r\n                return scaler\r\n        raise ValueError(f\"Unable to find model info: {path}\")\r\n\r\n\r\ndef get_realesrgan_models(scaler: UpscalerRealESRGAN):\r\n    return [\r\n        UpscalerData(\r\n            name=\"R-ESRGAN General 4xV3\",\r\n            path=\"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth\",\r\n            scale=4,\r\n            upscaler=scaler,\r\n        ),\r\n        UpscalerData(\r\n            name=\"R-ESRGAN General WDN 4xV3\",\r\n            path=\"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth\",\r\n            scale=4,\r\n            upscaler=scaler,\r\n        ),\r\n        UpscalerData(\r\n            name=\"R-ESRGAN AnimeVideo\",\r\n            path=\"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth\",\r\n            scale=4,\r\n            upscaler=scaler,\r\n        ),\r\n        UpscalerData(\r\n            name=\"R-ESRGAN 4x+\",\r\n            path=\"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth\",\r\n            scale=4,\r\n            upscaler=scaler,\r\n        ),\r\n        UpscalerData(\r\n            name=\"R-ESRGAN 4x+ Anime6B\",\r\n            path=\"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth\",\r\n            scale=4,\r\n            upscaler=scaler,\r\n        ),\r\n        UpscalerData(\r\n            name=\"R-ESRGAN 2x+\",\r\n            path=\"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth\",\r\n            scale=2,\r\n            upscaler=scaler,\r\n        ),\r\n    ]\r\n"
  },
  {
    "path": "modules/restart.py",
    "content": "import os\nfrom pathlib import Path\n\nfrom modules.paths_internal import script_path\n\n\ndef is_restartable() -> bool:\n    \"\"\"\n    Return True if the webui is restartable (i.e. there is something watching to restart it with)\n    \"\"\"\n    return bool(os.environ.get('SD_WEBUI_RESTART'))\n\n\ndef restart_program() -> None:\n    \"\"\"creates file tmp/restart and immediately stops the process, which webui.bat/webui.sh interpret as a command to start webui again\"\"\"\n\n    tmpdir = Path(script_path) / \"tmp\"\n    tmpdir.mkdir(parents=True, exist_ok=True)\n    (tmpdir / \"restart\").touch()\n\n    stop_program()\n\n\ndef stop_program() -> None:\n    os._exit(0)\n"
  },
  {
    "path": "modules/rng.py",
    "content": "import torch\r\n\r\nfrom modules import devices, rng_philox, shared\r\n\r\n\r\ndef randn(seed, shape, generator=None):\r\n    \"\"\"Generate a tensor with random numbers from a normal distribution using seed.\r\n\r\n    Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed.\"\"\"\r\n\r\n    manual_seed(seed)\r\n\r\n    if shared.opts.randn_source == \"NV\":\r\n        return torch.asarray((generator or nv_rng).randn(shape), device=devices.device)\r\n\r\n    if shared.opts.randn_source == \"CPU\" or devices.device.type == 'mps':\r\n        return torch.randn(shape, device=devices.cpu, generator=generator).to(devices.device)\r\n\r\n    return torch.randn(shape, device=devices.device, generator=generator)\r\n\r\n\r\ndef randn_local(seed, shape):\r\n    \"\"\"Generate a tensor with random numbers from a normal distribution using seed.\r\n\r\n    Does not change the global random number generator. You can only generate the seed's first tensor using this function.\"\"\"\r\n\r\n    if shared.opts.randn_source == \"NV\":\r\n        rng = rng_philox.Generator(seed)\r\n        return torch.asarray(rng.randn(shape), device=devices.device)\r\n\r\n    local_device = devices.cpu if shared.opts.randn_source == \"CPU\" or devices.device.type == 'mps' else devices.device\r\n    local_generator = torch.Generator(local_device).manual_seed(int(seed))\r\n    return torch.randn(shape, device=local_device, generator=local_generator).to(devices.device)\r\n\r\n\r\ndef randn_like(x):\r\n    \"\"\"Generate a tensor with random numbers from a normal distribution using the previously initialized generator.\r\n\r\n    Use either randn() or manual_seed() to initialize the generator.\"\"\"\r\n\r\n    if shared.opts.randn_source == \"NV\":\r\n        return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype)\r\n\r\n    if shared.opts.randn_source == \"CPU\" or x.device.type == 'mps':\r\n        return torch.randn_like(x, device=devices.cpu).to(x.device)\r\n\r\n    return torch.randn_like(x)\r\n\r\n\r\ndef randn_without_seed(shape, generator=None):\r\n    \"\"\"Generate a tensor with random numbers from a normal distribution using the previously initialized generator.\r\n\r\n    Use either randn() or manual_seed() to initialize the generator.\"\"\"\r\n\r\n    if shared.opts.randn_source == \"NV\":\r\n        return torch.asarray((generator or nv_rng).randn(shape), device=devices.device)\r\n\r\n    if shared.opts.randn_source == \"CPU\" or devices.device.type == 'mps':\r\n        return torch.randn(shape, device=devices.cpu, generator=generator).to(devices.device)\r\n\r\n    return torch.randn(shape, device=devices.device, generator=generator)\r\n\r\n\r\ndef manual_seed(seed):\r\n    \"\"\"Set up a global random number generator using the specified seed.\"\"\"\r\n\r\n    if shared.opts.randn_source == \"NV\":\r\n        global nv_rng\r\n        nv_rng = rng_philox.Generator(seed)\r\n        return\r\n\r\n    torch.manual_seed(seed)\r\n\r\n\r\ndef create_generator(seed):\r\n    if shared.opts.randn_source == \"NV\":\r\n        return rng_philox.Generator(seed)\r\n\r\n    device = devices.cpu if shared.opts.randn_source == \"CPU\" or devices.device.type == 'mps' else devices.device\r\n    generator = torch.Generator(device).manual_seed(int(seed))\r\n    return generator\r\n\r\n\r\n# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3\r\ndef slerp(val, low, high):\r\n    low_norm = low/torch.norm(low, dim=1, keepdim=True)\r\n    high_norm = high/torch.norm(high, dim=1, keepdim=True)\r\n    dot = (low_norm*high_norm).sum(1)\r\n\r\n    if dot.mean() > 0.9995:\r\n        return low * val + high * (1 - val)\r\n\r\n    omega = torch.acos(dot)\r\n    so = torch.sin(omega)\r\n    res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high\r\n    return res\r\n\r\n\r\nclass ImageRNG:\r\n    def __init__(self, shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0):\r\n        self.shape = tuple(map(int, shape))\r\n        self.seeds = seeds\r\n        self.subseeds = subseeds\r\n        self.subseed_strength = subseed_strength\r\n        self.seed_resize_from_h = seed_resize_from_h\r\n        self.seed_resize_from_w = seed_resize_from_w\r\n\r\n        self.generators = [create_generator(seed) for seed in seeds]\r\n\r\n        self.is_first = True\r\n\r\n    def first(self):\r\n        noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], int(self.seed_resize_from_h) // 8, int(self.seed_resize_from_w // 8))\r\n\r\n        xs = []\r\n\r\n        for i, (seed, generator) in enumerate(zip(self.seeds, self.generators)):\r\n            subnoise = None\r\n            if self.subseeds is not None and self.subseed_strength != 0:\r\n                subseed = 0 if i >= len(self.subseeds) else self.subseeds[i]\r\n                subnoise = randn(subseed, noise_shape)\r\n\r\n            if noise_shape != self.shape:\r\n                noise = randn(seed, noise_shape)\r\n            else:\r\n                noise = randn(seed, self.shape, generator=generator)\r\n\r\n            if subnoise is not None:\r\n                noise = slerp(self.subseed_strength, noise, subnoise)\r\n\r\n            if noise_shape != self.shape:\r\n                x = randn(seed, self.shape, generator=generator)\r\n                dx = (self.shape[2] - noise_shape[2]) // 2\r\n                dy = (self.shape[1] - noise_shape[1]) // 2\r\n                w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx\r\n                h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy\r\n                tx = 0 if dx < 0 else dx\r\n                ty = 0 if dy < 0 else dy\r\n                dx = max(-dx, 0)\r\n                dy = max(-dy, 0)\r\n\r\n                x[:, ty:ty + h, tx:tx + w] = noise[:, dy:dy + h, dx:dx + w]\r\n                noise = x\r\n\r\n            xs.append(noise)\r\n\r\n        eta_noise_seed_delta = shared.opts.eta_noise_seed_delta or 0\r\n        if eta_noise_seed_delta:\r\n            self.generators = [create_generator(seed + eta_noise_seed_delta) for seed in self.seeds]\r\n\r\n        return torch.stack(xs).to(shared.device)\r\n\r\n    def next(self):\r\n        if self.is_first:\r\n            self.is_first = False\r\n            return self.first()\r\n\r\n        xs = []\r\n        for generator in self.generators:\r\n            x = randn_without_seed(self.shape, generator=generator)\r\n            xs.append(x)\r\n\r\n        return torch.stack(xs).to(shared.device)\r\n\r\n\r\ndevices.randn = randn\r\ndevices.randn_local = randn_local\r\ndevices.randn_like = randn_like\r\ndevices.randn_without_seed = randn_without_seed\r\ndevices.manual_seed = manual_seed\r\n"
  },
  {
    "path": "modules/rng_philox.py",
    "content": "\"\"\"RNG imitiating torch cuda randn on CPU. You are welcome.\r\n\r\nUsage:\r\n\r\n```\r\ng = Generator(seed=0)\r\nprint(g.randn(shape=(3, 4)))\r\n```\r\n\r\nExpected output:\r\n```\r\n[[-0.92466259 -0.42534415 -2.6438457   0.14518388]\r\n [-0.12086647 -0.57972564 -0.62285122 -0.32838709]\r\n [-1.07454231 -0.36314407 -1.67105067  2.26550497]]\r\n```\r\n\"\"\"\r\n\r\nimport numpy as np\r\n\r\nphilox_m = [0xD2511F53, 0xCD9E8D57]\r\nphilox_w = [0x9E3779B9, 0xBB67AE85]\r\n\r\ntwo_pow32_inv = np.array([2.3283064e-10], dtype=np.float32)\r\ntwo_pow32_inv_2pi = np.array([2.3283064e-10 * 6.2831855], dtype=np.float32)\r\n\r\n\r\ndef uint32(x):\r\n    \"\"\"Converts (N,) np.uint64 array into (2, N) np.unit32 array.\"\"\"\r\n    return x.view(np.uint32).reshape(-1, 2).transpose(1, 0)\r\n\r\n\r\ndef philox4_round(counter, key):\r\n    \"\"\"A single round of the Philox 4x32 random number generator.\"\"\"\r\n\r\n    v1 = uint32(counter[0].astype(np.uint64) * philox_m[0])\r\n    v2 = uint32(counter[2].astype(np.uint64) * philox_m[1])\r\n\r\n    counter[0] = v2[1] ^ counter[1] ^ key[0]\r\n    counter[1] = v2[0]\r\n    counter[2] = v1[1] ^ counter[3] ^ key[1]\r\n    counter[3] = v1[0]\r\n\r\n\r\ndef philox4_32(counter, key, rounds=10):\r\n    \"\"\"Generates 32-bit random numbers using the Philox 4x32 random number generator.\r\n\r\n    Parameters:\r\n        counter (numpy.ndarray): A 4xN array of 32-bit integers representing the counter values (offset into generation).\r\n        key (numpy.ndarray): A 2xN array of 32-bit integers representing the key values (seed).\r\n        rounds (int): The number of rounds to perform.\r\n\r\n    Returns:\r\n        numpy.ndarray: A 4xN array of 32-bit integers containing the generated random numbers.\r\n    \"\"\"\r\n\r\n    for _ in range(rounds - 1):\r\n        philox4_round(counter, key)\r\n\r\n        key[0] = key[0] + philox_w[0]\r\n        key[1] = key[1] + philox_w[1]\r\n\r\n    philox4_round(counter, key)\r\n    return counter\r\n\r\n\r\ndef box_muller(x, y):\r\n    \"\"\"Returns just the first out of two numbers generated by Box–Muller transform algorithm.\"\"\"\r\n    u = x * two_pow32_inv + two_pow32_inv / 2\r\n    v = y * two_pow32_inv_2pi + two_pow32_inv_2pi / 2\r\n\r\n    s = np.sqrt(-2.0 * np.log(u))\r\n\r\n    r1 = s * np.sin(v)\r\n    return r1.astype(np.float32)\r\n\r\n\r\nclass Generator:\r\n    \"\"\"RNG that produces same outputs as torch.randn(..., device='cuda') on CPU\"\"\"\r\n\r\n    def __init__(self, seed):\r\n        self.seed = seed\r\n        self.offset = 0\r\n\r\n    def randn(self, shape):\r\n        \"\"\"Generate a sequence of n standard normal random variables using the Philox 4x32 random number generator and the Box-Muller transform.\"\"\"\r\n\r\n        n = 1\r\n        for x in shape:\r\n            n *= x\r\n\r\n        counter = np.zeros((4, n), dtype=np.uint32)\r\n        counter[0] = self.offset\r\n        counter[2] = np.arange(n, dtype=np.uint32)  # up to 2^32 numbers can be generated - if you want more you'd need to spill into counter[3]\r\n        self.offset += 1\r\n\r\n        key = np.empty(n, dtype=np.uint64)\r\n        key.fill(self.seed)\r\n        key = uint32(key)\r\n\r\n        g = philox4_32(counter, key)\r\n\r\n        return box_muller(g[0], g[1]).reshape(shape)  # discard g[2] and g[3]\r\n"
  },
  {
    "path": "modules/safe.py",
    "content": "# this code is adapted from the script contributed by anon from /h/\r\n\r\nimport pickle\r\nimport collections\r\n\r\nimport torch\r\nimport numpy\r\nimport _codecs\r\nimport zipfile\r\nimport re\r\n\r\n\r\n# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage\r\nfrom modules import errors\r\n\r\nTypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage\r\n\r\ndef encode(*args):\r\n    out = _codecs.encode(*args)\r\n    return out\r\n\r\n\r\nclass RestrictedUnpickler(pickle.Unpickler):\r\n    extra_handler = None\r\n\r\n    def persistent_load(self, saved_id):\r\n        assert saved_id[0] == 'storage'\r\n\r\n        try:\r\n            return TypedStorage(_internal=True)\r\n        except TypeError:\r\n            return TypedStorage()  # PyTorch before 2.0 does not have the _internal argument\r\n\r\n    def find_class(self, module, name):\r\n        if self.extra_handler is not None:\r\n            res = self.extra_handler(module, name)\r\n            if res is not None:\r\n                return res\r\n\r\n        if module == 'collections' and name == 'OrderedDict':\r\n            return getattr(collections, name)\r\n        if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:\r\n            return getattr(torch._utils, name)\r\n        if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32', 'BFloat16Storage']:\r\n            return getattr(torch, name)\r\n        if module == 'torch.nn.modules.container' and name in ['ParameterDict']:\r\n            return getattr(torch.nn.modules.container, name)\r\n        if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']:\r\n            return getattr(numpy.core.multiarray, name)\r\n        if module == 'numpy' and name in ['dtype', 'ndarray']:\r\n            return getattr(numpy, name)\r\n        if module == '_codecs' and name == 'encode':\r\n            return encode\r\n        if module == \"pytorch_lightning.callbacks\" and name == 'model_checkpoint':\r\n            import pytorch_lightning.callbacks\r\n            return pytorch_lightning.callbacks.model_checkpoint\r\n        if module == \"pytorch_lightning.callbacks.model_checkpoint\" and name == 'ModelCheckpoint':\r\n            import pytorch_lightning.callbacks.model_checkpoint\r\n            return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint\r\n        if module == \"__builtin__\" and name == 'set':\r\n            return set\r\n\r\n        # Forbid everything else.\r\n        raise Exception(f\"global '{module}/{name}' is forbidden\")\r\n\r\n\r\n# Regular expression that accepts 'dirname/version', 'dirname/byteorder', 'dirname/data.pkl', '.data/serialization_id', and 'dirname/data/<number>'\r\nallowed_zip_names_re = re.compile(r\"^([^/]+)/((data/\\d+)|version|byteorder|.data/serialization_id|(data\\.pkl))$\")\r\ndata_pkl_re = re.compile(r\"^([^/]+)/data\\.pkl$\")\r\n\r\ndef check_zip_filenames(filename, names):\r\n    for name in names:\r\n        if allowed_zip_names_re.match(name):\r\n            continue\r\n\r\n        raise Exception(f\"bad file inside {filename}: {name}\")\r\n\r\n\r\ndef check_pt(filename, extra_handler):\r\n    try:\r\n\r\n        # new pytorch format is a zip file\r\n        with zipfile.ZipFile(filename) as z:\r\n            check_zip_filenames(filename, z.namelist())\r\n\r\n            # find filename of data.pkl in zip file: '<directory name>/data.pkl'\r\n            data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]\r\n            if len(data_pkl_filenames) == 0:\r\n                raise Exception(f\"data.pkl not found in {filename}\")\r\n            if len(data_pkl_filenames) > 1:\r\n                raise Exception(f\"Multiple data.pkl found in {filename}\")\r\n            with z.open(data_pkl_filenames[0]) as file:\r\n                unpickler = RestrictedUnpickler(file)\r\n                unpickler.extra_handler = extra_handler\r\n                unpickler.load()\r\n\r\n    except zipfile.BadZipfile:\r\n\r\n        # if it's not a zip file, it's an old pytorch format, with five objects written to pickle\r\n        with open(filename, \"rb\") as file:\r\n            unpickler = RestrictedUnpickler(file)\r\n            unpickler.extra_handler = extra_handler\r\n            for _ in range(5):\r\n                unpickler.load()\r\n\r\n\r\ndef load(filename, *args, **kwargs):\r\n    return load_with_extra(filename, *args, extra_handler=global_extra_handler, **kwargs)\r\n\r\n\r\ndef load_with_extra(filename, extra_handler=None, *args, **kwargs):\r\n    \"\"\"\r\n    this function is intended to be used by extensions that want to load models with\r\n    some extra classes in them that the usual unpickler would find suspicious.\r\n\r\n    Use the extra_handler argument to specify a function that takes module and field name as text,\r\n    and returns that field's value:\r\n\r\n    ```python\r\n    def extra(module, name):\r\n        if module == 'collections' and name == 'OrderedDict':\r\n            return collections.OrderedDict\r\n\r\n        return None\r\n\r\n    safe.load_with_extra('model.pt', extra_handler=extra)\r\n    ```\r\n\r\n    The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is\r\n    definitely unsafe.\r\n    \"\"\"\r\n\r\n    from modules import shared\r\n\r\n    try:\r\n        if not shared.cmd_opts.disable_safe_unpickle:\r\n            check_pt(filename, extra_handler)\r\n\r\n    except pickle.UnpicklingError:\r\n        errors.report(\r\n            f\"Error verifying pickled file from {filename}\\n\"\r\n            \"-----> !!!! The file is most likely corrupted !!!! <-----\\n\"\r\n            \"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\\n\\n\",\r\n            exc_info=True,\r\n        )\r\n        return None\r\n    except Exception:\r\n        errors.report(\r\n            f\"Error verifying pickled file from {filename}\\n\"\r\n            f\"The file may be malicious, so the program is not going to read it.\\n\"\r\n            f\"You can skip this check with --disable-safe-unpickle commandline argument.\\n\\n\",\r\n            exc_info=True,\r\n        )\r\n        return None\r\n\r\n    return unsafe_torch_load(filename, *args, **kwargs)\r\n\r\n\r\nclass Extra:\r\n    \"\"\"\r\n    A class for temporarily setting the global handler for when you can't explicitly call load_with_extra\r\n    (because it's not your code making the torch.load call). The intended use is like this:\r\n\r\n```\r\nimport torch\r\nfrom modules import safe\r\n\r\ndef handler(module, name):\r\n    if module == 'torch' and name in ['float64', 'float16']:\r\n        return getattr(torch, name)\r\n\r\n    return None\r\n\r\nwith safe.Extra(handler):\r\n    x = torch.load('model.pt')\r\n```\r\n    \"\"\"\r\n\r\n    def __init__(self, handler):\r\n        self.handler = handler\r\n\r\n    def __enter__(self):\r\n        global global_extra_handler\r\n\r\n        assert global_extra_handler is None, 'already inside an Extra() block'\r\n        global_extra_handler = self.handler\r\n\r\n    def __exit__(self, exc_type, exc_val, exc_tb):\r\n        global global_extra_handler\r\n\r\n        global_extra_handler = None\r\n\r\n\r\nunsafe_torch_load = torch.load\r\ntorch.load = load\r\nglobal_extra_handler = None\r\n"
  },
  {
    "path": "modules/script_callbacks.py",
    "content": "from __future__ import annotations\r\n\r\nimport dataclasses\r\nimport inspect\r\nimport os\r\nfrom typing import Optional, Any\r\n\r\nfrom fastapi import FastAPI\r\nfrom gradio import Blocks\r\n\r\nfrom modules import errors, timer, extensions, shared, util\r\n\r\n\r\ndef report_exception(c, job):\r\n    errors.report(f\"Error executing callback {job} for {c.script}\", exc_info=True)\r\n\r\n\r\nclass ImageSaveParams:\r\n    def __init__(self, image, p, filename, pnginfo):\r\n        self.image = image\r\n        \"\"\"the PIL image itself\"\"\"\r\n\r\n        self.p = p\r\n        \"\"\"p object with processing parameters; either StableDiffusionProcessing or an object with same fields\"\"\"\r\n\r\n        self.filename = filename\r\n        \"\"\"name of file that the image would be saved to\"\"\"\r\n\r\n        self.pnginfo = pnginfo\r\n        \"\"\"dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'\"\"\"\r\n\r\n\r\nclass ExtraNoiseParams:\r\n    def __init__(self, noise, x, xi):\r\n        self.noise = noise\r\n        \"\"\"Random noise generated by the seed\"\"\"\r\n\r\n        self.x = x\r\n        \"\"\"Latent representation of the image\"\"\"\r\n\r\n        self.xi = xi\r\n        \"\"\"Noisy latent representation of the image\"\"\"\r\n\r\n\r\nclass CFGDenoiserParams:\r\n    def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond, denoiser=None):\r\n        self.x = x\r\n        \"\"\"Latent image representation in the process of being denoised\"\"\"\r\n\r\n        self.image_cond = image_cond\r\n        \"\"\"Conditioning image\"\"\"\r\n\r\n        self.sigma = sigma\r\n        \"\"\"Current sigma noise step value\"\"\"\r\n\r\n        self.sampling_step = sampling_step\r\n        \"\"\"Current Sampling step number\"\"\"\r\n\r\n        self.total_sampling_steps = total_sampling_steps\r\n        \"\"\"Total number of sampling steps planned\"\"\"\r\n\r\n        self.text_cond = text_cond\r\n        \"\"\" Encoder hidden states of text conditioning from prompt\"\"\"\r\n\r\n        self.text_uncond = text_uncond\r\n        \"\"\" Encoder hidden states of text conditioning from negative prompt\"\"\"\r\n\r\n        self.denoiser = denoiser\r\n        \"\"\"Current CFGDenoiser object with processing parameters\"\"\"\r\n\r\n\r\nclass CFGDenoisedParams:\r\n    def __init__(self, x, sampling_step, total_sampling_steps, inner_model):\r\n        self.x = x\r\n        \"\"\"Latent image representation in the process of being denoised\"\"\"\r\n\r\n        self.sampling_step = sampling_step\r\n        \"\"\"Current Sampling step number\"\"\"\r\n\r\n        self.total_sampling_steps = total_sampling_steps\r\n        \"\"\"Total number of sampling steps planned\"\"\"\r\n\r\n        self.inner_model = inner_model\r\n        \"\"\"Inner model reference used for denoising\"\"\"\r\n\r\n\r\nclass AfterCFGCallbackParams:\r\n    def __init__(self, x, sampling_step, total_sampling_steps):\r\n        self.x = x\r\n        \"\"\"Latent image representation in the process of being denoised\"\"\"\r\n\r\n        self.sampling_step = sampling_step\r\n        \"\"\"Current Sampling step number\"\"\"\r\n\r\n        self.total_sampling_steps = total_sampling_steps\r\n        \"\"\"Total number of sampling steps planned\"\"\"\r\n\r\n\r\nclass UiTrainTabParams:\r\n    def __init__(self, txt2img_preview_params):\r\n        self.txt2img_preview_params = txt2img_preview_params\r\n\r\n\r\nclass ImageGridLoopParams:\r\n    def __init__(self, imgs, cols, rows):\r\n        self.imgs = imgs\r\n        self.cols = cols\r\n        self.rows = rows\r\n\r\n\r\n@dataclasses.dataclass\r\nclass BeforeTokenCounterParams:\r\n    prompt: str\r\n    steps: int\r\n    styles: list\r\n\r\n    is_positive: bool = True\r\n\r\n\r\n@dataclasses.dataclass\r\nclass ScriptCallback:\r\n    script: str\r\n    callback: any\r\n    name: str = \"unnamed\"\r\n\r\n\r\ndef add_callback(callbacks, fun, *, name=None, category='unknown', filename=None):\r\n    if filename is None:\r\n        stack = [x for x in inspect.stack() if x.filename != __file__]\r\n        filename = stack[0].filename if stack else 'unknown file'\r\n\r\n    extension = extensions.find_extension(filename)\r\n    extension_name = extension.canonical_name if extension else 'base'\r\n\r\n    callback_name = f\"{extension_name}/{os.path.basename(filename)}/{category}\"\r\n    if name is not None:\r\n        callback_name += f'/{name}'\r\n\r\n    unique_callback_name = callback_name\r\n    for index in range(1000):\r\n        existing = any(x.name == unique_callback_name for x in callbacks)\r\n        if not existing:\r\n            break\r\n\r\n        unique_callback_name = f'{callback_name}-{index+1}'\r\n\r\n    callbacks.append(ScriptCallback(filename, fun, unique_callback_name))\r\n\r\n\r\ndef sort_callbacks(category, unordered_callbacks, *, enable_user_sort=True):\r\n    callbacks = unordered_callbacks.copy()\r\n    callback_lookup = {x.name: x for x in callbacks}\r\n    dependencies = {}\r\n\r\n    order_instructions = {}\r\n    for extension in extensions.extensions:\r\n        for order_instruction in extension.metadata.list_callback_order_instructions():\r\n            if order_instruction.name in callback_lookup:\r\n                if order_instruction.name not in order_instructions:\r\n                    order_instructions[order_instruction.name] = []\r\n\r\n                order_instructions[order_instruction.name].append(order_instruction)\r\n\r\n    if order_instructions:\r\n        for callback in callbacks:\r\n            dependencies[callback.name] = []\r\n\r\n        for callback in callbacks:\r\n            for order_instruction in order_instructions.get(callback.name, []):\r\n                for after in order_instruction.after:\r\n                    if after not in callback_lookup:\r\n                        continue\r\n\r\n                    dependencies[callback.name].append(after)\r\n\r\n                for before in order_instruction.before:\r\n                    if before not in callback_lookup:\r\n                        continue\r\n\r\n                    dependencies[before].append(callback.name)\r\n\r\n        sorted_names = util.topological_sort(dependencies)\r\n        callbacks = [callback_lookup[x] for x in sorted_names]\r\n\r\n    if enable_user_sort:\r\n        for name in reversed(getattr(shared.opts, 'prioritized_callbacks_' + category, [])):\r\n            index = next((i for i, callback in enumerate(callbacks) if callback.name == name), None)\r\n            if index is not None:\r\n                callbacks.insert(0, callbacks.pop(index))\r\n\r\n    return callbacks\r\n\r\n\r\ndef ordered_callbacks(category, unordered_callbacks=None, *, enable_user_sort=True):\r\n    if unordered_callbacks is None:\r\n        unordered_callbacks = callback_map.get('callbacks_' + category, [])\r\n\r\n    if not enable_user_sort:\r\n        return sort_callbacks(category, unordered_callbacks, enable_user_sort=False)\r\n\r\n    callbacks = ordered_callbacks_map.get(category)\r\n    if callbacks is not None and len(callbacks) == len(unordered_callbacks):\r\n        return callbacks\r\n\r\n    callbacks = sort_callbacks(category, unordered_callbacks)\r\n\r\n    ordered_callbacks_map[category] = callbacks\r\n    return callbacks\r\n\r\n\r\ndef enumerate_callbacks():\r\n    for category, callbacks in callback_map.items():\r\n        if category.startswith('callbacks_'):\r\n            category = category[10:]\r\n\r\n        yield category, callbacks\r\n\r\n\r\ncallback_map = dict(\r\n    callbacks_app_started=[],\r\n    callbacks_model_loaded=[],\r\n    callbacks_ui_tabs=[],\r\n    callbacks_ui_train_tabs=[],\r\n    callbacks_ui_settings=[],\r\n    callbacks_before_image_saved=[],\r\n    callbacks_image_saved=[],\r\n    callbacks_extra_noise=[],\r\n    callbacks_cfg_denoiser=[],\r\n    callbacks_cfg_denoised=[],\r\n    callbacks_cfg_after_cfg=[],\r\n    callbacks_before_component=[],\r\n    callbacks_after_component=[],\r\n    callbacks_image_grid=[],\r\n    callbacks_infotext_pasted=[],\r\n    callbacks_script_unloaded=[],\r\n    callbacks_before_ui=[],\r\n    callbacks_on_reload=[],\r\n    callbacks_list_optimizers=[],\r\n    callbacks_list_unets=[],\r\n    callbacks_before_token_counter=[],\r\n)\r\n\r\nordered_callbacks_map = {}\r\n\r\n\r\ndef clear_callbacks():\r\n    for callback_list in callback_map.values():\r\n        callback_list.clear()\r\n\r\n    ordered_callbacks_map.clear()\r\n\r\n\r\ndef app_started_callback(demo: Optional[Blocks], app: FastAPI):\r\n    for c in ordered_callbacks('app_started'):\r\n        try:\r\n            c.callback(demo, app)\r\n            timer.startup_timer.record(os.path.basename(c.script))\r\n        except Exception:\r\n            report_exception(c, 'app_started_callback')\r\n\r\n\r\ndef app_reload_callback():\r\n    for c in ordered_callbacks('on_reload'):\r\n        try:\r\n            c.callback()\r\n        except Exception:\r\n            report_exception(c, 'callbacks_on_reload')\r\n\r\n\r\ndef model_loaded_callback(sd_model):\r\n    for c in ordered_callbacks('model_loaded'):\r\n        try:\r\n            c.callback(sd_model)\r\n        except Exception:\r\n            report_exception(c, 'model_loaded_callback')\r\n\r\n\r\ndef ui_tabs_callback():\r\n    res = []\r\n\r\n    for c in ordered_callbacks('ui_tabs'):\r\n        try:\r\n            res += c.callback() or []\r\n        except Exception:\r\n            report_exception(c, 'ui_tabs_callback')\r\n\r\n    return res\r\n\r\n\r\ndef ui_train_tabs_callback(params: UiTrainTabParams):\r\n    for c in ordered_callbacks('ui_train_tabs'):\r\n        try:\r\n            c.callback(params)\r\n        except Exception:\r\n            report_exception(c, 'callbacks_ui_train_tabs')\r\n\r\n\r\ndef ui_settings_callback():\r\n    for c in ordered_callbacks('ui_settings'):\r\n        try:\r\n            c.callback()\r\n        except Exception:\r\n            report_exception(c, 'ui_settings_callback')\r\n\r\n\r\ndef before_image_saved_callback(params: ImageSaveParams):\r\n    for c in ordered_callbacks('before_image_saved'):\r\n        try:\r\n            c.callback(params)\r\n        except Exception:\r\n            report_exception(c, 'before_image_saved_callback')\r\n\r\n\r\ndef image_saved_callback(params: ImageSaveParams):\r\n    for c in ordered_callbacks('image_saved'):\r\n        try:\r\n            c.callback(params)\r\n        except Exception:\r\n            report_exception(c, 'image_saved_callback')\r\n\r\n\r\ndef extra_noise_callback(params: ExtraNoiseParams):\r\n    for c in ordered_callbacks('extra_noise'):\r\n        try:\r\n            c.callback(params)\r\n        except Exception:\r\n            report_exception(c, 'callbacks_extra_noise')\r\n\r\n\r\ndef cfg_denoiser_callback(params: CFGDenoiserParams):\r\n    for c in ordered_callbacks('cfg_denoiser'):\r\n        try:\r\n            c.callback(params)\r\n        except Exception:\r\n            report_exception(c, 'cfg_denoiser_callback')\r\n\r\n\r\ndef cfg_denoised_callback(params: CFGDenoisedParams):\r\n    for c in ordered_callbacks('cfg_denoised'):\r\n        try:\r\n            c.callback(params)\r\n        except Exception:\r\n            report_exception(c, 'cfg_denoised_callback')\r\n\r\n\r\ndef cfg_after_cfg_callback(params: AfterCFGCallbackParams):\r\n    for c in ordered_callbacks('cfg_after_cfg'):\r\n        try:\r\n            c.callback(params)\r\n        except Exception:\r\n            report_exception(c, 'cfg_after_cfg_callback')\r\n\r\n\r\ndef before_component_callback(component, **kwargs):\r\n    for c in ordered_callbacks('before_component'):\r\n        try:\r\n            c.callback(component, **kwargs)\r\n        except Exception:\r\n            report_exception(c, 'before_component_callback')\r\n\r\n\r\ndef after_component_callback(component, **kwargs):\r\n    for c in ordered_callbacks('after_component'):\r\n        try:\r\n            c.callback(component, **kwargs)\r\n        except Exception:\r\n            report_exception(c, 'after_component_callback')\r\n\r\n\r\ndef image_grid_callback(params: ImageGridLoopParams):\r\n    for c in ordered_callbacks('image_grid'):\r\n        try:\r\n            c.callback(params)\r\n        except Exception:\r\n            report_exception(c, 'image_grid')\r\n\r\n\r\ndef infotext_pasted_callback(infotext: str, params: dict[str, Any]):\r\n    for c in ordered_callbacks('infotext_pasted'):\r\n        try:\r\n            c.callback(infotext, params)\r\n        except Exception:\r\n            report_exception(c, 'infotext_pasted')\r\n\r\n\r\ndef script_unloaded_callback():\r\n    for c in reversed(ordered_callbacks('script_unloaded')):\r\n        try:\r\n            c.callback()\r\n        except Exception:\r\n            report_exception(c, 'script_unloaded')\r\n\r\n\r\ndef before_ui_callback():\r\n    for c in reversed(ordered_callbacks('before_ui')):\r\n        try:\r\n            c.callback()\r\n        except Exception:\r\n            report_exception(c, 'before_ui')\r\n\r\n\r\ndef list_optimizers_callback():\r\n    res = []\r\n\r\n    for c in ordered_callbacks('list_optimizers'):\r\n        try:\r\n            c.callback(res)\r\n        except Exception:\r\n            report_exception(c, 'list_optimizers')\r\n\r\n    return res\r\n\r\n\r\ndef list_unets_callback():\r\n    res = []\r\n\r\n    for c in ordered_callbacks('list_unets'):\r\n        try:\r\n            c.callback(res)\r\n        except Exception:\r\n            report_exception(c, 'list_unets')\r\n\r\n    return res\r\n\r\n\r\ndef before_token_counter_callback(params: BeforeTokenCounterParams):\r\n    for c in ordered_callbacks('before_token_counter'):\r\n        try:\r\n            c.callback(params)\r\n        except Exception:\r\n            report_exception(c, 'before_token_counter')\r\n\r\n\r\ndef remove_current_script_callbacks():\r\n    stack = [x for x in inspect.stack() if x.filename != __file__]\r\n    filename = stack[0].filename if stack else 'unknown file'\r\n    if filename == 'unknown file':\r\n        return\r\n    for callback_list in callback_map.values():\r\n        for callback_to_remove in [cb for cb in callback_list if cb.script == filename]:\r\n            callback_list.remove(callback_to_remove)\r\n    for ordered_callbacks_list in ordered_callbacks_map.values():\r\n        for callback_to_remove in [cb for cb in ordered_callbacks_list if cb.script == filename]:\r\n            ordered_callbacks_list.remove(callback_to_remove)\r\n\r\n\r\ndef remove_callbacks_for_function(callback_func):\r\n    for callback_list in callback_map.values():\r\n        for callback_to_remove in [cb for cb in callback_list if cb.callback == callback_func]:\r\n            callback_list.remove(callback_to_remove)\r\n    for ordered_callback_list in ordered_callbacks_map.values():\r\n        for callback_to_remove in [cb for cb in ordered_callback_list if cb.callback == callback_func]:\r\n            ordered_callback_list.remove(callback_to_remove)\r\n\r\n\r\ndef on_app_started(callback, *, name=None):\r\n    \"\"\"register a function to be called when the webui started, the gradio `Block` component and\r\n    fastapi `FastAPI` object are passed as the arguments\"\"\"\r\n    add_callback(callback_map['callbacks_app_started'], callback, name=name, category='app_started')\r\n\r\n\r\ndef on_before_reload(callback, *, name=None):\r\n    \"\"\"register a function to be called just before the server reloads.\"\"\"\r\n    add_callback(callback_map['callbacks_on_reload'], callback, name=name, category='on_reload')\r\n\r\n\r\ndef on_model_loaded(callback, *, name=None):\r\n    \"\"\"register a function to be called when the stable diffusion model is created; the model is\r\n    passed as an argument; this function is also called when the script is reloaded. \"\"\"\r\n    add_callback(callback_map['callbacks_model_loaded'], callback, name=name, category='model_loaded')\r\n\r\n\r\ndef on_ui_tabs(callback, *, name=None):\r\n    \"\"\"register a function to be called when the UI is creating new tabs.\r\n    The function must either return a None, which means no new tabs to be added, or a list, where\r\n    each element is a tuple:\r\n        (gradio_component, title, elem_id)\r\n\r\n    gradio_component is a gradio component to be used for contents of the tab (usually gr.Blocks)\r\n    title is tab text displayed to user in the UI\r\n    elem_id is HTML id for the tab\r\n    \"\"\"\r\n    add_callback(callback_map['callbacks_ui_tabs'], callback, name=name, category='ui_tabs')\r\n\r\n\r\ndef on_ui_train_tabs(callback, *, name=None):\r\n    \"\"\"register a function to be called when the UI is creating new tabs for the train tab.\r\n    Create your new tabs with gr.Tab.\r\n    \"\"\"\r\n    add_callback(callback_map['callbacks_ui_train_tabs'], callback, name=name, category='ui_train_tabs')\r\n\r\n\r\ndef on_ui_settings(callback, *, name=None):\r\n    \"\"\"register a function to be called before UI settings are populated; add your settings\r\n    by using shared.opts.add_option(shared.OptionInfo(...)) \"\"\"\r\n    add_callback(callback_map['callbacks_ui_settings'], callback, name=name, category='ui_settings')\r\n\r\n\r\ndef on_before_image_saved(callback, *, name=None):\r\n    \"\"\"register a function to be called before an image is saved to a file.\r\n    The callback is called with one argument:\r\n        - params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.\r\n    \"\"\"\r\n    add_callback(callback_map['callbacks_before_image_saved'], callback, name=name, category='before_image_saved')\r\n\r\n\r\ndef on_image_saved(callback, *, name=None):\r\n    \"\"\"register a function to be called after an image is saved to a file.\r\n    The callback is called with one argument:\r\n        - params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.\r\n    \"\"\"\r\n    add_callback(callback_map['callbacks_image_saved'], callback, name=name, category='image_saved')\r\n\r\n\r\ndef on_extra_noise(callback, *, name=None):\r\n    \"\"\"register a function to be called before adding extra noise in img2img or hires fix;\r\n    The callback is called with one argument:\r\n        - params: ExtraNoiseParams - contains noise determined by seed and latent representation of image\r\n    \"\"\"\r\n    add_callback(callback_map['callbacks_extra_noise'], callback, name=name, category='extra_noise')\r\n\r\n\r\ndef on_cfg_denoiser(callback, *, name=None):\r\n    \"\"\"register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.\r\n    The callback is called with one argument:\r\n        - params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.\r\n    \"\"\"\r\n    add_callback(callback_map['callbacks_cfg_denoiser'], callback, name=name, category='cfg_denoiser')\r\n\r\n\r\ndef on_cfg_denoised(callback, *, name=None):\r\n    \"\"\"register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.\r\n    The callback is called with one argument:\r\n        - params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details.\r\n    \"\"\"\r\n    add_callback(callback_map['callbacks_cfg_denoised'], callback, name=name, category='cfg_denoised')\r\n\r\n\r\ndef on_cfg_after_cfg(callback, *, name=None):\r\n    \"\"\"register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations are completed.\r\n    The callback is called with one argument:\r\n        - params: AfterCFGCallbackParams - parameters to be passed to the script for post-processing after cfg calculation.\r\n    \"\"\"\r\n    add_callback(callback_map['callbacks_cfg_after_cfg'], callback, name=name, category='cfg_after_cfg')\r\n\r\n\r\ndef on_before_component(callback, *, name=None):\r\n    \"\"\"register a function to be called before a component is created.\r\n    The callback is called with arguments:\r\n        - component - gradio component that is about to be created.\r\n        - **kwargs - args to gradio.components.IOComponent.__init__ function\r\n\r\n    Use elem_id/label fields of kwargs to figure out which component it is.\r\n    This can be useful to inject your own components somewhere in the middle of vanilla UI.\r\n    \"\"\"\r\n    add_callback(callback_map['callbacks_before_component'], callback, name=name, category='before_component')\r\n\r\n\r\ndef on_after_component(callback, *, name=None):\r\n    \"\"\"register a function to be called after a component is created. See on_before_component for more.\"\"\"\r\n    add_callback(callback_map['callbacks_after_component'], callback, name=name, category='after_component')\r\n\r\n\r\ndef on_image_grid(callback, *, name=None):\r\n    \"\"\"register a function to be called before making an image grid.\r\n    The callback is called with one argument:\r\n       - params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified.\r\n    \"\"\"\r\n    add_callback(callback_map['callbacks_image_grid'], callback, name=name, category='image_grid')\r\n\r\n\r\ndef on_infotext_pasted(callback, *, name=None):\r\n    \"\"\"register a function to be called before applying an infotext.\r\n    The callback is called with two arguments:\r\n       - infotext: str - raw infotext.\r\n       - result: dict[str, any] - parsed infotext parameters.\r\n    \"\"\"\r\n    add_callback(callback_map['callbacks_infotext_pasted'], callback, name=name, category='infotext_pasted')\r\n\r\n\r\ndef on_script_unloaded(callback, *, name=None):\r\n    \"\"\"register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that\r\n    the script did should be reverted here\"\"\"\r\n\r\n    add_callback(callback_map['callbacks_script_unloaded'], callback, name=name, category='script_unloaded')\r\n\r\n\r\ndef on_before_ui(callback, *, name=None):\r\n    \"\"\"register a function to be called before the UI is created.\"\"\"\r\n\r\n    add_callback(callback_map['callbacks_before_ui'], callback, name=name, category='before_ui')\r\n\r\n\r\ndef on_list_optimizers(callback, *, name=None):\r\n    \"\"\"register a function to be called when UI is making a list of cross attention optimization options.\r\n    The function will be called with one argument, a list, and shall add objects of type modules.sd_hijack_optimizations.SdOptimization\r\n    to it.\"\"\"\r\n\r\n    add_callback(callback_map['callbacks_list_optimizers'], callback, name=name, category='list_optimizers')\r\n\r\n\r\ndef on_list_unets(callback, *, name=None):\r\n    \"\"\"register a function to be called when UI is making a list of alternative options for unet.\r\n    The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it.\"\"\"\r\n\r\n    add_callback(callback_map['callbacks_list_unets'], callback, name=name, category='list_unets')\r\n\r\n\r\ndef on_before_token_counter(callback, *, name=None):\r\n    \"\"\"register a function to be called when UI is counting tokens for a prompt.\r\n    The function will be called with one argument of type BeforeTokenCounterParams, and should modify its fields if necessary.\"\"\"\r\n\r\n    add_callback(callback_map['callbacks_before_token_counter'], callback, name=name, category='before_token_counter')\r\n"
  },
  {
    "path": "modules/script_loading.py",
    "content": "import os\r\nimport importlib.util\r\n\r\nfrom modules import errors\r\n\r\n\r\nloaded_scripts = {}\r\n\r\n\r\ndef load_module(path):\r\n    module_spec = importlib.util.spec_from_file_location(os.path.basename(path), path)\r\n    module = importlib.util.module_from_spec(module_spec)\r\n    module_spec.loader.exec_module(module)\r\n\r\n    loaded_scripts[path] = module\r\n    return module\r\n\r\n\r\ndef preload_extensions(extensions_dir, parser, extension_list=None):\r\n    if not os.path.isdir(extensions_dir):\r\n        return\r\n\r\n    extensions = extension_list if extension_list is not None else os.listdir(extensions_dir)\r\n    for dirname in sorted(extensions):\r\n        preload_script = os.path.join(extensions_dir, dirname, \"preload.py\")\r\n        if not os.path.isfile(preload_script):\r\n            continue\r\n\r\n        try:\r\n            module = load_module(preload_script)\r\n            if hasattr(module, 'preload'):\r\n                module.preload(parser)\r\n\r\n        except Exception:\r\n            errors.report(f\"Error running preload() for {preload_script}\", exc_info=True)\r\n"
  },
  {
    "path": "modules/scripts.py",
    "content": "import os\r\nimport re\r\nimport sys\r\nimport inspect\r\nfrom collections import namedtuple\r\nfrom dataclasses import dataclass\r\n\r\nimport gradio as gr\r\n\r\nfrom modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer, util\r\n\r\ntopological_sort = util.topological_sort\r\n\r\nAlwaysVisible = object()\r\n\r\nclass MaskBlendArgs:\r\n    def __init__(self, current_latent, nmask, init_latent, mask, blended_latent, denoiser=None, sigma=None):\r\n        self.current_latent = current_latent\r\n        self.nmask = nmask\r\n        self.init_latent = init_latent\r\n        self.mask = mask\r\n        self.blended_latent = blended_latent\r\n\r\n        self.denoiser = denoiser\r\n        self.is_final_blend = denoiser is None\r\n        self.sigma = sigma\r\n\r\nclass PostSampleArgs:\r\n    def __init__(self, samples):\r\n        self.samples = samples\r\n\r\nclass PostprocessImageArgs:\r\n    def __init__(self, image):\r\n        self.image = image\r\n\r\nclass PostProcessMaskOverlayArgs:\r\n    def __init__(self, index, mask_for_overlay, overlay_image):\r\n        self.index = index\r\n        self.mask_for_overlay = mask_for_overlay\r\n        self.overlay_image = overlay_image\r\n\r\nclass PostprocessBatchListArgs:\r\n    def __init__(self, images):\r\n        self.images = images\r\n\r\n\r\n@dataclass\r\nclass OnComponent:\r\n    component: gr.blocks.Block\r\n\r\n\r\nclass Script:\r\n    name = None\r\n    \"\"\"script's internal name derived from title\"\"\"\r\n\r\n    section = None\r\n    \"\"\"name of UI section that the script's controls will be placed into\"\"\"\r\n\r\n    filename = None\r\n    args_from = None\r\n    args_to = None\r\n    alwayson = False\r\n\r\n    is_txt2img = False\r\n    is_img2img = False\r\n    tabname = None\r\n\r\n    group = None\r\n    \"\"\"A gr.Group component that has all script's UI inside it.\"\"\"\r\n\r\n    create_group = True\r\n    \"\"\"If False, for alwayson scripts, a group component will not be created.\"\"\"\r\n\r\n    infotext_fields = None\r\n    \"\"\"if set in ui(), this is a list of pairs of gradio component + text; the text will be used when\r\n    parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example\r\n    \"\"\"\r\n\r\n    paste_field_names = None\r\n    \"\"\"if set in ui(), this is a list of names of infotext fields; the fields will be sent through the\r\n    various \"Send to <X>\" buttons when clicked\r\n    \"\"\"\r\n\r\n    api_info = None\r\n    \"\"\"Generated value of type modules.api.models.ScriptInfo with information about the script for API\"\"\"\r\n\r\n    on_before_component_elem_id = None\r\n    \"\"\"list of callbacks to be called before a component with an elem_id is created\"\"\"\r\n\r\n    on_after_component_elem_id = None\r\n    \"\"\"list of callbacks to be called after a component with an elem_id is created\"\"\"\r\n\r\n    setup_for_ui_only = False\r\n    \"\"\"If true, the script setup will only be run in Gradio UI, not in API\"\"\"\r\n\r\n    controls = None\r\n    \"\"\"A list of controls returned by the ui().\"\"\"\r\n\r\n    def title(self):\r\n        \"\"\"this function should return the title of the script. This is what will be displayed in the dropdown menu.\"\"\"\r\n\r\n        raise NotImplementedError()\r\n\r\n    def ui(self, is_img2img):\r\n        \"\"\"this function should create gradio UI elements. See https://gradio.app/docs/#components\r\n        The return value should be an array of all components that are used in processing.\r\n        Values of those returned components will be passed to run() and process() functions.\r\n        \"\"\"\r\n\r\n        pass\r\n\r\n    def show(self, is_img2img):\r\n        \"\"\"\r\n        is_img2img is True if this function is called for the img2img interface, and False otherwise\r\n\r\n        This function should return:\r\n         - False if the script should not be shown in UI at all\r\n         - True if the script should be shown in UI if it's selected in the scripts dropdown\r\n         - script.AlwaysVisible if the script should be shown in UI at all times\r\n         \"\"\"\r\n\r\n        return True\r\n\r\n    def run(self, p, *args):\r\n        \"\"\"\r\n        This function is called if the script has been selected in the script dropdown.\r\n        It must do all processing and return the Processed object with results, same as\r\n        one returned by processing.process_images.\r\n\r\n        Usually the processing is done by calling the processing.process_images function.\r\n\r\n        args contains all values returned by components from ui()\r\n        \"\"\"\r\n\r\n        pass\r\n\r\n    def setup(self, p, *args):\r\n        \"\"\"For AlwaysVisible scripts, this function is called when the processing object is set up, before any processing starts.\r\n        args contains all values returned by components from ui().\r\n        \"\"\"\r\n        pass\r\n\r\n    def before_process(self, p, *args):\r\n        \"\"\"\r\n        This function is called very early during processing begins for AlwaysVisible scripts.\r\n        You can modify the processing object (p) here, inject hooks, etc.\r\n        args contains all values returned by components from ui()\r\n        \"\"\"\r\n\r\n        pass\r\n\r\n    def process(self, p, *args):\r\n        \"\"\"\r\n        This function is called before processing begins for AlwaysVisible scripts.\r\n        You can modify the processing object (p) here, inject hooks, etc.\r\n        args contains all values returned by components from ui()\r\n        \"\"\"\r\n\r\n        pass\r\n\r\n    def before_process_batch(self, p, *args, **kwargs):\r\n        \"\"\"\r\n        Called before extra networks are parsed from the prompt, so you can add\r\n        new extra network keywords to the prompt with this callback.\r\n\r\n        **kwargs will have those items:\r\n          - batch_number - index of current batch, from 0 to number of batches-1\r\n          - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things\r\n          - seeds - list of seeds for current batch\r\n          - subseeds - list of subseeds for current batch\r\n        \"\"\"\r\n\r\n        pass\r\n\r\n    def after_extra_networks_activate(self, p, *args, **kwargs):\r\n        \"\"\"\r\n        Called after extra networks activation, before conds calculation\r\n        allow modification of the network after extra networks activation been applied\r\n        won't be call if p.disable_extra_networks\r\n\r\n        **kwargs will have those items:\r\n          - batch_number - index of current batch, from 0 to number of batches-1\r\n          - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things\r\n          - seeds - list of seeds for current batch\r\n          - subseeds - list of subseeds for current batch\r\n          - extra_network_data - list of ExtraNetworkParams for current stage\r\n        \"\"\"\r\n        pass\r\n\r\n    def process_before_every_sampling(self, p, *args, **kwargs):\r\n        \"\"\"\r\n        Similar to process(), called before every sampling.\r\n        If you use high-res fix, this will be called two times.\r\n        \"\"\"\r\n        pass\r\n\r\n    def process_batch(self, p, *args, **kwargs):\r\n        \"\"\"\r\n        Same as process(), but called for every batch.\r\n\r\n        **kwargs will have those items:\r\n          - batch_number - index of current batch, from 0 to number of batches-1\r\n          - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things\r\n          - seeds - list of seeds for current batch\r\n          - subseeds - list of subseeds for current batch\r\n        \"\"\"\r\n\r\n        pass\r\n\r\n    def postprocess_batch(self, p, *args, **kwargs):\r\n        \"\"\"\r\n        Same as process_batch(), but called for every batch after it has been generated.\r\n\r\n        **kwargs will have same items as process_batch, and also:\r\n          - batch_number - index of current batch, from 0 to number of batches-1\r\n          - images - torch tensor with all generated images, with values ranging from 0 to 1;\r\n        \"\"\"\r\n\r\n        pass\r\n\r\n    def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, *args, **kwargs):\r\n        \"\"\"\r\n        Same as postprocess_batch(), but receives batch images as a list of 3D tensors instead of a 4D tensor.\r\n        This is useful when you want to update the entire batch instead of individual images.\r\n\r\n        You can modify the postprocessing object (pp) to update the images in the batch, remove images, add images, etc.\r\n        If the number of images is different from the batch size when returning,\r\n        then the script has the responsibility to also update the following attributes in the processing object (p):\r\n          - p.prompts\r\n          - p.negative_prompts\r\n          - p.seeds\r\n          - p.subseeds\r\n\r\n        **kwargs will have same items as process_batch, and also:\r\n          - batch_number - index of current batch, from 0 to number of batches-1\r\n        \"\"\"\r\n\r\n        pass\r\n\r\n    def on_mask_blend(self, p, mba: MaskBlendArgs, *args):\r\n        \"\"\"\r\n        Called in inpainting mode when the original content is blended with the inpainted content.\r\n        This is called at every step in the denoising process and once at the end.\r\n        If is_final_blend is true, this is called for the final blending stage.\r\n        Otherwise, denoiser and sigma are defined and may be used to inform the procedure.\r\n        \"\"\"\r\n\r\n        pass\r\n\r\n    def post_sample(self, p, ps: PostSampleArgs, *args):\r\n        \"\"\"\r\n        Called after the samples have been generated,\r\n        but before they have been decoded by the VAE, if applicable.\r\n        Check getattr(samples, 'already_decoded', False) to test if the images are decoded.\r\n        \"\"\"\r\n\r\n        pass\r\n\r\n    def postprocess_image(self, p, pp: PostprocessImageArgs, *args):\r\n        \"\"\"\r\n        Called for every image after it has been generated.\r\n        \"\"\"\r\n\r\n        pass\r\n\r\n    def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs, *args):\r\n        \"\"\"\r\n        Called for every image after it has been generated.\r\n        \"\"\"\r\n\r\n        pass\r\n\r\n    def postprocess_image_after_composite(self, p, pp: PostprocessImageArgs, *args):\r\n        \"\"\"\r\n        Called for every image after it has been generated.\r\n        Same as postprocess_image but after inpaint_full_res composite\r\n        So that it operates on the full image instead of the inpaint_full_res crop region.\r\n        \"\"\"\r\n\r\n        pass\r\n\r\n    def postprocess(self, p, processed, *args):\r\n        \"\"\"\r\n        This function is called after processing ends for AlwaysVisible scripts.\r\n        args contains all values returned by components from ui()\r\n        \"\"\"\r\n\r\n        pass\r\n\r\n    def before_component(self, component, **kwargs):\r\n        \"\"\"\r\n        Called before a component is created.\r\n        Use elem_id/label fields of kwargs to figure out which component it is.\r\n        This can be useful to inject your own components somewhere in the middle of vanilla UI.\r\n        You can return created components in the ui() function to add them to the list of arguments for your processing functions\r\n        \"\"\"\r\n\r\n        pass\r\n\r\n    def after_component(self, component, **kwargs):\r\n        \"\"\"\r\n        Called after a component is created. Same as above.\r\n        \"\"\"\r\n\r\n        pass\r\n\r\n    def on_before_component(self, callback, *, elem_id):\r\n        \"\"\"\r\n        Calls callback before a component is created. The callback function is called with a single argument of type OnComponent.\r\n\r\n        May be called in show() or ui() - but it may be too late in latter as some components may already be created.\r\n\r\n        This function is an alternative to before_component in that it also cllows to run before a component is created, but\r\n        it doesn't require to be called for every created component - just for the one you need.\r\n        \"\"\"\r\n        if self.on_before_component_elem_id is None:\r\n            self.on_before_component_elem_id = []\r\n\r\n        self.on_before_component_elem_id.append((elem_id, callback))\r\n\r\n    def on_after_component(self, callback, *, elem_id):\r\n        \"\"\"\r\n        Calls callback after a component is created. The callback function is called with a single argument of type OnComponent.\r\n        \"\"\"\r\n        if self.on_after_component_elem_id is None:\r\n            self.on_after_component_elem_id = []\r\n\r\n        self.on_after_component_elem_id.append((elem_id, callback))\r\n\r\n    def describe(self):\r\n        \"\"\"unused\"\"\"\r\n        return \"\"\r\n\r\n    def elem_id(self, item_id):\r\n        \"\"\"helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id\"\"\"\r\n\r\n        need_tabname = self.show(True) == self.show(False)\r\n        tabkind = 'img2img' if self.is_img2img else 'txt2img'\r\n        tabname = f\"{tabkind}_\" if need_tabname else \"\"\r\n        title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\\s', '_', self.title().lower()))\r\n\r\n        return f'script_{tabname}{title}_{item_id}'\r\n\r\n    def before_hr(self, p, *args):\r\n        \"\"\"\r\n        This function is called before hires fix start.\r\n        \"\"\"\r\n        pass\r\n\r\n\r\nclass ScriptBuiltinUI(Script):\r\n    setup_for_ui_only = True\r\n\r\n    def elem_id(self, item_id):\r\n        \"\"\"helper function to generate id for a HTML element, constructs final id out of tab and user-supplied item_id\"\"\"\r\n\r\n        need_tabname = self.show(True) == self.show(False)\r\n        tabname = ('img2img' if self.is_img2img else 'txt2img') + \"_\" if need_tabname else \"\"\r\n\r\n        return f'{tabname}{item_id}'\r\n\r\n    def show(self, is_img2img):\r\n        return AlwaysVisible\r\n\r\n\r\ncurrent_basedir = paths.script_path\r\n\r\n\r\ndef basedir():\r\n    \"\"\"returns the base directory for the current script. For scripts in the main scripts directory,\r\n    this is the main directory (where webui.py resides), and for scripts in extensions directory\r\n    (ie extensions/aesthetic/script/aesthetic.py), this is extension's directory (extensions/aesthetic)\r\n    \"\"\"\r\n    return current_basedir\r\n\r\n\r\nScriptFile = namedtuple(\"ScriptFile\", [\"basedir\", \"filename\", \"path\"])\r\n\r\nscripts_data = []\r\npostprocessing_scripts_data = []\r\nScriptClassData = namedtuple(\"ScriptClassData\", [\"script_class\", \"path\", \"basedir\", \"module\"])\r\n\r\n\r\n@dataclass\r\nclass ScriptWithDependencies:\r\n    script_canonical_name: str\r\n    file: ScriptFile\r\n    requires: list\r\n    load_before: list\r\n    load_after: list\r\n\r\n\r\ndef list_scripts(scriptdirname, extension, *, include_extensions=True):\r\n    scripts = {}\r\n\r\n    loaded_extensions = {ext.canonical_name: ext for ext in extensions.active()}\r\n    loaded_extensions_scripts = {ext.canonical_name: [] for ext in extensions.active()}\r\n\r\n    # build script dependency map\r\n    root_script_basedir = os.path.join(paths.script_path, scriptdirname)\r\n    if os.path.exists(root_script_basedir):\r\n        for filename in sorted(os.listdir(root_script_basedir)):\r\n            if not os.path.isfile(os.path.join(root_script_basedir, filename)):\r\n                continue\r\n\r\n            if os.path.splitext(filename)[1].lower() != extension:\r\n                continue\r\n\r\n            script_file = ScriptFile(paths.script_path, filename, os.path.join(root_script_basedir, filename))\r\n            scripts[filename] = ScriptWithDependencies(filename, script_file, [], [], [])\r\n\r\n    if include_extensions:\r\n        for ext in extensions.active():\r\n            extension_scripts_list = ext.list_files(scriptdirname, extension)\r\n            for extension_script in extension_scripts_list:\r\n                if not os.path.isfile(extension_script.path):\r\n                    continue\r\n\r\n                script_canonical_name = (\"builtin/\" if ext.is_builtin else \"\") + ext.canonical_name + \"/\" + extension_script.filename\r\n                relative_path = scriptdirname + \"/\" + extension_script.filename\r\n\r\n                script = ScriptWithDependencies(\r\n                    script_canonical_name=script_canonical_name,\r\n                    file=extension_script,\r\n                    requires=ext.metadata.get_script_requirements(\"Requires\", relative_path, scriptdirname),\r\n                    load_before=ext.metadata.get_script_requirements(\"Before\", relative_path, scriptdirname),\r\n                    load_after=ext.metadata.get_script_requirements(\"After\", relative_path, scriptdirname),\r\n                )\r\n\r\n                scripts[script_canonical_name] = script\r\n                loaded_extensions_scripts[ext.canonical_name].append(script)\r\n\r\n    for script_canonical_name, script in scripts.items():\r\n        # load before requires inverse dependency\r\n        # in this case, append the script name into the load_after list of the specified script\r\n        for load_before in script.load_before:\r\n            # if this requires an individual script to be loaded before\r\n            other_script = scripts.get(load_before)\r\n            if other_script:\r\n                other_script.load_after.append(script_canonical_name)\r\n\r\n            # if this requires an extension\r\n            other_extension_scripts = loaded_extensions_scripts.get(load_before)\r\n            if other_extension_scripts:\r\n                for other_script in other_extension_scripts:\r\n                    other_script.load_after.append(script_canonical_name)\r\n\r\n        # if After mentions an extension, remove it and instead add all of its scripts\r\n        for load_after in list(script.load_after):\r\n            if load_after not in scripts and load_after in loaded_extensions_scripts:\r\n                script.load_after.remove(load_after)\r\n\r\n                for other_script in loaded_extensions_scripts.get(load_after, []):\r\n                    script.load_after.append(other_script.script_canonical_name)\r\n\r\n    dependencies = {}\r\n\r\n    for script_canonical_name, script in scripts.items():\r\n        for required_script in script.requires:\r\n            if required_script not in scripts and required_script not in loaded_extensions:\r\n                errors.report(f'Script \"{script_canonical_name}\" requires \"{required_script}\" to be loaded, but it is not.', exc_info=False)\r\n\r\n        dependencies[script_canonical_name] = script.load_after\r\n\r\n    ordered_scripts = topological_sort(dependencies)\r\n    scripts_list = [scripts[script_canonical_name].file for script_canonical_name in ordered_scripts]\r\n\r\n    return scripts_list\r\n\r\n\r\ndef list_files_with_name(filename):\r\n    res = []\r\n\r\n    dirs = [paths.script_path] + [ext.path for ext in extensions.active()]\r\n\r\n    for dirpath in dirs:\r\n        if not os.path.isdir(dirpath):\r\n            continue\r\n\r\n        path = os.path.join(dirpath, filename)\r\n        if os.path.isfile(path):\r\n            res.append(path)\r\n\r\n    return res\r\n\r\n\r\ndef load_scripts():\r\n    global current_basedir\r\n    scripts_data.clear()\r\n    postprocessing_scripts_data.clear()\r\n    script_callbacks.clear_callbacks()\r\n\r\n    scripts_list = list_scripts(\"scripts\", \".py\") + list_scripts(\"modules/processing_scripts\", \".py\", include_extensions=False)\r\n\r\n    syspath = sys.path\r\n\r\n    def register_scripts_from_module(module):\r\n        for script_class in module.__dict__.values():\r\n            if not inspect.isclass(script_class):\r\n                continue\r\n\r\n            if issubclass(script_class, Script):\r\n                scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))\r\n            elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):\r\n                postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))\r\n\r\n    # here the scripts_list is already ordered\r\n    # processing_script is not considered though\r\n    for scriptfile in scripts_list:\r\n        try:\r\n            if scriptfile.basedir != paths.script_path:\r\n                sys.path = [scriptfile.basedir] + sys.path\r\n            current_basedir = scriptfile.basedir\r\n\r\n            script_module = script_loading.load_module(scriptfile.path)\r\n            register_scripts_from_module(script_module)\r\n\r\n        except Exception:\r\n            errors.report(f\"Error loading script: {scriptfile.filename}\", exc_info=True)\r\n\r\n        finally:\r\n            sys.path = syspath\r\n            current_basedir = paths.script_path\r\n            timer.startup_timer.record(scriptfile.filename)\r\n\r\n    global scripts_txt2img, scripts_img2img, scripts_postproc\r\n\r\n    scripts_txt2img = ScriptRunner()\r\n    scripts_img2img = ScriptRunner()\r\n    scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()\r\n\r\n\r\ndef wrap_call(func, filename, funcname, *args, default=None, **kwargs):\r\n    try:\r\n        return func(*args, **kwargs)\r\n    except Exception:\r\n        errors.report(f\"Error calling: {filename}/{funcname}\", exc_info=True)\r\n\r\n    return default\r\n\r\n\r\nclass ScriptRunner:\r\n    def __init__(self):\r\n        self.scripts = []\r\n        self.selectable_scripts = []\r\n        self.alwayson_scripts = []\r\n        self.titles = []\r\n        self.title_map = {}\r\n        self.infotext_fields = []\r\n        self.paste_field_names = []\r\n        self.inputs = [None]\r\n\r\n        self.callback_map = {}\r\n        self.callback_names = [\r\n            'before_process',\r\n            'process',\r\n            'before_process_batch',\r\n            'after_extra_networks_activate',\r\n            'process_batch',\r\n            'postprocess',\r\n            'postprocess_batch',\r\n            'postprocess_batch_list',\r\n            'post_sample',\r\n            'on_mask_blend',\r\n            'postprocess_image',\r\n            'postprocess_maskoverlay',\r\n            'postprocess_image_after_composite',\r\n            'before_component',\r\n            'after_component',\r\n        ]\r\n\r\n        self.on_before_component_elem_id = {}\r\n        \"\"\"dict of callbacks to be called before an element is created; key=elem_id, value=list of callbacks\"\"\"\r\n\r\n        self.on_after_component_elem_id = {}\r\n        \"\"\"dict of callbacks to be called after an element is created; key=elem_id, value=list of callbacks\"\"\"\r\n\r\n    def initialize_scripts(self, is_img2img):\r\n        from modules import scripts_auto_postprocessing\r\n\r\n        self.scripts.clear()\r\n        self.alwayson_scripts.clear()\r\n        self.selectable_scripts.clear()\r\n\r\n        auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()\r\n\r\n        for script_data in auto_processing_scripts + scripts_data:\r\n            try:\r\n                script = script_data.script_class()\r\n            except Exception:\r\n                errors.report(f\"Error # failed to initialize Script {script_data.module}: \", exc_info=True)\r\n                continue\r\n\r\n            script.filename = script_data.path\r\n            script.is_txt2img = not is_img2img\r\n            script.is_img2img = is_img2img\r\n            script.tabname = \"img2img\" if is_img2img else \"txt2img\"\r\n\r\n            visibility = script.show(script.is_img2img)\r\n\r\n            if visibility == AlwaysVisible:\r\n                self.scripts.append(script)\r\n                self.alwayson_scripts.append(script)\r\n                script.alwayson = True\r\n\r\n            elif visibility:\r\n                self.scripts.append(script)\r\n                self.selectable_scripts.append(script)\r\n\r\n        self.callback_map.clear()\r\n\r\n        self.apply_on_before_component_callbacks()\r\n\r\n    def apply_on_before_component_callbacks(self):\r\n        for script in self.scripts:\r\n            on_before = script.on_before_component_elem_id or []\r\n            on_after = script.on_after_component_elem_id or []\r\n\r\n            for elem_id, callback in on_before:\r\n                if elem_id not in self.on_before_component_elem_id:\r\n                    self.on_before_component_elem_id[elem_id] = []\r\n\r\n                self.on_before_component_elem_id[elem_id].append((callback, script))\r\n\r\n            for elem_id, callback in on_after:\r\n                if elem_id not in self.on_after_component_elem_id:\r\n                    self.on_after_component_elem_id[elem_id] = []\r\n\r\n                self.on_after_component_elem_id[elem_id].append((callback, script))\r\n\r\n            on_before.clear()\r\n            on_after.clear()\r\n\r\n    def create_script_ui(self, script):\r\n\r\n        script.args_from = len(self.inputs)\r\n        script.args_to = len(self.inputs)\r\n\r\n        try:\r\n            self.create_script_ui_inner(script)\r\n        except Exception:\r\n            errors.report(f\"Error creating UI for {script.name}: \", exc_info=True)\r\n\r\n    def create_script_ui_inner(self, script):\r\n        import modules.api.models as api_models\r\n\r\n        controls = wrap_call(script.ui, script.filename, \"ui\", script.is_img2img)\r\n        script.controls = controls\r\n\r\n        if controls is None:\r\n            return\r\n\r\n        script.name = wrap_call(script.title, script.filename, \"title\", default=script.filename).lower()\r\n\r\n        api_args = []\r\n\r\n        for control in controls:\r\n            control.custom_script_source = os.path.basename(script.filename)\r\n\r\n            arg_info = api_models.ScriptArg(label=control.label or \"\")\r\n\r\n            for field in (\"value\", \"minimum\", \"maximum\", \"step\"):\r\n                v = getattr(control, field, None)\r\n                if v is not None:\r\n                    setattr(arg_info, field, v)\r\n\r\n            choices = getattr(control, 'choices', None)  # as of gradio 3.41, some items in choices are strings, and some are tuples where the first elem is the string\r\n            if choices is not None:\r\n                arg_info.choices = [x[0] if isinstance(x, tuple) else x for x in choices]\r\n\r\n            api_args.append(arg_info)\r\n\r\n        script.api_info = api_models.ScriptInfo(\r\n            name=script.name,\r\n            is_img2img=script.is_img2img,\r\n            is_alwayson=script.alwayson,\r\n            args=api_args,\r\n        )\r\n\r\n        if script.infotext_fields is not None:\r\n            self.infotext_fields += script.infotext_fields\r\n\r\n        if script.paste_field_names is not None:\r\n            self.paste_field_names += script.paste_field_names\r\n\r\n        self.inputs += controls\r\n        script.args_to = len(self.inputs)\r\n\r\n    def setup_ui_for_section(self, section, scriptlist=None):\r\n        if scriptlist is None:\r\n            scriptlist = self.alwayson_scripts\r\n\r\n        for script in scriptlist:\r\n            if script.alwayson and script.section != section:\r\n                continue\r\n\r\n            if script.create_group:\r\n                with gr.Group(visible=script.alwayson) as group:\r\n                    self.create_script_ui(script)\r\n\r\n                script.group = group\r\n            else:\r\n                self.create_script_ui(script)\r\n\r\n    def prepare_ui(self):\r\n        self.inputs = [None]\r\n\r\n    def setup_ui(self):\r\n        all_titles = [wrap_call(script.title, script.filename, \"title\") or script.filename for script in self.scripts]\r\n        self.title_map = {title.lower(): script for title, script in zip(all_titles, self.scripts)}\r\n        self.titles = [wrap_call(script.title, script.filename, \"title\") or f\"{script.filename} [error]\" for script in self.selectable_scripts]\r\n\r\n        self.setup_ui_for_section(None)\r\n\r\n        dropdown = gr.Dropdown(label=\"Script\", elem_id=\"script_list\", choices=[\"None\"] + self.titles, value=\"None\", type=\"index\")\r\n        self.inputs[0] = dropdown\r\n\r\n        self.setup_ui_for_section(None, self.selectable_scripts)\r\n\r\n        def select_script(script_index):\r\n            if script_index is None:\r\n                script_index = 0\r\n            selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None\r\n\r\n            return [gr.update(visible=selected_script == s) for s in self.selectable_scripts]\r\n\r\n        def init_field(title):\r\n            \"\"\"called when an initial value is set from ui-config.json to show script's UI components\"\"\"\r\n\r\n            if title == 'None':\r\n                return\r\n\r\n            script_index = self.titles.index(title)\r\n            self.selectable_scripts[script_index].group.visible = True\r\n\r\n        dropdown.init_field = init_field\r\n\r\n        dropdown.change(\r\n            fn=select_script,\r\n            inputs=[dropdown],\r\n            outputs=[script.group for script in self.selectable_scripts]\r\n        )\r\n\r\n        self.script_load_ctr = 0\r\n\r\n        def onload_script_visibility(params):\r\n            title = params.get('Script', None)\r\n            if title:\r\n                try:\r\n                    title_index = self.titles.index(title)\r\n                    visibility = title_index == self.script_load_ctr\r\n                    self.script_load_ctr = (self.script_load_ctr + 1) % len(self.titles)\r\n                    return gr.update(visible=visibility)\r\n                except ValueError:\r\n                    params['Script'] = None\r\n                    massage = f'Cannot find Script: \"{title}\"'\r\n                    print(massage)\r\n                    gr.Warning(massage)\r\n            return gr.update(visible=False)\r\n\r\n        self.infotext_fields.append((dropdown, lambda x: gr.update(value=x.get('Script', 'None'))))\r\n        self.infotext_fields.extend([(script.group, onload_script_visibility) for script in self.selectable_scripts])\r\n\r\n        self.apply_on_before_component_callbacks()\r\n\r\n        return self.inputs\r\n\r\n    def run(self, p, *args):\r\n        script_index = args[0]\r\n\r\n        if script_index == 0 or script_index is None:\r\n            return None\r\n\r\n        script = self.selectable_scripts[script_index-1]\r\n\r\n        if script is None:\r\n            return None\r\n\r\n        script_args = args[script.args_from:script.args_to]\r\n        processed = script.run(p, *script_args)\r\n\r\n        shared.total_tqdm.clear()\r\n\r\n        return processed\r\n\r\n    def list_scripts_for_method(self, method_name):\r\n        if method_name in ('before_component', 'after_component'):\r\n            return self.scripts\r\n        else:\r\n            return self.alwayson_scripts\r\n\r\n    def create_ordered_callbacks_list(self,  method_name, *, enable_user_sort=True):\r\n        script_list = self.list_scripts_for_method(method_name)\r\n        category = f'script_{method_name}'\r\n        callbacks = []\r\n\r\n        for script in script_list:\r\n            if getattr(script.__class__, method_name, None) == getattr(Script, method_name, None):\r\n                continue\r\n\r\n            script_callbacks.add_callback(callbacks, script, category=category, name=script.__class__.__name__, filename=script.filename)\r\n\r\n        return script_callbacks.sort_callbacks(category, callbacks, enable_user_sort=enable_user_sort)\r\n\r\n    def ordered_callbacks(self, method_name, *, enable_user_sort=True):\r\n        script_list = self.list_scripts_for_method(method_name)\r\n        category = f'script_{method_name}'\r\n\r\n        scrpts_len, callbacks = self.callback_map.get(category, (-1, None))\r\n\r\n        if callbacks is None or scrpts_len != len(script_list):\r\n            callbacks = self.create_ordered_callbacks_list(method_name, enable_user_sort=enable_user_sort)\r\n            self.callback_map[category] = len(script_list), callbacks\r\n\r\n        return callbacks\r\n\r\n    def ordered_scripts(self, method_name):\r\n        return [x.callback for x in self.ordered_callbacks(method_name)]\r\n\r\n    def before_process(self, p):\r\n        for script in self.ordered_scripts('before_process'):\r\n            try:\r\n                script_args = p.script_args[script.args_from:script.args_to]\r\n                script.before_process(p, *script_args)\r\n            except Exception:\r\n                errors.report(f\"Error running before_process: {script.filename}\", exc_info=True)\r\n\r\n    def process(self, p):\r\n        for script in self.ordered_scripts('process'):\r\n            try:\r\n                script_args = p.script_args[script.args_from:script.args_to]\r\n                script.process(p, *script_args)\r\n            except Exception:\r\n                errors.report(f\"Error running process: {script.filename}\", exc_info=True)\r\n\r\n    def process_before_every_sampling(self, p, **kwargs):\r\n        for script in self.ordered_scripts('process_before_every_sampling'):\r\n            try:\r\n                script_args = p.script_args[script.args_from:script.args_to]\r\n                script.process_before_every_sampling(p, *script_args, **kwargs)\r\n            except Exception:\r\n                errors.report(f\"Error running process_before_every_sampling: {script.filename}\", exc_info=True)\r\n\r\n    def before_process_batch(self, p, **kwargs):\r\n        for script in self.ordered_scripts('before_process_batch'):\r\n            try:\r\n                script_args = p.script_args[script.args_from:script.args_to]\r\n                script.before_process_batch(p, *script_args, **kwargs)\r\n            except Exception:\r\n                errors.report(f\"Error running before_process_batch: {script.filename}\", exc_info=True)\r\n\r\n    def after_extra_networks_activate(self, p, **kwargs):\r\n        for script in self.ordered_scripts('after_extra_networks_activate'):\r\n            try:\r\n                script_args = p.script_args[script.args_from:script.args_to]\r\n                script.after_extra_networks_activate(p, *script_args, **kwargs)\r\n            except Exception:\r\n                errors.report(f\"Error running after_extra_networks_activate: {script.filename}\", exc_info=True)\r\n\r\n    def process_batch(self, p, **kwargs):\r\n        for script in self.ordered_scripts('process_batch'):\r\n            try:\r\n                script_args = p.script_args[script.args_from:script.args_to]\r\n                script.process_batch(p, *script_args, **kwargs)\r\n            except Exception:\r\n                errors.report(f\"Error running process_batch: {script.filename}\", exc_info=True)\r\n\r\n    def postprocess(self, p, processed):\r\n        for script in self.ordered_scripts('postprocess'):\r\n            try:\r\n                script_args = p.script_args[script.args_from:script.args_to]\r\n                script.postprocess(p, processed, *script_args)\r\n            except Exception:\r\n                errors.report(f\"Error running postprocess: {script.filename}\", exc_info=True)\r\n\r\n    def postprocess_batch(self, p, images, **kwargs):\r\n        for script in self.ordered_scripts('postprocess_batch'):\r\n            try:\r\n                script_args = p.script_args[script.args_from:script.args_to]\r\n                script.postprocess_batch(p, *script_args, images=images, **kwargs)\r\n            except Exception:\r\n                errors.report(f\"Error running postprocess_batch: {script.filename}\", exc_info=True)\r\n\r\n    def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, **kwargs):\r\n        for script in self.ordered_scripts('postprocess_batch_list'):\r\n            try:\r\n                script_args = p.script_args[script.args_from:script.args_to]\r\n                script.postprocess_batch_list(p, pp, *script_args, **kwargs)\r\n            except Exception:\r\n                errors.report(f\"Error running postprocess_batch_list: {script.filename}\", exc_info=True)\r\n\r\n    def post_sample(self, p, ps: PostSampleArgs):\r\n        for script in self.ordered_scripts('post_sample'):\r\n            try:\r\n                script_args = p.script_args[script.args_from:script.args_to]\r\n                script.post_sample(p, ps, *script_args)\r\n            except Exception:\r\n                errors.report(f\"Error running post_sample: {script.filename}\", exc_info=True)\r\n\r\n    def on_mask_blend(self, p, mba: MaskBlendArgs):\r\n        for script in self.ordered_scripts('on_mask_blend'):\r\n            try:\r\n                script_args = p.script_args[script.args_from:script.args_to]\r\n                script.on_mask_blend(p, mba, *script_args)\r\n            except Exception:\r\n                errors.report(f\"Error running post_sample: {script.filename}\", exc_info=True)\r\n\r\n    def postprocess_image(self, p, pp: PostprocessImageArgs):\r\n        for script in self.ordered_scripts('postprocess_image'):\r\n            try:\r\n                script_args = p.script_args[script.args_from:script.args_to]\r\n                script.postprocess_image(p, pp, *script_args)\r\n            except Exception:\r\n                errors.report(f\"Error running postprocess_image: {script.filename}\", exc_info=True)\r\n\r\n    def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs):\r\n        for script in self.ordered_scripts('postprocess_maskoverlay'):\r\n            try:\r\n                script_args = p.script_args[script.args_from:script.args_to]\r\n                script.postprocess_maskoverlay(p, ppmo, *script_args)\r\n            except Exception:\r\n                errors.report(f\"Error running postprocess_image: {script.filename}\", exc_info=True)\r\n\r\n    def postprocess_image_after_composite(self, p, pp: PostprocessImageArgs):\r\n        for script in self.ordered_scripts('postprocess_image_after_composite'):\r\n            try:\r\n                script_args = p.script_args[script.args_from:script.args_to]\r\n                script.postprocess_image_after_composite(p, pp, *script_args)\r\n            except Exception:\r\n                errors.report(f\"Error running postprocess_image_after_composite: {script.filename}\", exc_info=True)\r\n\r\n    def before_component(self, component, **kwargs):\r\n        for callback, script in self.on_before_component_elem_id.get(kwargs.get(\"elem_id\"), []):\r\n            try:\r\n                callback(OnComponent(component=component))\r\n            except Exception:\r\n                errors.report(f\"Error running on_before_component: {script.filename}\", exc_info=True)\r\n\r\n        for script in self.ordered_scripts('before_component'):\r\n            try:\r\n                script.before_component(component, **kwargs)\r\n            except Exception:\r\n                errors.report(f\"Error running before_component: {script.filename}\", exc_info=True)\r\n\r\n    def after_component(self, component, **kwargs):\r\n        for callback, script in self.on_after_component_elem_id.get(component.elem_id, []):\r\n            try:\r\n                callback(OnComponent(component=component))\r\n            except Exception:\r\n                errors.report(f\"Error running on_after_component: {script.filename}\", exc_info=True)\r\n\r\n        for script in self.ordered_scripts('after_component'):\r\n            try:\r\n                script.after_component(component, **kwargs)\r\n            except Exception:\r\n                errors.report(f\"Error running after_component: {script.filename}\", exc_info=True)\r\n\r\n    def script(self, title):\r\n        return self.title_map.get(title.lower())\r\n\r\n    def reload_sources(self, cache):\r\n        for si, script in list(enumerate(self.scripts)):\r\n            args_from = script.args_from\r\n            args_to = script.args_to\r\n            filename = script.filename\r\n\r\n            module = cache.get(filename, None)\r\n            if module is None:\r\n                module = script_loading.load_module(script.filename)\r\n                cache[filename] = module\r\n\r\n            for script_class in module.__dict__.values():\r\n                if type(script_class) == type and issubclass(script_class, Script):\r\n                    self.scripts[si] = script_class()\r\n                    self.scripts[si].filename = filename\r\n                    self.scripts[si].args_from = args_from\r\n                    self.scripts[si].args_to = args_to\r\n\r\n    def before_hr(self, p):\r\n        for script in self.ordered_scripts('before_hr'):\r\n            try:\r\n                script_args = p.script_args[script.args_from:script.args_to]\r\n                script.before_hr(p, *script_args)\r\n            except Exception:\r\n                errors.report(f\"Error running before_hr: {script.filename}\", exc_info=True)\r\n\r\n    def setup_scrips(self, p, *, is_ui=True):\r\n        for script in self.ordered_scripts('setup'):\r\n            if not is_ui and script.setup_for_ui_only:\r\n                continue\r\n\r\n            try:\r\n                script_args = p.script_args[script.args_from:script.args_to]\r\n                script.setup(p, *script_args)\r\n            except Exception:\r\n                errors.report(f\"Error running setup: {script.filename}\", exc_info=True)\r\n\r\n    def set_named_arg(self, args, script_name, arg_elem_id, value, fuzzy=False):\r\n        \"\"\"Locate an arg of a specific script in script_args and set its value\r\n        Args:\r\n            args: all script args of process p, p.script_args\r\n            script_name: the name target script name to\r\n            arg_elem_id: the elem_id of the target arg\r\n            value: the value to set\r\n            fuzzy: if True, arg_elem_id can be a substring of the control.elem_id else exact match\r\n        Returns:\r\n            Updated script args\r\n        when script_name in not found or arg_elem_id is not found in script controls, raise RuntimeError\r\n        \"\"\"\r\n        script = next((x for x in self.scripts if x.name == script_name), None)\r\n        if script is None:\r\n            raise RuntimeError(f\"script {script_name} not found\")\r\n\r\n        for i, control in enumerate(script.controls):\r\n            if arg_elem_id in control.elem_id if fuzzy else arg_elem_id == control.elem_id:\r\n                index = script.args_from + i\r\n\r\n                if isinstance(args, tuple):\r\n                    return args[:index] + (value,) + args[index + 1:]\r\n                elif isinstance(args, list):\r\n                    args[index] = value\r\n                    return args\r\n                else:\r\n                    raise RuntimeError(f\"args is not a list or tuple, but {type(args)}\")\r\n        raise RuntimeError(f\"arg_elem_id {arg_elem_id} not found in script {script_name}\")\r\n\r\n\r\nscripts_txt2img: ScriptRunner = None\r\nscripts_img2img: ScriptRunner = None\r\nscripts_postproc: scripts_postprocessing.ScriptPostprocessingRunner = None\r\nscripts_current: ScriptRunner = None\r\n\r\n\r\ndef reload_script_body_only():\r\n    cache = {}\r\n    scripts_txt2img.reload_sources(cache)\r\n    scripts_img2img.reload_sources(cache)\r\n\r\n\r\nreload_scripts = load_scripts  # compatibility alias\r\n"
  },
  {
    "path": "modules/scripts_auto_postprocessing.py",
    "content": "from modules import scripts, scripts_postprocessing, shared\r\n\r\n\r\nclass ScriptPostprocessingForMainUI(scripts.Script):\r\n    def __init__(self, script_postproc):\r\n        self.script: scripts_postprocessing.ScriptPostprocessing = script_postproc\r\n        self.postprocessing_controls = None\r\n\r\n    def title(self):\r\n        return self.script.name\r\n\r\n    def show(self, is_img2img):\r\n        return scripts.AlwaysVisible\r\n\r\n    def ui(self, is_img2img):\r\n        self.postprocessing_controls = self.script.ui()\r\n        return self.postprocessing_controls.values()\r\n\r\n    def postprocess_image(self, p, script_pp, *args):\r\n        args_dict = dict(zip(self.postprocessing_controls, args))\r\n\r\n        pp = scripts_postprocessing.PostprocessedImage(script_pp.image)\r\n        pp.info = {}\r\n        self.script.process(pp, **args_dict)\r\n        p.extra_generation_params.update(pp.info)\r\n        script_pp.image = pp.image\r\n\r\n\r\ndef create_auto_preprocessing_script_data():\r\n    from modules import scripts\r\n\r\n    res = []\r\n\r\n    for name in shared.opts.postprocessing_enable_in_main_ui:\r\n        script = next(iter([x for x in scripts.postprocessing_scripts_data if x.script_class.name == name]), None)\r\n        if script is None:\r\n            continue\r\n\r\n        constructor = lambda s=script: ScriptPostprocessingForMainUI(s.script_class())\r\n        res.append(scripts.ScriptClassData(script_class=constructor, path=script.path, basedir=script.basedir, module=script.module))\r\n\r\n    return res\r\n"
  },
  {
    "path": "modules/scripts_postprocessing.py",
    "content": "import dataclasses\r\nimport os\r\nimport gradio as gr\r\n\r\nfrom modules import errors, shared\r\n\r\n\r\n@dataclasses.dataclass\r\nclass PostprocessedImageSharedInfo:\r\n    target_width: int = None\r\n    target_height: int = None\r\n\r\n\r\nclass PostprocessedImage:\r\n    def __init__(self, image):\r\n        self.image = image\r\n        self.info = {}\r\n        self.shared = PostprocessedImageSharedInfo()\r\n        self.extra_images = []\r\n        self.nametags = []\r\n        self.disable_processing = False\r\n        self.caption = None\r\n\r\n    def get_suffix(self, used_suffixes=None):\r\n        used_suffixes = {} if used_suffixes is None else used_suffixes\r\n        suffix = \"-\".join(self.nametags)\r\n        if suffix:\r\n            suffix = \"-\" + suffix\r\n\r\n        if suffix not in used_suffixes:\r\n            used_suffixes[suffix] = 1\r\n            return suffix\r\n\r\n        for i in range(1, 100):\r\n            proposed_suffix = suffix + \"-\" + str(i)\r\n\r\n            if proposed_suffix not in used_suffixes:\r\n                used_suffixes[proposed_suffix] = 1\r\n                return proposed_suffix\r\n\r\n        return suffix\r\n\r\n    def create_copy(self, new_image, *, nametags=None, disable_processing=False):\r\n        pp = PostprocessedImage(new_image)\r\n        pp.shared = self.shared\r\n        pp.nametags = self.nametags.copy()\r\n        pp.info = self.info.copy()\r\n        pp.disable_processing = disable_processing\r\n\r\n        if nametags is not None:\r\n            pp.nametags += nametags\r\n\r\n        return pp\r\n\r\n\r\nclass ScriptPostprocessing:\r\n    filename = None\r\n    controls = None\r\n    args_from = None\r\n    args_to = None\r\n\r\n    order = 1000\r\n    \"\"\"scripts will be ordred by this value in postprocessing UI\"\"\"\r\n\r\n    name = None\r\n    \"\"\"this function should return the title of the script.\"\"\"\r\n\r\n    group = None\r\n    \"\"\"A gr.Group component that has all script's UI inside it\"\"\"\r\n\r\n    def ui(self):\r\n        \"\"\"\r\n        This function should create gradio UI elements. See https://gradio.app/docs/#components\r\n        The return value should be a dictionary that maps parameter names to components used in processing.\r\n        Values of those components will be passed to process() function.\r\n        \"\"\"\r\n\r\n        pass\r\n\r\n    def process(self, pp: PostprocessedImage, **args):\r\n        \"\"\"\r\n        This function is called to postprocess the image.\r\n        args contains a dictionary with all values returned by components from ui()\r\n        \"\"\"\r\n\r\n        pass\r\n\r\n    def process_firstpass(self, pp: PostprocessedImage, **args):\r\n        \"\"\"\r\n        Called for all scripts before calling process(). Scripts can examine the image here and set fields\r\n        of the pp object to communicate things to other scripts.\r\n        args contains a dictionary with all values returned by components from ui()\r\n        \"\"\"\r\n\r\n        pass\r\n\r\n    def image_changed(self):\r\n        pass\r\n\r\n\r\ndef wrap_call(func, filename, funcname, *args, default=None, **kwargs):\r\n    try:\r\n        res = func(*args, **kwargs)\r\n        return res\r\n    except Exception as e:\r\n        errors.display(e, f\"calling {filename}/{funcname}\")\r\n\r\n    return default\r\n\r\n\r\nclass ScriptPostprocessingRunner:\r\n    def __init__(self):\r\n        self.scripts = None\r\n        self.ui_created = False\r\n\r\n    def initialize_scripts(self, scripts_data):\r\n        self.scripts = []\r\n\r\n        for script_data in scripts_data:\r\n            script: ScriptPostprocessing = script_data.script_class()\r\n            script.filename = script_data.path\r\n\r\n            if script.name == \"Simple Upscale\":\r\n                continue\r\n\r\n            self.scripts.append(script)\r\n\r\n    def create_script_ui(self, script, inputs):\r\n        script.args_from = len(inputs)\r\n        script.args_to = len(inputs)\r\n\r\n        script.controls = wrap_call(script.ui, script.filename, \"ui\")\r\n\r\n        for control in script.controls.values():\r\n            control.custom_script_source = os.path.basename(script.filename)\r\n\r\n        inputs += list(script.controls.values())\r\n        script.args_to = len(inputs)\r\n\r\n    def scripts_in_preferred_order(self):\r\n        if self.scripts is None:\r\n            import modules.scripts\r\n            self.initialize_scripts(modules.scripts.postprocessing_scripts_data)\r\n\r\n        scripts_order = shared.opts.postprocessing_operation_order\r\n        scripts_filter_out = set(shared.opts.postprocessing_disable_in_extras)\r\n\r\n        def script_score(name):\r\n            for i, possible_match in enumerate(scripts_order):\r\n                if possible_match == name:\r\n                    return i\r\n\r\n            return len(self.scripts)\r\n\r\n        filtered_scripts = [script for script in self.scripts if script.name not in scripts_filter_out]\r\n        script_scores = {script.name: (script_score(script.name), script.order, script.name, original_index) for original_index, script in enumerate(filtered_scripts)}\r\n\r\n        return sorted(filtered_scripts, key=lambda x: script_scores[x.name])\r\n\r\n    def setup_ui(self):\r\n        inputs = []\r\n\r\n        for script in self.scripts_in_preferred_order():\r\n            with gr.Row() as group:\r\n                self.create_script_ui(script, inputs)\r\n\r\n            script.group = group\r\n\r\n        self.ui_created = True\r\n        return inputs\r\n\r\n    def run(self, pp: PostprocessedImage, args):\r\n        scripts = []\r\n\r\n        for script in self.scripts_in_preferred_order():\r\n            script_args = args[script.args_from:script.args_to]\r\n\r\n            process_args = {}\r\n            for (name, _component), value in zip(script.controls.items(), script_args):\r\n                process_args[name] = value\r\n\r\n            scripts.append((script, process_args))\r\n\r\n        for script, process_args in scripts:\r\n            script.process_firstpass(pp, **process_args)\r\n\r\n        all_images = [pp]\r\n\r\n        for script, process_args in scripts:\r\n            if shared.state.skipped:\r\n                break\r\n\r\n            shared.state.job = script.name\r\n\r\n            for single_image in all_images.copy():\r\n\r\n                if not single_image.disable_processing:\r\n                    script.process(single_image, **process_args)\r\n\r\n                for extra_image in single_image.extra_images:\r\n                    if not isinstance(extra_image, PostprocessedImage):\r\n                        extra_image = single_image.create_copy(extra_image)\r\n\r\n                    all_images.append(extra_image)\r\n\r\n                single_image.extra_images.clear()\r\n\r\n        pp.extra_images = all_images[1:]\r\n\r\n    def create_args_for_run(self, scripts_args):\r\n        if not self.ui_created:\r\n            with gr.Blocks(analytics_enabled=False):\r\n                self.setup_ui()\r\n\r\n        scripts = self.scripts_in_preferred_order()\r\n        args = [None] * max([x.args_to for x in scripts])\r\n\r\n        for script in scripts:\r\n            script_args_dict = scripts_args.get(script.name, None)\r\n            if script_args_dict is not None:\r\n\r\n                for i, name in enumerate(script.controls):\r\n                    args[script.args_from + i] = script_args_dict.get(name, None)\r\n\r\n        return args\r\n\r\n    def image_changed(self):\r\n        for script in self.scripts_in_preferred_order():\r\n            script.image_changed()\r\n\r\n"
  },
  {
    "path": "modules/sd_disable_initialization.py",
    "content": "import ldm.modules.encoders.modules\r\nimport open_clip\r\nimport torch\r\nimport transformers.utils.hub\r\n\r\nfrom modules import shared\r\n\r\n\r\nclass ReplaceHelper:\r\n    def __init__(self):\r\n        self.replaced = []\r\n\r\n    def replace(self, obj, field, func):\r\n        original = getattr(obj, field, None)\r\n        if original is None:\r\n            return None\r\n\r\n        self.replaced.append((obj, field, original))\r\n        setattr(obj, field, func)\r\n\r\n        return original\r\n\r\n    def restore(self):\r\n        for obj, field, original in self.replaced:\r\n            setattr(obj, field, original)\r\n\r\n        self.replaced.clear()\r\n\r\n\r\nclass DisableInitialization(ReplaceHelper):\r\n    \"\"\"\r\n    When an object of this class enters a `with` block, it starts:\r\n    - preventing torch's layer initialization functions from working\r\n    - changes CLIP and OpenCLIP to not download model weights\r\n    - changes CLIP to not make requests to check if there is a new version of a file you already have\r\n\r\n    When it leaves the block, it reverts everything to how it was before.\r\n\r\n    Use it like this:\r\n    ```\r\n    with DisableInitialization():\r\n        do_things()\r\n    ```\r\n    \"\"\"\r\n\r\n    def __init__(self, disable_clip=True):\r\n        super().__init__()\r\n        self.disable_clip = disable_clip\r\n\r\n    def replace(self, obj, field, func):\r\n        original = getattr(obj, field, None)\r\n        if original is None:\r\n            return None\r\n\r\n        self.replaced.append((obj, field, original))\r\n        setattr(obj, field, func)\r\n\r\n        return original\r\n\r\n    def __enter__(self):\r\n        def do_nothing(*args, **kwargs):\r\n            pass\r\n\r\n        def create_model_and_transforms_without_pretrained(*args, pretrained=None, **kwargs):\r\n            return self.create_model_and_transforms(*args, pretrained=None, **kwargs)\r\n\r\n        def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):\r\n            res = self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs)\r\n            res.name_or_path = pretrained_model_name_or_path\r\n            return res\r\n\r\n        def transformers_modeling_utils_load_pretrained_model(*args, **kwargs):\r\n            args = args[0:3] + ('/', ) + args[4:]  # resolved_archive_file; must set it to something to prevent what seems to be a bug\r\n            return self.transformers_modeling_utils_load_pretrained_model(*args, **kwargs)\r\n\r\n        def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs):\r\n\r\n            # this file is always 404, prevent making request\r\n            if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json' or url == 'openai/clip-vit-large-patch14' and args[0] == 'added_tokens.json':\r\n                return None\r\n\r\n            try:\r\n                res = original(url, *args, local_files_only=True, **kwargs)\r\n                if res is None:\r\n                    res = original(url, *args, local_files_only=False, **kwargs)\r\n                return res\r\n            except Exception:\r\n                return original(url, *args, local_files_only=False, **kwargs)\r\n\r\n        def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):\r\n            return transformers_utils_hub_get_file_from_cache(self.transformers_utils_hub_get_from_cache, url, *args, **kwargs)\r\n\r\n        def transformers_tokenization_utils_base_cached_file(url, *args, local_files_only=False, **kwargs):\r\n            return transformers_utils_hub_get_file_from_cache(self.transformers_tokenization_utils_base_cached_file, url, *args, **kwargs)\r\n\r\n        def transformers_configuration_utils_cached_file(url, *args, local_files_only=False, **kwargs):\r\n            return transformers_utils_hub_get_file_from_cache(self.transformers_configuration_utils_cached_file, url, *args, **kwargs)\r\n\r\n        self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing)\r\n        self.replace(torch.nn.init, '_no_grad_normal_', do_nothing)\r\n        self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing)\r\n\r\n        if self.disable_clip:\r\n            self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)\r\n            self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)\r\n            self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)\r\n            self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file)\r\n            self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)\r\n            self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)\r\n\r\n    def __exit__(self, exc_type, exc_val, exc_tb):\r\n        self.restore()\r\n\r\n\r\nclass InitializeOnMeta(ReplaceHelper):\r\n    \"\"\"\r\n    Context manager that causes all parameters for linear/conv2d/mha layers to be allocated on meta device,\r\n    which results in those parameters having no values and taking no memory. model.to() will be broken and\r\n    will need to be repaired by using LoadStateDictOnMeta below when loading params from state dict.\r\n\r\n    Usage:\r\n    ```\r\n    with sd_disable_initialization.InitializeOnMeta():\r\n        sd_model = instantiate_from_config(sd_config.model)\r\n    ```\r\n    \"\"\"\r\n\r\n    def __enter__(self):\r\n        if shared.cmd_opts.disable_model_loading_ram_optimization:\r\n            return\r\n\r\n        def set_device(x):\r\n            x[\"device\"] = \"meta\"\r\n            return x\r\n\r\n        linear_init = self.replace(torch.nn.Linear, '__init__', lambda *args, **kwargs: linear_init(*args, **set_device(kwargs)))\r\n        conv2d_init = self.replace(torch.nn.Conv2d, '__init__', lambda *args, **kwargs: conv2d_init(*args, **set_device(kwargs)))\r\n        mha_init = self.replace(torch.nn.MultiheadAttention, '__init__', lambda *args, **kwargs: mha_init(*args, **set_device(kwargs)))\r\n        self.replace(torch.nn.Module, 'to', lambda *args, **kwargs: None)\r\n\r\n    def __exit__(self, exc_type, exc_val, exc_tb):\r\n        self.restore()\r\n\r\n\r\nclass LoadStateDictOnMeta(ReplaceHelper):\r\n    \"\"\"\r\n    Context manager that allows to read parameters from state_dict into a model that has some of its parameters in the meta device.\r\n    As those parameters are read from state_dict, they will be deleted from it, so by the end state_dict will be mostly empty, to save memory.\r\n    Meant to be used together with InitializeOnMeta above.\r\n\r\n    Usage:\r\n    ```\r\n    with sd_disable_initialization.LoadStateDictOnMeta(state_dict):\r\n        model.load_state_dict(state_dict, strict=False)\r\n    ```\r\n    \"\"\"\r\n\r\n    def __init__(self, state_dict, device, weight_dtype_conversion=None):\r\n        super().__init__()\r\n        self.state_dict = state_dict\r\n        self.device = device\r\n        self.weight_dtype_conversion = weight_dtype_conversion or {}\r\n        self.default_dtype = self.weight_dtype_conversion.get('')\r\n\r\n    def get_weight_dtype(self, key):\r\n        key_first_term, _ = key.split('.', 1)\r\n        return self.weight_dtype_conversion.get(key_first_term, self.default_dtype)\r\n\r\n    def __enter__(self):\r\n        if shared.cmd_opts.disable_model_loading_ram_optimization:\r\n            return\r\n\r\n        sd = self.state_dict\r\n        device = self.device\r\n\r\n        def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs):\r\n            used_param_keys = []\r\n\r\n            for name, param in module._parameters.items():\r\n                if param is None:\r\n                    continue\r\n\r\n                key = prefix + name\r\n                sd_param = sd.pop(key, None)\r\n                if sd_param is not None:\r\n                    state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key))\r\n                    used_param_keys.append(key)\r\n\r\n                if param.is_meta:\r\n                    dtype = sd_param.dtype if sd_param is not None else param.dtype\r\n                    module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad)\r\n\r\n            for name in module._buffers:\r\n                key = prefix + name\r\n\r\n                sd_param = sd.pop(key, None)\r\n                if sd_param is not None:\r\n                    state_dict[key] = sd_param\r\n                    used_param_keys.append(key)\r\n\r\n            original(module, state_dict, prefix, *args, **kwargs)\r\n\r\n            for key in used_param_keys:\r\n                state_dict.pop(key, None)\r\n\r\n        def load_state_dict(original, module, state_dict, strict=True):\r\n            \"\"\"torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help\r\n            because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with\r\n            all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes.\r\n\r\n            In _load_from_state_dict, the correct weight will be obtained from a single dict with the right weights (sd).\r\n\r\n            The dangerous thing about this is if _load_from_state_dict is not called, (if some exotic module overloads\r\n            the function and does not call the original) the state dict will just fail to load because weights\r\n            would be on the meta device.\r\n            \"\"\"\r\n\r\n            if state_dict is sd:\r\n                state_dict = {k: v.to(device=\"meta\", dtype=v.dtype) for k, v in state_dict.items()}\r\n\r\n            original(module, state_dict, strict=strict)\r\n\r\n        module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs))\r\n        module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs))\r\n        linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs))\r\n        conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs))\r\n        mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs))\r\n        layer_norm_load_from_state_dict = self.replace(torch.nn.LayerNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(layer_norm_load_from_state_dict, *args, **kwargs))\r\n        group_norm_load_from_state_dict = self.replace(torch.nn.GroupNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(group_norm_load_from_state_dict, *args, **kwargs))\r\n\r\n    def __exit__(self, exc_type, exc_val, exc_tb):\r\n        self.restore()\r\n"
  },
  {
    "path": "modules/sd_emphasis.py",
    "content": "from __future__ import annotations\r\nimport torch\r\n\r\n\r\nclass Emphasis:\r\n    \"\"\"Emphasis class decides how to death with (emphasized:1.1) text in prompts\"\"\"\r\n\r\n    name: str = \"Base\"\r\n    description: str = \"\"\r\n\r\n    tokens: list[list[int]]\r\n    \"\"\"tokens from the chunk of the prompt\"\"\"\r\n\r\n    multipliers: torch.Tensor\r\n    \"\"\"tensor with multipliers, once for each token\"\"\"\r\n\r\n    z: torch.Tensor\r\n    \"\"\"output of cond transformers network (CLIP)\"\"\"\r\n\r\n    def after_transformers(self):\r\n        \"\"\"Called after cond transformers network has processed the chunk of the prompt; this function should modify self.z to apply the emphasis\"\"\"\r\n\r\n        pass\r\n\r\n\r\nclass EmphasisNone(Emphasis):\r\n    name = \"None\"\r\n    description = \"disable the mechanism entirely and treat (:.1.1) as literal characters\"\r\n\r\n\r\nclass EmphasisIgnore(Emphasis):\r\n    name = \"Ignore\"\r\n    description = \"treat all empasised words as if they have no emphasis\"\r\n\r\n\r\nclass EmphasisOriginal(Emphasis):\r\n    name = \"Original\"\r\n    description = \"the original emphasis implementation\"\r\n\r\n    def after_transformers(self):\r\n        original_mean = self.z.mean()\r\n        self.z = self.z * self.multipliers.reshape(self.multipliers.shape + (1,)).expand(self.z.shape)\r\n\r\n        # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise\r\n        new_mean = self.z.mean()\r\n        self.z = self.z * (original_mean / new_mean)\r\n\r\n\r\nclass EmphasisOriginalNoNorm(EmphasisOriginal):\r\n    name = \"No norm\"\r\n    description = \"same as original, but without normalization (seems to work better for SDXL)\"\r\n\r\n    def after_transformers(self):\r\n        self.z = self.z * self.multipliers.reshape(self.multipliers.shape + (1,)).expand(self.z.shape)\r\n\r\n\r\ndef get_current_option(emphasis_option_name):\r\n    return next(iter([x for x in options if x.name == emphasis_option_name]), EmphasisOriginal)\r\n\r\n\r\ndef get_options_descriptions():\r\n    return \", \".join(f\"{x.name}: {x.description}\" for x in options)\r\n\r\n\r\noptions = [\r\n    EmphasisNone,\r\n    EmphasisIgnore,\r\n    EmphasisOriginal,\r\n    EmphasisOriginalNoNorm,\r\n]\r\n"
  },
  {
    "path": "modules/sd_hijack.py",
    "content": "import torch\r\nfrom torch.nn.functional import silu\r\nfrom types import MethodType\r\n\r\nfrom modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches\r\nfrom modules.hypernetworks import hypernetwork\r\nfrom modules.shared import cmd_opts\r\nfrom modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18\r\n\r\nimport ldm.modules.attention\r\nimport ldm.modules.diffusionmodules.model\r\nimport ldm.modules.diffusionmodules.openaimodel\r\nimport ldm.models.diffusion.ddpm\r\nimport ldm.models.diffusion.ddim\r\nimport ldm.models.diffusion.plms\r\nimport ldm.modules.encoders.modules\r\n\r\nimport sgm.modules.attention\r\nimport sgm.modules.diffusionmodules.model\r\nimport sgm.modules.diffusionmodules.openaimodel\r\nimport sgm.modules.encoders.modules\r\n\r\nattention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward\r\ndiffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity\r\ndiffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward\r\n\r\n# new memory efficient cross attention blocks do not support hypernets and we already\r\n# have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention\r\nldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention\r\nldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES[\"softmax-xformers\"] = ldm.modules.attention.CrossAttention\r\n\r\n# silence new console spam from SD2\r\nldm.modules.attention.print = shared.ldm_print\r\nldm.modules.diffusionmodules.model.print = shared.ldm_print\r\nldm.util.print = shared.ldm_print\r\nldm.models.diffusion.ddpm.print = shared.ldm_print\r\n\r\noptimizers = []\r\ncurrent_optimizer: sd_hijack_optimizations.SdOptimization = None\r\n\r\nldm_patched_forward = sd_unet.create_unet_forward(ldm.modules.diffusionmodules.openaimodel.UNetModel.forward)\r\nldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, \"forward\", ldm_patched_forward)\r\n\r\nsgm_patched_forward = sd_unet.create_unet_forward(sgm.modules.diffusionmodules.openaimodel.UNetModel.forward)\r\nsgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, \"forward\", sgm_patched_forward)\r\n\r\n\r\ndef list_optimizers():\r\n    new_optimizers = script_callbacks.list_optimizers_callback()\r\n\r\n    new_optimizers = [x for x in new_optimizers if x.is_available()]\r\n\r\n    new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True)\r\n\r\n    optimizers.clear()\r\n    optimizers.extend(new_optimizers)\r\n\r\n\r\ndef apply_optimizations(option=None):\r\n    global current_optimizer\r\n\r\n    undo_optimizations()\r\n\r\n    if len(optimizers) == 0:\r\n        # a script can access the model very early, and optimizations would not be filled by then\r\n        current_optimizer = None\r\n        return ''\r\n\r\n    ldm.modules.diffusionmodules.model.nonlinearity = silu\r\n    ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th\r\n\r\n    sgm.modules.diffusionmodules.model.nonlinearity = silu\r\n    sgm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th\r\n\r\n    if current_optimizer is not None:\r\n        current_optimizer.undo()\r\n        current_optimizer = None\r\n\r\n    selection = option or shared.opts.cross_attention_optimization\r\n    if selection == \"Automatic\" and len(optimizers) > 0:\r\n        matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0])\r\n    else:\r\n        matching_optimizer = next(iter([x for x in optimizers if x.title() == selection]), None)\r\n\r\n    if selection == \"None\":\r\n        matching_optimizer = None\r\n    elif selection == \"Automatic\" and shared.cmd_opts.disable_opt_split_attention:\r\n        matching_optimizer = None\r\n    elif matching_optimizer is None:\r\n        matching_optimizer = optimizers[0]\r\n\r\n    if matching_optimizer is not None:\r\n        print(f\"Applying attention optimization: {matching_optimizer.name}... \", end='')\r\n        matching_optimizer.apply()\r\n        print(\"done.\")\r\n        current_optimizer = matching_optimizer\r\n        return current_optimizer.name\r\n    else:\r\n        print(\"Disabling attention optimization\")\r\n        return ''\r\n\r\n\r\ndef undo_optimizations():\r\n    ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity\r\n    ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward\r\n    ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward\r\n\r\n    sgm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity\r\n    sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward\r\n    sgm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward\r\n\r\n\r\ndef fix_checkpoint():\r\n    \"\"\"checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want\r\n    checkpoints to be added when not training (there's a warning)\"\"\"\r\n\r\n    pass\r\n\r\n\r\ndef weighted_loss(sd_model, pred, target, mean=True):\r\n    #Calculate the weight normally, but ignore the mean\r\n    loss = sd_model._old_get_loss(pred, target, mean=False)\r\n\r\n    #Check if we have weights available\r\n    weight = getattr(sd_model, '_custom_loss_weight', None)\r\n    if weight is not None:\r\n        loss *= weight\r\n\r\n    #Return the loss, as mean if specified\r\n    return loss.mean() if mean else loss\r\n\r\ndef weighted_forward(sd_model, x, c, w, *args, **kwargs):\r\n    try:\r\n        #Temporarily append weights to a place accessible during loss calc\r\n        sd_model._custom_loss_weight = w\r\n\r\n        #Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely\r\n        #Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set\r\n        if not hasattr(sd_model, '_old_get_loss'):\r\n            sd_model._old_get_loss = sd_model.get_loss\r\n        sd_model.get_loss = MethodType(weighted_loss, sd_model)\r\n\r\n        #Run the standard forward function, but with the patched 'get_loss'\r\n        return sd_model.forward(x, c, *args, **kwargs)\r\n    finally:\r\n        try:\r\n            #Delete temporary weights if appended\r\n            del sd_model._custom_loss_weight\r\n        except AttributeError:\r\n            pass\r\n\r\n        #If we have an old loss function, reset the loss function to the original one\r\n        if hasattr(sd_model, '_old_get_loss'):\r\n            sd_model.get_loss = sd_model._old_get_loss\r\n            del sd_model._old_get_loss\r\n\r\ndef apply_weighted_forward(sd_model):\r\n    #Add new function 'weighted_forward' that can be called to calc weighted loss\r\n    sd_model.weighted_forward = MethodType(weighted_forward, sd_model)\r\n\r\ndef undo_weighted_forward(sd_model):\r\n    try:\r\n        del sd_model.weighted_forward\r\n    except AttributeError:\r\n        pass\r\n\r\n\r\nclass StableDiffusionModelHijack:\r\n    fixes = None\r\n    layers = None\r\n    circular_enabled = False\r\n    clip = None\r\n    optimization_method = None\r\n\r\n    def __init__(self):\r\n        import modules.textual_inversion.textual_inversion\r\n\r\n        self.extra_generation_params = {}\r\n        self.comments = []\r\n\r\n        self.embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()\r\n        self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)\r\n\r\n    def apply_optimizations(self, option=None):\r\n        try:\r\n            self.optimization_method = apply_optimizations(option)\r\n        except Exception as e:\r\n            errors.display(e, \"applying cross attention optimization\")\r\n            undo_optimizations()\r\n\r\n    def convert_sdxl_to_ssd(self, m):\r\n        \"\"\"Converts an SDXL model to a Segmind Stable Diffusion model (see https://huggingface.co/segmind/SSD-1B)\"\"\"\r\n\r\n        delattr(m.model.diffusion_model.middle_block, '1')\r\n        delattr(m.model.diffusion_model.middle_block, '2')\r\n        for i in ['9', '8', '7', '6', '5', '4']:\r\n            delattr(m.model.diffusion_model.input_blocks[7][1].transformer_blocks, i)\r\n            delattr(m.model.diffusion_model.input_blocks[8][1].transformer_blocks, i)\r\n            delattr(m.model.diffusion_model.output_blocks[0][1].transformer_blocks, i)\r\n            delattr(m.model.diffusion_model.output_blocks[1][1].transformer_blocks, i)\r\n        delattr(m.model.diffusion_model.output_blocks[4][1].transformer_blocks, '1')\r\n        delattr(m.model.diffusion_model.output_blocks[5][1].transformer_blocks, '1')\r\n        devices.torch_gc()\r\n\r\n    def hijack(self, m):\r\n        conditioner = getattr(m, 'conditioner', None)\r\n        if conditioner:\r\n            text_cond_models = []\r\n\r\n            for i in range(len(conditioner.embedders)):\r\n                embedder = conditioner.embedders[i]\r\n                typename = type(embedder).__name__\r\n                if typename == 'FrozenOpenCLIPEmbedder':\r\n                    embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)\r\n                    conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self)\r\n                    text_cond_models.append(conditioner.embedders[i])\r\n                if typename == 'FrozenCLIPEmbedder':\r\n                    model_embeddings = embedder.transformer.text_model.embeddings\r\n                    model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)\r\n                    conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)\r\n                    text_cond_models.append(conditioner.embedders[i])\r\n                if typename == 'FrozenOpenCLIPEmbedder2':\r\n                    embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self, textual_inversion_key='clip_g')\r\n                    conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)\r\n                    text_cond_models.append(conditioner.embedders[i])\r\n\r\n            if len(text_cond_models) == 1:\r\n                m.cond_stage_model = text_cond_models[0]\r\n            else:\r\n                m.cond_stage_model = conditioner\r\n\r\n        if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation or type(m.cond_stage_model) == xlmr_m18.BertSeriesModelWithTransformation:\r\n            model_embeddings = m.cond_stage_model.roberta.embeddings\r\n            model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)\r\n            m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)\r\n\r\n        elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:\r\n            model_embeddings = m.cond_stage_model.transformer.text_model.embeddings\r\n            model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)\r\n            m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)\r\n\r\n        elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:\r\n            m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)\r\n            m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)\r\n\r\n        apply_weighted_forward(m)\r\n        if m.cond_stage_key == \"edit\":\r\n            sd_hijack_unet.hijack_ddpm_edit()\r\n\r\n        self.apply_optimizations()\r\n\r\n        self.clip = m.cond_stage_model\r\n\r\n        def flatten(el):\r\n            flattened = [flatten(children) for children in el.children()]\r\n            res = [el]\r\n            for c in flattened:\r\n                res += c\r\n            return res\r\n\r\n        self.layers = flatten(m)\r\n\r\n        import modules.models.diffusion.ddpm_edit\r\n\r\n        if isinstance(m, ldm.models.diffusion.ddpm.LatentDiffusion):\r\n            sd_unet.original_forward = ldm_original_forward\r\n        elif isinstance(m, modules.models.diffusion.ddpm_edit.LatentDiffusion):\r\n            sd_unet.original_forward = ldm_original_forward\r\n        elif isinstance(m, sgm.models.diffusion.DiffusionEngine):\r\n            sd_unet.original_forward = sgm_original_forward\r\n        else:\r\n            sd_unet.original_forward = None\r\n\r\n\r\n    def undo_hijack(self, m):\r\n        conditioner = getattr(m, 'conditioner', None)\r\n        if conditioner:\r\n            for i in range(len(conditioner.embedders)):\r\n                embedder = conditioner.embedders[i]\r\n                if isinstance(embedder, (sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords, sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords)):\r\n                    embedder.wrapped.model.token_embedding = embedder.wrapped.model.token_embedding.wrapped\r\n                    conditioner.embedders[i] = embedder.wrapped\r\n                if isinstance(embedder, sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords):\r\n                    embedder.wrapped.transformer.text_model.embeddings.token_embedding = embedder.wrapped.transformer.text_model.embeddings.token_embedding.wrapped\r\n                    conditioner.embedders[i] = embedder.wrapped\r\n\r\n            if hasattr(m, 'cond_stage_model'):\r\n                delattr(m, 'cond_stage_model')\r\n\r\n        elif type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords:\r\n            m.cond_stage_model = m.cond_stage_model.wrapped\r\n\r\n        elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:\r\n            m.cond_stage_model = m.cond_stage_model.wrapped\r\n\r\n            model_embeddings = m.cond_stage_model.transformer.text_model.embeddings\r\n            if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:\r\n                model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped\r\n        elif type(m.cond_stage_model) == sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords:\r\n            m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped\r\n            m.cond_stage_model = m.cond_stage_model.wrapped\r\n\r\n        undo_optimizations()\r\n        undo_weighted_forward(m)\r\n\r\n        self.apply_circular(False)\r\n        self.layers = None\r\n        self.clip = None\r\n\r\n\r\n    def apply_circular(self, enable):\r\n        if self.circular_enabled == enable:\r\n            return\r\n\r\n        self.circular_enabled = enable\r\n\r\n        for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:\r\n            layer.padding_mode = 'circular' if enable else 'zeros'\r\n\r\n    def clear_comments(self):\r\n        self.comments = []\r\n        self.extra_generation_params = {}\r\n\r\n    def get_prompt_lengths(self, text):\r\n        if self.clip is None:\r\n            return \"-\", \"-\"\r\n\r\n        if hasattr(self.clip, 'get_token_count'):\r\n            token_count = self.clip.get_token_count(text)\r\n        else:\r\n            _, token_count = self.clip.process_texts([text])\r\n\r\n        return token_count, self.clip.get_target_prompt_token_count(token_count)\r\n\r\n    def redo_hijack(self, m):\r\n        self.undo_hijack(m)\r\n        self.hijack(m)\r\n\r\n\r\nclass EmbeddingsWithFixes(torch.nn.Module):\r\n    def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'):\r\n        super().__init__()\r\n        self.wrapped = wrapped\r\n        self.embeddings = embeddings\r\n        self.textual_inversion_key = textual_inversion_key\r\n\r\n    def forward(self, input_ids):\r\n        batch_fixes = self.embeddings.fixes\r\n        self.embeddings.fixes = None\r\n\r\n        inputs_embeds = self.wrapped(input_ids)\r\n\r\n        if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:\r\n            return inputs_embeds\r\n\r\n        vecs = []\r\n        for fixes, tensor in zip(batch_fixes, inputs_embeds):\r\n            for offset, embedding in fixes:\r\n                vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec\r\n                emb = devices.cond_cast_unet(vec)\r\n                emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])\r\n                tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]).to(dtype=inputs_embeds.dtype)\r\n\r\n            vecs.append(tensor)\r\n\r\n        return torch.stack(vecs)\r\n\r\n\r\nclass TextualInversionEmbeddings(torch.nn.Embedding):\r\n    def __init__(self, num_embeddings: int, embedding_dim: int, textual_inversion_key='clip_l', **kwargs):\r\n        super().__init__(num_embeddings, embedding_dim, **kwargs)\r\n\r\n        self.embeddings = model_hijack\r\n        self.textual_inversion_key = textual_inversion_key\r\n\r\n    @property\r\n    def wrapped(self):\r\n        return super().forward\r\n\r\n    def forward(self, input_ids):\r\n        return EmbeddingsWithFixes.forward(self, input_ids)\r\n\r\n\r\ndef add_circular_option_to_conv_2d():\r\n    conv2d_constructor = torch.nn.Conv2d.__init__\r\n\r\n    def conv2d_constructor_circular(self, *args, **kwargs):\r\n        return conv2d_constructor(self, *args, padding_mode='circular', **kwargs)\r\n\r\n    torch.nn.Conv2d.__init__ = conv2d_constructor_circular\r\n\r\n\r\nmodel_hijack = StableDiffusionModelHijack()\r\n\r\n\r\ndef register_buffer(self, name, attr):\r\n    \"\"\"\r\n    Fix register buffer bug for Mac OS.\r\n    \"\"\"\r\n\r\n    if type(attr) == torch.Tensor:\r\n        if attr.device != devices.device:\r\n            attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))\r\n\r\n    setattr(self, name, attr)\r\n\r\n\r\nldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer\r\nldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer\r\n"
  },
  {
    "path": "modules/sd_hijack_checkpoint.py",
    "content": "from torch.utils.checkpoint import checkpoint\n\nimport ldm.modules.attention\nimport ldm.modules.diffusionmodules.openaimodel\n\n\ndef BasicTransformerBlock_forward(self, x, context=None):\n    return checkpoint(self._forward, x, context)\n\n\ndef AttentionBlock_forward(self, x):\n    return checkpoint(self._forward, x)\n\n\ndef ResBlock_forward(self, x, emb):\n    return checkpoint(self._forward, x, emb)\n\n\nstored = []\n\n\ndef add():\n    if len(stored) != 0:\n        return\n\n    stored.extend([\n        ldm.modules.attention.BasicTransformerBlock.forward,\n        ldm.modules.diffusionmodules.openaimodel.ResBlock.forward,\n        ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward\n    ])\n\n    ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward\n    ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward\n    ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward\n\n\ndef remove():\n    if len(stored) == 0:\n        return\n\n    ldm.modules.attention.BasicTransformerBlock.forward = stored[0]\n    ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1]\n    ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2]\n\n    stored.clear()\n\n"
  },
  {
    "path": "modules/sd_hijack_clip.py",
    "content": "import math\r\nfrom collections import namedtuple\r\n\r\nimport torch\r\n\r\nfrom modules import prompt_parser, devices, sd_hijack, sd_emphasis\r\nfrom modules.shared import opts\r\n\r\n\r\nclass PromptChunk:\r\n    \"\"\"\r\n    This object contains token ids, weight (multipliers:1.4) and textual inversion embedding info for a chunk of prompt.\r\n    If a prompt is short, it is represented by one PromptChunk, otherwise, multiple are necessary.\r\n    Each PromptChunk contains an exact amount of tokens - 77, which includes one for start and end token,\r\n    so just 75 tokens from prompt.\r\n    \"\"\"\r\n\r\n    def __init__(self):\r\n        self.tokens = []\r\n        self.multipliers = []\r\n        self.fixes = []\r\n\r\n\r\nPromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])\r\n\"\"\"An object of this type is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt\r\nchunk. Those objects are found in PromptChunk.fixes and, are placed into FrozenCLIPEmbedderWithCustomWordsBase.hijack.fixes, and finally\r\nare applied by sd_hijack.EmbeddingsWithFixes's forward function.\"\"\"\r\n\r\n\r\nclass TextConditionalModel(torch.nn.Module):\r\n    def __init__(self):\r\n        super().__init__()\r\n\r\n        self.hijack = sd_hijack.model_hijack\r\n        self.chunk_length = 75\r\n\r\n        self.is_trainable = False\r\n        self.input_key = 'txt'\r\n        self.return_pooled = False\r\n\r\n        self.comma_token = None\r\n        self.id_start = None\r\n        self.id_end = None\r\n        self.id_pad = None\r\n\r\n    def empty_chunk(self):\r\n        \"\"\"creates an empty PromptChunk and returns it\"\"\"\r\n\r\n        chunk = PromptChunk()\r\n        chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1)\r\n        chunk.multipliers = [1.0] * (self.chunk_length + 2)\r\n        return chunk\r\n\r\n    def get_target_prompt_token_count(self, token_count):\r\n        \"\"\"returns the maximum number of tokens a prompt of a known length can have before it requires one more PromptChunk to be represented\"\"\"\r\n\r\n        return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length\r\n\r\n    def tokenize(self, texts):\r\n        \"\"\"Converts a batch of texts into a batch of token ids\"\"\"\r\n\r\n        raise NotImplementedError\r\n\r\n    def encode_with_transformers(self, tokens):\r\n        \"\"\"\r\n        converts a batch of token ids (in python lists) into a single tensor with numeric representation of those tokens;\r\n        All python lists with tokens are assumed to have same length, usually 77.\r\n        if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on\r\n        model - can be 768 and 1024.\r\n        Among other things, this call will read self.hijack.fixes, apply it to its inputs, and clear it (setting it to None).\r\n        \"\"\"\r\n\r\n        raise NotImplementedError\r\n\r\n    def encode_embedding_init_text(self, init_text, nvpt):\r\n        \"\"\"Converts text into a tensor with this text's tokens' embeddings. Note that those are embeddings before they are passed through\r\n        transformers. nvpt is used as a maximum length in tokens. If text produces less teokens than nvpt, only this many is returned.\"\"\"\r\n\r\n        raise NotImplementedError\r\n\r\n    def tokenize_line(self, line):\r\n        \"\"\"\r\n        this transforms a single prompt into a list of PromptChunk objects - as many as needed to\r\n        represent the prompt.\r\n        Returns the list and the total number of tokens in the prompt.\r\n        \"\"\"\r\n\r\n        if opts.emphasis != \"None\":\r\n            parsed = prompt_parser.parse_prompt_attention(line)\r\n        else:\r\n            parsed = [[line, 1.0]]\r\n\r\n        tokenized = self.tokenize([text for text, _ in parsed])\r\n\r\n        chunks = []\r\n        chunk = PromptChunk()\r\n        token_count = 0\r\n        last_comma = -1\r\n\r\n        def next_chunk(is_last=False):\r\n            \"\"\"puts current chunk into the list of results and produces the next one - empty;\r\n            if is_last is true, tokens <end-of-text> tokens at the end won't add to token_count\"\"\"\r\n            nonlocal token_count\r\n            nonlocal last_comma\r\n            nonlocal chunk\r\n\r\n            if is_last:\r\n                token_count += len(chunk.tokens)\r\n            else:\r\n                token_count += self.chunk_length\r\n\r\n            to_add = self.chunk_length - len(chunk.tokens)\r\n            if to_add > 0:\r\n                chunk.tokens += [self.id_end] * to_add\r\n                chunk.multipliers += [1.0] * to_add\r\n\r\n            chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end]\r\n            chunk.multipliers = [1.0] + chunk.multipliers + [1.0]\r\n\r\n            last_comma = -1\r\n            chunks.append(chunk)\r\n            chunk = PromptChunk()\r\n\r\n        for tokens, (text, weight) in zip(tokenized, parsed):\r\n            if text == 'BREAK' and weight == -1:\r\n                next_chunk()\r\n                continue\r\n\r\n            position = 0\r\n            while position < len(tokens):\r\n                token = tokens[position]\r\n\r\n                if token == self.comma_token:\r\n                    last_comma = len(chunk.tokens)\r\n\r\n                # this is when we are at the end of allotted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack\r\n                # is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next.\r\n                elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack:\r\n                    break_location = last_comma + 1\r\n\r\n                    reloc_tokens = chunk.tokens[break_location:]\r\n                    reloc_mults = chunk.multipliers[break_location:]\r\n\r\n                    chunk.tokens = chunk.tokens[:break_location]\r\n                    chunk.multipliers = chunk.multipliers[:break_location]\r\n\r\n                    next_chunk()\r\n                    chunk.tokens = reloc_tokens\r\n                    chunk.multipliers = reloc_mults\r\n\r\n                if len(chunk.tokens) == self.chunk_length:\r\n                    next_chunk()\r\n\r\n                embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, position)\r\n                if embedding is None:\r\n                    chunk.tokens.append(token)\r\n                    chunk.multipliers.append(weight)\r\n                    position += 1\r\n                    continue\r\n\r\n                emb_len = int(embedding.vectors)\r\n                if len(chunk.tokens) + emb_len > self.chunk_length:\r\n                    next_chunk()\r\n\r\n                chunk.fixes.append(PromptChunkFix(len(chunk.tokens), embedding))\r\n\r\n                chunk.tokens += [0] * emb_len\r\n                chunk.multipliers += [weight] * emb_len\r\n                position += embedding_length_in_tokens\r\n\r\n        if chunk.tokens or not chunks:\r\n            next_chunk(is_last=True)\r\n\r\n        return chunks, token_count\r\n\r\n    def process_texts(self, texts):\r\n        \"\"\"\r\n        Accepts a list of texts and calls tokenize_line() on each, with cache. Returns the list of results and maximum\r\n        length, in tokens, of all texts.\r\n        \"\"\"\r\n\r\n        token_count = 0\r\n\r\n        cache = {}\r\n        batch_chunks = []\r\n        for line in texts:\r\n            if line in cache:\r\n                chunks = cache[line]\r\n            else:\r\n                chunks, current_token_count = self.tokenize_line(line)\r\n                token_count = max(current_token_count, token_count)\r\n\r\n                cache[line] = chunks\r\n\r\n            batch_chunks.append(chunks)\r\n\r\n        return batch_chunks, token_count\r\n\r\n    def forward(self, texts):\r\n        \"\"\"\r\n        Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.\r\n        Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will\r\n        be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, for SD2 it's 1024, and for SDXL it's 1280.\r\n        An example shape returned by this function can be: (2, 77, 768).\r\n        For SDXL, instead of returning one tensor avobe, it returns a tuple with two: the other one with shape (B, 1280) with pooled values.\r\n        Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one element\r\n        is when you do prompt editing: \"a picture of a [cat:dog:0.4] eating ice cream\"\r\n        \"\"\"\r\n\r\n        batch_chunks, token_count = self.process_texts(texts)\r\n\r\n        used_embeddings = {}\r\n        chunk_count = max([len(x) for x in batch_chunks])\r\n\r\n        zs = []\r\n        for i in range(chunk_count):\r\n            batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks]\r\n\r\n            tokens = [x.tokens for x in batch_chunk]\r\n            multipliers = [x.multipliers for x in batch_chunk]\r\n            self.hijack.fixes = [x.fixes for x in batch_chunk]\r\n\r\n            for fixes in self.hijack.fixes:\r\n                for _position, embedding in fixes:\r\n                    used_embeddings[embedding.name] = embedding\r\n            devices.torch_npu_set_device()\r\n            z = self.process_tokens(tokens, multipliers)\r\n            zs.append(z)\r\n\r\n        if opts.textual_inversion_add_hashes_to_infotext and used_embeddings:\r\n            hashes = []\r\n            for name, embedding in used_embeddings.items():\r\n                shorthash = embedding.shorthash\r\n                if not shorthash:\r\n                    continue\r\n\r\n                name = name.replace(\":\", \"\").replace(\",\", \"\")\r\n                hashes.append(f\"{name}: {shorthash}\")\r\n\r\n            if hashes:\r\n                if self.hijack.extra_generation_params.get(\"TI hashes\"):\r\n                    hashes.append(self.hijack.extra_generation_params.get(\"TI hashes\"))\r\n                self.hijack.extra_generation_params[\"TI hashes\"] = \", \".join(hashes)\r\n\r\n        if any(x for x in texts if \"(\" in x or \"[\" in x) and opts.emphasis != \"Original\":\r\n            self.hijack.extra_generation_params[\"Emphasis\"] = opts.emphasis\r\n\r\n        if self.return_pooled:\r\n            return torch.hstack(zs), zs[0].pooled\r\n        else:\r\n            return torch.hstack(zs)\r\n\r\n    def process_tokens(self, remade_batch_tokens, batch_multipliers):\r\n        \"\"\"\r\n        sends one single prompt chunk to be encoded by transformers neural network.\r\n        remade_batch_tokens is a batch of tokens - a list, where every element is a list of tokens; usually\r\n        there are exactly 77 tokens in the list. batch_multipliers is the same but for multipliers instead of tokens.\r\n        Multipliers are used to give more or less weight to the outputs of transformers network. Each multiplier\r\n        corresponds to one token.\r\n        \"\"\"\r\n        tokens = torch.asarray(remade_batch_tokens).to(devices.device)\r\n\r\n        # this is for SD2: SD1 uses the same token for padding and end of text, while SD2 uses different ones.\r\n        if self.id_end != self.id_pad:\r\n            for batch_pos in range(len(remade_batch_tokens)):\r\n                index = remade_batch_tokens[batch_pos].index(self.id_end)\r\n                tokens[batch_pos, index+1:tokens.shape[1]] = self.id_pad\r\n\r\n        z = self.encode_with_transformers(tokens)\r\n\r\n        pooled = getattr(z, 'pooled', None)\r\n\r\n        emphasis = sd_emphasis.get_current_option(opts.emphasis)()\r\n        emphasis.tokens = remade_batch_tokens\r\n        emphasis.multipliers = torch.asarray(batch_multipliers).to(devices.device)\r\n        emphasis.z = z\r\n\r\n        emphasis.after_transformers()\r\n\r\n        z = emphasis.z\r\n\r\n        if pooled is not None:\r\n            z.pooled = pooled\r\n\r\n        return z\r\n\r\n\r\nclass FrozenCLIPEmbedderWithCustomWordsBase(TextConditionalModel):\r\n    \"\"\"A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to\r\n    have unlimited prompt length and assign weights to tokens in prompt.\r\n    \"\"\"\r\n\r\n    def __init__(self, wrapped, hijack):\r\n        super().__init__()\r\n\r\n        self.hijack = hijack\r\n\r\n        self.wrapped = wrapped\r\n        \"\"\"Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation,\r\n        depending on model.\"\"\"\r\n\r\n        self.is_trainable = getattr(wrapped, 'is_trainable', False)\r\n        self.input_key = getattr(wrapped, 'input_key', 'txt')\r\n        self.return_pooled = getattr(self.wrapped, 'return_pooled', False)\r\n\r\n        self.legacy_ucg_val = None  # for sgm codebase\r\n\r\n    def forward(self, texts):\r\n        if opts.use_old_emphasis_implementation:\r\n            import modules.sd_hijack_clip_old\r\n            return modules.sd_hijack_clip_old.forward_old(self, texts)\r\n\r\n        return super().forward(texts)\r\n\r\n\r\nclass FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):\r\n    def __init__(self, wrapped, hijack):\r\n        super().__init__(wrapped, hijack)\r\n        self.tokenizer = wrapped.tokenizer\r\n\r\n        vocab = self.tokenizer.get_vocab()\r\n\r\n        self.comma_token = vocab.get(',</w>', None)\r\n\r\n        self.token_mults = {}\r\n        tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k]\r\n        for text, ident in tokens_with_parens:\r\n            mult = 1.0\r\n            for c in text:\r\n                if c == '[':\r\n                    mult /= 1.1\r\n                if c == ']':\r\n                    mult *= 1.1\r\n                if c == '(':\r\n                    mult *= 1.1\r\n                if c == ')':\r\n                    mult /= 1.1\r\n\r\n            if mult != 1.0:\r\n                self.token_mults[ident] = mult\r\n\r\n        self.id_start = self.wrapped.tokenizer.bos_token_id\r\n        self.id_end = self.wrapped.tokenizer.eos_token_id\r\n        self.id_pad = self.id_end\r\n\r\n    def tokenize(self, texts):\r\n        tokenized = self.wrapped.tokenizer(texts, truncation=False, add_special_tokens=False)[\"input_ids\"]\r\n\r\n        return tokenized\r\n\r\n    def encode_with_transformers(self, tokens):\r\n        outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)\r\n\r\n        if opts.CLIP_stop_at_last_layers > 1:\r\n            z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]\r\n            z = self.wrapped.transformer.text_model.final_layer_norm(z)\r\n        else:\r\n            z = outputs.last_hidden_state\r\n\r\n        return z\r\n\r\n    def encode_embedding_init_text(self, init_text, nvpt):\r\n        embedding_layer = self.wrapped.transformer.text_model.embeddings\r\n        ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors=\"pt\", add_special_tokens=False)[\"input_ids\"]\r\n        embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)\r\n\r\n        return embedded\r\n\r\n\r\nclass FrozenCLIPEmbedderForSDXLWithCustomWords(FrozenCLIPEmbedderWithCustomWords):\r\n    def __init__(self, wrapped, hijack):\r\n        super().__init__(wrapped, hijack)\r\n\r\n    def encode_with_transformers(self, tokens):\r\n        outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=self.wrapped.layer == \"hidden\")\r\n\r\n        if opts.sdxl_clip_l_skip is True:\r\n            z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]\r\n        elif self.wrapped.layer == \"last\":\r\n            z = outputs.last_hidden_state\r\n        else:\r\n            z = outputs.hidden_states[self.wrapped.layer_idx]\r\n\r\n        return z\r\n"
  },
  {
    "path": "modules/sd_hijack_clip_old.py",
    "content": "from modules import sd_hijack_clip\r\nfrom modules import shared\r\n\r\n\r\ndef process_text_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):\r\n    id_start = self.id_start\r\n    id_end = self.id_end\r\n    maxlen = self.wrapped.max_length  # you get to stay at 77\r\n    used_custom_terms = []\r\n    remade_batch_tokens = []\r\n    hijack_comments = []\r\n    hijack_fixes = []\r\n    token_count = 0\r\n\r\n    cache = {}\r\n    batch_tokens = self.tokenize(texts)\r\n    batch_multipliers = []\r\n    for tokens in batch_tokens:\r\n        tuple_tokens = tuple(tokens)\r\n\r\n        if tuple_tokens in cache:\r\n            remade_tokens, fixes, multipliers = cache[tuple_tokens]\r\n        else:\r\n            fixes = []\r\n            remade_tokens = []\r\n            multipliers = []\r\n            mult = 1.0\r\n\r\n            i = 0\r\n            while i < len(tokens):\r\n                token = tokens[i]\r\n\r\n                embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)\r\n\r\n                mult_change = self.token_mults.get(token) if shared.opts.emphasis != \"None\" else None\r\n                if mult_change is not None:\r\n                    mult *= mult_change\r\n                    i += 1\r\n                elif embedding is None:\r\n                    remade_tokens.append(token)\r\n                    multipliers.append(mult)\r\n                    i += 1\r\n                else:\r\n                    emb_len = int(embedding.vec.shape[0])\r\n                    fixes.append((len(remade_tokens), embedding))\r\n                    remade_tokens += [0] * emb_len\r\n                    multipliers += [mult] * emb_len\r\n                    used_custom_terms.append((embedding.name, embedding.checksum()))\r\n                    i += embedding_length_in_tokens\r\n\r\n            if len(remade_tokens) > maxlen - 2:\r\n                vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}\r\n                ovf = remade_tokens[maxlen - 2:]\r\n                overflowing_words = [vocab.get(int(x), \"\") for x in ovf]\r\n                overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))\r\n                hijack_comments.append(f\"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\\n{overflowing_text}\\n\")\r\n\r\n            token_count = len(remade_tokens)\r\n            remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))\r\n            remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]\r\n            cache[tuple_tokens] = (remade_tokens, fixes, multipliers)\r\n\r\n        multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))\r\n        multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]\r\n\r\n        remade_batch_tokens.append(remade_tokens)\r\n        hijack_fixes.append(fixes)\r\n        batch_multipliers.append(multipliers)\r\n    return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count\r\n\r\n\r\ndef forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts):\r\n    batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = process_text_old(self, texts)\r\n\r\n    self.hijack.comments += hijack_comments\r\n\r\n    if used_custom_terms:\r\n        embedding_names = \", \".join(f\"{word} [{checksum}]\" for word, checksum in used_custom_terms)\r\n        self.hijack.comments.append(f\"Used embeddings: {embedding_names}\")\r\n\r\n    self.hijack.fixes = hijack_fixes\r\n    return self.process_tokens(remade_batch_tokens, batch_multipliers)\r\n"
  },
  {
    "path": "modules/sd_hijack_ip2p.py",
    "content": "import os.path\n\n\ndef should_hijack_ip2p(checkpoint_info):\n    from modules import sd_models_config\n\n    ckpt_basename = os.path.basename(checkpoint_info.filename).lower()\n    cfg_basename = os.path.basename(sd_models_config.find_checkpoint_config_near_filename(checkpoint_info)).lower()\n\n    return \"pix2pix\" in ckpt_basename and \"pix2pix\" not in cfg_basename\n"
  },
  {
    "path": "modules/sd_hijack_open_clip.py",
    "content": "import open_clip.tokenizer\r\nimport torch\r\n\r\nfrom modules import sd_hijack_clip, devices\r\nfrom modules.shared import opts\r\n\r\ntokenizer = open_clip.tokenizer._tokenizer\r\n\r\n\r\nclass FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase):\r\n    def __init__(self, wrapped, hijack):\r\n        super().__init__(wrapped, hijack)\r\n\r\n        self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ',</w>'][0]\r\n        self.id_start = tokenizer.encoder[\"<start_of_text>\"]\r\n        self.id_end = tokenizer.encoder[\"<end_of_text>\"]\r\n        self.id_pad = 0\r\n\r\n    def tokenize(self, texts):\r\n        assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'\r\n\r\n        tokenized = [tokenizer.encode(text) for text in texts]\r\n\r\n        return tokenized\r\n\r\n    def encode_with_transformers(self, tokens):\r\n        # set self.wrapped.layer_idx here according to opts.CLIP_stop_at_last_layers\r\n        z = self.wrapped.encode_with_transformer(tokens)\r\n\r\n        return z\r\n\r\n    def encode_embedding_init_text(self, init_text, nvpt):\r\n        ids = tokenizer.encode(init_text)\r\n        ids = torch.asarray([ids], device=devices.device, dtype=torch.int)\r\n        embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)\r\n\r\n        return embedded\r\n\r\n\r\nclass FrozenOpenCLIPEmbedder2WithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase):\r\n    def __init__(self, wrapped, hijack):\r\n        super().__init__(wrapped, hijack)\r\n\r\n        self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ',</w>'][0]\r\n        self.id_start = tokenizer.encoder[\"<start_of_text>\"]\r\n        self.id_end = tokenizer.encoder[\"<end_of_text>\"]\r\n        self.id_pad = 0\r\n\r\n    def tokenize(self, texts):\r\n        assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'\r\n\r\n        tokenized = [tokenizer.encode(text) for text in texts]\r\n\r\n        return tokenized\r\n\r\n    def encode_with_transformers(self, tokens):\r\n        d = self.wrapped.encode_with_transformer(tokens)\r\n        z = d[self.wrapped.layer]\r\n\r\n        pooled = d.get(\"pooled\")\r\n        if pooled is not None:\r\n            z.pooled = pooled\r\n\r\n        return z\r\n\r\n    def encode_embedding_init_text(self, init_text, nvpt):\r\n        ids = tokenizer.encode(init_text)\r\n        ids = torch.asarray([ids], device=devices.device, dtype=torch.int)\r\n        embedded = self.wrapped.model.token_embedding.wrapped(ids.to(self.wrapped.model.token_embedding.wrapped.weight.device)).squeeze(0)\r\n\r\n        return embedded\r\n"
  },
  {
    "path": "modules/sd_hijack_optimizations.py",
    "content": "from __future__ import annotations\r\nimport math\r\nimport psutil\r\nimport platform\r\n\r\nimport torch\r\nfrom torch import einsum\r\n\r\nfrom ldm.util import default\r\nfrom einops import rearrange\r\n\r\nfrom modules import shared, errors, devices, sub_quadratic_attention\r\nfrom modules.hypernetworks import hypernetwork\r\n\r\nimport ldm.modules.attention\r\nimport ldm.modules.diffusionmodules.model\r\n\r\nimport sgm.modules.attention\r\nimport sgm.modules.diffusionmodules.model\r\n\r\ndiffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward\r\nsgm_diffusionmodules_model_AttnBlock_forward = sgm.modules.diffusionmodules.model.AttnBlock.forward\r\n\r\n\r\nclass SdOptimization:\r\n    name: str = None\r\n    label: str | None = None\r\n    cmd_opt: str | None = None\r\n    priority: int = 0\r\n\r\n    def title(self):\r\n        if self.label is None:\r\n            return self.name\r\n\r\n        return f\"{self.name} - {self.label}\"\r\n\r\n    def is_available(self):\r\n        return True\r\n\r\n    def apply(self):\r\n        pass\r\n\r\n    def undo(self):\r\n        ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward\r\n        ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward\r\n\r\n        sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward\r\n        sgm.modules.diffusionmodules.model.AttnBlock.forward = sgm_diffusionmodules_model_AttnBlock_forward\r\n\r\n\r\nclass SdOptimizationXformers(SdOptimization):\r\n    name = \"xformers\"\r\n    cmd_opt = \"xformers\"\r\n    priority = 100\r\n\r\n    def is_available(self):\r\n        return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.cuda.is_available() and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0))\r\n\r\n    def apply(self):\r\n        ldm.modules.attention.CrossAttention.forward = xformers_attention_forward\r\n        ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward\r\n        sgm.modules.attention.CrossAttention.forward = xformers_attention_forward\r\n        sgm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward\r\n\r\n\r\nclass SdOptimizationSdpNoMem(SdOptimization):\r\n    name = \"sdp-no-mem\"\r\n    label = \"scaled dot product without memory efficient attention\"\r\n    cmd_opt = \"opt_sdp_no_mem_attention\"\r\n    priority = 80\r\n\r\n    def is_available(self):\r\n        return hasattr(torch.nn.functional, \"scaled_dot_product_attention\") and callable(torch.nn.functional.scaled_dot_product_attention)\r\n\r\n    def apply(self):\r\n        ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward\r\n        ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward\r\n        sgm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward\r\n        sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward\r\n\r\n\r\nclass SdOptimizationSdp(SdOptimizationSdpNoMem):\r\n    name = \"sdp\"\r\n    label = \"scaled dot product\"\r\n    cmd_opt = \"opt_sdp_attention\"\r\n    priority = 70\r\n\r\n    def apply(self):\r\n        ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward\r\n        ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward\r\n        sgm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward\r\n        sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward\r\n\r\n\r\nclass SdOptimizationSubQuad(SdOptimization):\r\n    name = \"sub-quadratic\"\r\n    cmd_opt = \"opt_sub_quad_attention\"\r\n\r\n    @property\r\n    def priority(self):\r\n        return 1000 if shared.device.type == 'mps' else 10\r\n\r\n    def apply(self):\r\n        ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward\r\n        ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward\r\n        sgm.modules.attention.CrossAttention.forward = sub_quad_attention_forward\r\n        sgm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward\r\n\r\n\r\nclass SdOptimizationV1(SdOptimization):\r\n    name = \"V1\"\r\n    label = \"original v1\"\r\n    cmd_opt = \"opt_split_attention_v1\"\r\n    priority = 10\r\n\r\n    def apply(self):\r\n        ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1\r\n        sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1\r\n\r\n\r\nclass SdOptimizationInvokeAI(SdOptimization):\r\n    name = \"InvokeAI\"\r\n    cmd_opt = \"opt_split_attention_invokeai\"\r\n\r\n    @property\r\n    def priority(self):\r\n        return 1000 if shared.device.type != 'mps' and not torch.cuda.is_available() else 10\r\n\r\n    def apply(self):\r\n        ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI\r\n        sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI\r\n\r\n\r\nclass SdOptimizationDoggettx(SdOptimization):\r\n    name = \"Doggettx\"\r\n    cmd_opt = \"opt_split_attention\"\r\n    priority = 90\r\n\r\n    def apply(self):\r\n        ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward\r\n        ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward\r\n        sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward\r\n        sgm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward\r\n\r\n\r\ndef list_optimizers(res):\r\n    res.extend([\r\n        SdOptimizationXformers(),\r\n        SdOptimizationSdpNoMem(),\r\n        SdOptimizationSdp(),\r\n        SdOptimizationSubQuad(),\r\n        SdOptimizationV1(),\r\n        SdOptimizationInvokeAI(),\r\n        SdOptimizationDoggettx(),\r\n    ])\r\n\r\n\r\nif shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:\r\n    try:\r\n        import xformers.ops\r\n        shared.xformers_available = True\r\n    except Exception:\r\n        errors.report(\"Cannot import xformers\", exc_info=True)\r\n\r\n\r\ndef get_available_vram():\r\n    if shared.device.type == 'cuda':\r\n        stats = torch.cuda.memory_stats(shared.device)\r\n        mem_active = stats['active_bytes.all.current']\r\n        mem_reserved = stats['reserved_bytes.all.current']\r\n        mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())\r\n        mem_free_torch = mem_reserved - mem_active\r\n        mem_free_total = mem_free_cuda + mem_free_torch\r\n        return mem_free_total\r\n    else:\r\n        return psutil.virtual_memory().available\r\n\r\n\r\n# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion\r\ndef split_cross_attention_forward_v1(self, x, context=None, mask=None, **kwargs):\r\n    h = self.heads\r\n\r\n    q_in = self.to_q(x)\r\n    context = default(context, x)\r\n\r\n    context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)\r\n    k_in = self.to_k(context_k)\r\n    v_in = self.to_v(context_v)\r\n    del context, context_k, context_v, x\r\n\r\n    q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))\r\n    del q_in, k_in, v_in\r\n\r\n    dtype = q.dtype\r\n    if shared.opts.upcast_attn:\r\n        q, k, v = q.float(), k.float(), v.float()\r\n\r\n    with devices.without_autocast(disable=not shared.opts.upcast_attn):\r\n        r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)\r\n        for i in range(0, q.shape[0], 2):\r\n            end = i + 2\r\n            s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])\r\n            s1 *= self.scale\r\n\r\n            s2 = s1.softmax(dim=-1)\r\n            del s1\r\n\r\n            r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])\r\n            del s2\r\n        del q, k, v\r\n\r\n    r1 = r1.to(dtype)\r\n\r\n    r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)\r\n    del r1\r\n\r\n    return self.to_out(r2)\r\n\r\n\r\n# taken from https://github.com/Doggettx/stable-diffusion and modified\r\ndef split_cross_attention_forward(self, x, context=None, mask=None, **kwargs):\r\n    h = self.heads\r\n\r\n    q_in = self.to_q(x)\r\n    context = default(context, x)\r\n\r\n    context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)\r\n    k_in = self.to_k(context_k)\r\n    v_in = self.to_v(context_v)\r\n\r\n    dtype = q_in.dtype\r\n    if shared.opts.upcast_attn:\r\n        q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float()\r\n\r\n    with devices.without_autocast(disable=not shared.opts.upcast_attn):\r\n        k_in = k_in * self.scale\r\n\r\n        del context, x\r\n\r\n        q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))\r\n        del q_in, k_in, v_in\r\n\r\n        r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)\r\n\r\n        mem_free_total = get_available_vram()\r\n\r\n        gb = 1024 ** 3\r\n        tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()\r\n        modifier = 3 if q.element_size() == 2 else 2.5\r\n        mem_required = tensor_size * modifier\r\n        steps = 1\r\n\r\n        if mem_required > mem_free_total:\r\n            steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))\r\n            # print(f\"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB \"\r\n            #       f\"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}\")\r\n\r\n        if steps > 64:\r\n            max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64\r\n            raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '\r\n                               f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')\r\n\r\n        slice_size = q.shape[1] // steps\r\n        for i in range(0, q.shape[1], slice_size):\r\n            end = min(i + slice_size, q.shape[1])\r\n            s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)\r\n\r\n            s2 = s1.softmax(dim=-1, dtype=q.dtype)\r\n            del s1\r\n\r\n            r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)\r\n            del s2\r\n\r\n        del q, k, v\r\n\r\n    r1 = r1.to(dtype)\r\n\r\n    r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)\r\n    del r1\r\n\r\n    return self.to_out(r2)\r\n\r\n\r\n# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --\r\nmem_total_gb = psutil.virtual_memory().total // (1 << 30)\r\n\r\n\r\ndef einsum_op_compvis(q, k, v):\r\n    s = einsum('b i d, b j d -> b i j', q, k)\r\n    s = s.softmax(dim=-1, dtype=s.dtype)\r\n    return einsum('b i j, b j d -> b i d', s, v)\r\n\r\n\r\ndef einsum_op_slice_0(q, k, v, slice_size):\r\n    r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)\r\n    for i in range(0, q.shape[0], slice_size):\r\n        end = i + slice_size\r\n        r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end])\r\n    return r\r\n\r\n\r\ndef einsum_op_slice_1(q, k, v, slice_size):\r\n    r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)\r\n    for i in range(0, q.shape[1], slice_size):\r\n        end = i + slice_size\r\n        r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v)\r\n    return r\r\n\r\n\r\ndef einsum_op_mps_v1(q, k, v):\r\n    if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096\r\n        return einsum_op_compvis(q, k, v)\r\n    else:\r\n        slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))\r\n        if slice_size % 4096 == 0:\r\n            slice_size -= 1\r\n        return einsum_op_slice_1(q, k, v, slice_size)\r\n\r\n\r\ndef einsum_op_mps_v2(q, k, v):\r\n    if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:\r\n        return einsum_op_compvis(q, k, v)\r\n    else:\r\n        return einsum_op_slice_0(q, k, v, 1)\r\n\r\n\r\ndef einsum_op_tensor_mem(q, k, v, max_tensor_mb):\r\n    size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)\r\n    if size_mb <= max_tensor_mb:\r\n        return einsum_op_compvis(q, k, v)\r\n    div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()\r\n    if div <= q.shape[0]:\r\n        return einsum_op_slice_0(q, k, v, q.shape[0] // div)\r\n    return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))\r\n\r\n\r\ndef einsum_op_cuda(q, k, v):\r\n    stats = torch.cuda.memory_stats(q.device)\r\n    mem_active = stats['active_bytes.all.current']\r\n    mem_reserved = stats['reserved_bytes.all.current']\r\n    mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)\r\n    mem_free_torch = mem_reserved - mem_active\r\n    mem_free_total = mem_free_cuda + mem_free_torch\r\n    # Divide factor of safety as there's copying and fragmentation\r\n    return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))\r\n\r\n\r\ndef einsum_op(q, k, v):\r\n    if q.device.type == 'cuda':\r\n        return einsum_op_cuda(q, k, v)\r\n\r\n    if q.device.type == 'mps':\r\n        if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18:\r\n            return einsum_op_mps_v1(q, k, v)\r\n        return einsum_op_mps_v2(q, k, v)\r\n\r\n    # Smaller slices are faster due to L2/L3/SLC caches.\r\n    # Tested on i7 with 8MB L3 cache.\r\n    return einsum_op_tensor_mem(q, k, v, 32)\r\n\r\n\r\ndef split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, **kwargs):\r\n    h = self.heads\r\n\r\n    q = self.to_q(x)\r\n    context = default(context, x)\r\n\r\n    context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)\r\n    k = self.to_k(context_k)\r\n    v = self.to_v(context_v)\r\n    del context, context_k, context_v, x\r\n\r\n    dtype = q.dtype\r\n    if shared.opts.upcast_attn:\r\n        q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float()\r\n\r\n    with devices.without_autocast(disable=not shared.opts.upcast_attn):\r\n        k = k * self.scale\r\n\r\n        q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))\r\n        r = einsum_op(q, k, v)\r\n    r = r.to(dtype)\r\n    return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))\r\n\r\n# -- End of code from https://github.com/invoke-ai/InvokeAI --\r\n\r\n\r\n# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1\r\n# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface\r\ndef sub_quad_attention_forward(self, x, context=None, mask=None, **kwargs):\r\n    assert mask is None, \"attention-mask not currently implemented for SubQuadraticCrossAttnProcessor.\"\r\n\r\n    h = self.heads\r\n\r\n    q = self.to_q(x)\r\n    context = default(context, x)\r\n\r\n    context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)\r\n    k = self.to_k(context_k)\r\n    v = self.to_v(context_v)\r\n    del context, context_k, context_v, x\r\n\r\n    q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)\r\n    k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)\r\n    v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)\r\n\r\n    if q.device.type == 'mps':\r\n        q, k, v = q.contiguous(), k.contiguous(), v.contiguous()\r\n\r\n    dtype = q.dtype\r\n    if shared.opts.upcast_attn:\r\n        q, k = q.float(), k.float()\r\n\r\n    x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)\r\n\r\n    x = x.to(dtype)\r\n\r\n    x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)\r\n\r\n    out_proj, dropout = self.to_out\r\n    x = out_proj(x)\r\n    x = dropout(x)\r\n\r\n    return x\r\n\r\n\r\ndef sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):\r\n    bytes_per_token = torch.finfo(q.dtype).bits//8\r\n    batch_x_heads, q_tokens, _ = q.shape\r\n    _, k_tokens, _ = k.shape\r\n    qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens\r\n\r\n    if chunk_threshold is None:\r\n        if q.device.type == 'mps':\r\n            chunk_threshold_bytes = 268435456 * (2 if platform.processor() == 'i386' else bytes_per_token)\r\n        else:\r\n            chunk_threshold_bytes = int(get_available_vram() * 0.7)\r\n    elif chunk_threshold == 0:\r\n        chunk_threshold_bytes = None\r\n    else:\r\n        chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram())\r\n\r\n    if kv_chunk_size_min is None and chunk_threshold_bytes is not None:\r\n        kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2]))\r\n    elif kv_chunk_size_min == 0:\r\n        kv_chunk_size_min = None\r\n\r\n    if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:\r\n        # the big matmul fits into our memory limit; do everything in 1 chunk,\r\n        # i.e. send it down the unchunked fast-path\r\n        kv_chunk_size = k_tokens\r\n\r\n    with devices.without_autocast(disable=q.dtype == v.dtype):\r\n        return sub_quadratic_attention.efficient_dot_product_attention(\r\n            q,\r\n            k,\r\n            v,\r\n            query_chunk_size=q_chunk_size,\r\n            kv_chunk_size=kv_chunk_size,\r\n            kv_chunk_size_min = kv_chunk_size_min,\r\n            use_checkpoint=use_checkpoint,\r\n        )\r\n\r\n\r\ndef get_xformers_flash_attention_op(q, k, v):\r\n    if not shared.cmd_opts.xformers_flash_attention:\r\n        return None\r\n\r\n    try:\r\n        flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp\r\n        fw, bw = flash_attention_op\r\n        if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)):\r\n            return flash_attention_op\r\n    except Exception as e:\r\n        errors.display_once(e, \"enabling flash attention\")\r\n\r\n    return None\r\n\r\n\r\ndef xformers_attention_forward(self, x, context=None, mask=None, **kwargs):\r\n    h = self.heads\r\n    q_in = self.to_q(x)\r\n    context = default(context, x)\r\n\r\n    context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)\r\n    k_in = self.to_k(context_k)\r\n    v_in = self.to_v(context_v)\r\n\r\n    q, k, v = (t.reshape(t.shape[0], t.shape[1], h, -1) for t in (q_in, k_in, v_in))\r\n\r\n    del q_in, k_in, v_in\r\n\r\n    dtype = q.dtype\r\n    if shared.opts.upcast_attn:\r\n        q, k, v = q.float(), k.float(), v.float()\r\n\r\n    out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))\r\n\r\n    out = out.to(dtype)\r\n\r\n    b, n, h, d = out.shape\r\n    out = out.reshape(b, n, h * d)\r\n    return self.to_out(out)\r\n\r\n\r\n# Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py\r\n# The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface\r\ndef scaled_dot_product_attention_forward(self, x, context=None, mask=None, **kwargs):\r\n    batch_size, sequence_length, inner_dim = x.shape\r\n\r\n    if mask is not None:\r\n        mask = self.prepare_attention_mask(mask, sequence_length, batch_size)\r\n        mask = mask.view(batch_size, self.heads, -1, mask.shape[-1])\r\n\r\n    h = self.heads\r\n    q_in = self.to_q(x)\r\n    context = default(context, x)\r\n\r\n    context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)\r\n    k_in = self.to_k(context_k)\r\n    v_in = self.to_v(context_v)\r\n\r\n    head_dim = inner_dim // h\r\n    q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)\r\n    k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)\r\n    v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)\r\n\r\n    del q_in, k_in, v_in\r\n\r\n    dtype = q.dtype\r\n    if shared.opts.upcast_attn:\r\n        q, k, v = q.float(), k.float(), v.float()\r\n\r\n    # the output of sdp = (batch, num_heads, seq_len, head_dim)\r\n    hidden_states = torch.nn.functional.scaled_dot_product_attention(\r\n        q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False\r\n    )\r\n\r\n    hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, h * head_dim)\r\n    hidden_states = hidden_states.to(dtype)\r\n\r\n    # linear proj\r\n    hidden_states = self.to_out[0](hidden_states)\r\n    # dropout\r\n    hidden_states = self.to_out[1](hidden_states)\r\n    return hidden_states\r\n\r\n\r\ndef scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None, **kwargs):\r\n    with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):\r\n        return scaled_dot_product_attention_forward(self, x, context, mask)\r\n\r\n\r\ndef cross_attention_attnblock_forward(self, x):\r\n        h_ = x\r\n        h_ = self.norm(h_)\r\n        q1 = self.q(h_)\r\n        k1 = self.k(h_)\r\n        v = self.v(h_)\r\n\r\n        # compute attention\r\n        b, c, h, w = q1.shape\r\n\r\n        q2 = q1.reshape(b, c, h*w)\r\n        del q1\r\n\r\n        q = q2.permute(0, 2, 1)   # b,hw,c\r\n        del q2\r\n\r\n        k = k1.reshape(b, c, h*w) # b,c,hw\r\n        del k1\r\n\r\n        h_ = torch.zeros_like(k, device=q.device)\r\n\r\n        mem_free_total = get_available_vram()\r\n\r\n        tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()\r\n        mem_required = tensor_size * 2.5\r\n        steps = 1\r\n\r\n        if mem_required > mem_free_total:\r\n            steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))\r\n\r\n        slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]\r\n        for i in range(0, q.shape[1], slice_size):\r\n            end = i + slice_size\r\n\r\n            w1 = torch.bmm(q[:, i:end], k)     # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]\r\n            w2 = w1 * (int(c)**(-0.5))\r\n            del w1\r\n            w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)\r\n            del w2\r\n\r\n            # attend to values\r\n            v1 = v.reshape(b, c, h*w)\r\n            w4 = w3.permute(0, 2, 1)   # b,hw,hw (first hw of k, second of q)\r\n            del w3\r\n\r\n            h_[:, :, i:end] = torch.bmm(v1, w4)     # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]\r\n            del v1, w4\r\n\r\n        h2 = h_.reshape(b, c, h, w)\r\n        del h_\r\n\r\n        h3 = self.proj_out(h2)\r\n        del h2\r\n\r\n        h3 += x\r\n\r\n        return h3\r\n\r\n\r\ndef xformers_attnblock_forward(self, x):\r\n    try:\r\n        h_ = x\r\n        h_ = self.norm(h_)\r\n        q = self.q(h_)\r\n        k = self.k(h_)\r\n        v = self.v(h_)\r\n        b, c, h, w = q.shape\r\n        q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))\r\n        dtype = q.dtype\r\n        if shared.opts.upcast_attn:\r\n            q, k = q.float(), k.float()\r\n        q = q.contiguous()\r\n        k = k.contiguous()\r\n        v = v.contiguous()\r\n        out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))\r\n        out = out.to(dtype)\r\n        out = rearrange(out, 'b (h w) c -> b c h w', h=h)\r\n        out = self.proj_out(out)\r\n        return x + out\r\n    except NotImplementedError:\r\n        return cross_attention_attnblock_forward(self, x)\r\n\r\n\r\ndef sdp_attnblock_forward(self, x):\r\n    h_ = x\r\n    h_ = self.norm(h_)\r\n    q = self.q(h_)\r\n    k = self.k(h_)\r\n    v = self.v(h_)\r\n    b, c, h, w = q.shape\r\n    q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))\r\n    dtype = q.dtype\r\n    if shared.opts.upcast_attn:\r\n        q, k, v = q.float(), k.float(), v.float()\r\n    q = q.contiguous()\r\n    k = k.contiguous()\r\n    v = v.contiguous()\r\n    out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)\r\n    out = out.to(dtype)\r\n    out = rearrange(out, 'b (h w) c -> b c h w', h=h)\r\n    out = self.proj_out(out)\r\n    return x + out\r\n\r\n\r\ndef sdp_no_mem_attnblock_forward(self, x):\r\n    with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):\r\n        return sdp_attnblock_forward(self, x)\r\n\r\n\r\ndef sub_quad_attnblock_forward(self, x):\r\n    h_ = x\r\n    h_ = self.norm(h_)\r\n    q = self.q(h_)\r\n    k = self.k(h_)\r\n    v = self.v(h_)\r\n    b, c, h, w = q.shape\r\n    q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))\r\n    q = q.contiguous()\r\n    k = k.contiguous()\r\n    v = v.contiguous()\r\n    out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)\r\n    out = rearrange(out, 'b (h w) c -> b c h w', h=h)\r\n    out = self.proj_out(out)\r\n    return x + out\r\n"
  },
  {
    "path": "modules/sd_hijack_unet.py",
    "content": "import torch\r\nfrom packaging import version\r\nfrom einops import repeat\r\nimport math\r\n\r\nfrom modules import devices\r\nfrom modules.sd_hijack_utils import CondFunc\r\n\r\n\r\nclass TorchHijackForUnet:\r\n    \"\"\"\r\n    This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;\r\n    this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64\r\n    \"\"\"\r\n\r\n    def __getattr__(self, item):\r\n        if item == 'cat':\r\n            return self.cat\r\n\r\n        if hasattr(torch, item):\r\n            return getattr(torch, item)\r\n\r\n        raise AttributeError(f\"'{type(self).__name__}' object has no attribute '{item}'\")\r\n\r\n    def cat(self, tensors, *args, **kwargs):\r\n        if len(tensors) == 2:\r\n            a, b = tensors\r\n            if a.shape[-2:] != b.shape[-2:]:\r\n                a = torch.nn.functional.interpolate(a, b.shape[-2:], mode=\"nearest\")\r\n\r\n            tensors = (a, b)\r\n\r\n        return torch.cat(tensors, *args, **kwargs)\r\n\r\n\r\nth = TorchHijackForUnet()\r\n\r\n\r\n# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling\r\ndef apply_model(orig_func, self, x_noisy, t, cond, **kwargs):\r\n    \"\"\"Always make sure inputs to unet are in correct dtype.\"\"\"\r\n    if isinstance(cond, dict):\r\n        for y in cond.keys():\r\n            if isinstance(cond[y], list):\r\n                cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]\r\n            else:\r\n                cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y]\r\n\r\n    with devices.autocast():\r\n        result = orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs)\r\n        if devices.unet_needs_upcast:\r\n            return result.float()\r\n        else:\r\n            return result\r\n\r\n\r\n# Monkey patch to create timestep embed tensor on device, avoiding a block.\r\ndef timestep_embedding(_, timesteps, dim, max_period=10000, repeat_only=False):\r\n    \"\"\"\r\n    Create sinusoidal timestep embeddings.\r\n    :param timesteps: a 1-D Tensor of N indices, one per batch element.\r\n                      These may be fractional.\r\n    :param dim: the dimension of the output.\r\n    :param max_period: controls the minimum frequency of the embeddings.\r\n    :return: an [N x dim] Tensor of positional embeddings.\r\n    \"\"\"\r\n    if not repeat_only:\r\n        half = dim // 2\r\n        freqs = torch.exp(\r\n            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half\r\n        )\r\n        args = timesteps[:, None].float() * freqs[None]\r\n        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)\r\n        if dim % 2:\r\n            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)\r\n    else:\r\n        embedding = repeat(timesteps, 'b -> b d', d=dim)\r\n    return embedding\r\n\r\n\r\n# Monkey patch to SpatialTransformer removing unnecessary contiguous calls.\r\n# Prevents a lot of unnecessary aten::copy_ calls\r\ndef spatial_transformer_forward(_, self, x: torch.Tensor, context=None):\r\n    # note: if no context is given, cross-attention defaults to self-attention\r\n    if not isinstance(context, list):\r\n        context = [context]\r\n    b, c, h, w = x.shape\r\n    x_in = x\r\n    x = self.norm(x)\r\n    if not self.use_linear:\r\n        x = self.proj_in(x)\r\n    x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)\r\n    if self.use_linear:\r\n        x = self.proj_in(x)\r\n    for i, block in enumerate(self.transformer_blocks):\r\n        x = block(x, context=context[i])\r\n    if self.use_linear:\r\n        x = self.proj_out(x)\r\n    x = x.view(b, h, w, c).permute(0, 3, 1, 2)\r\n    if not self.use_linear:\r\n        x = self.proj_out(x)\r\n    return x + x_in\r\n\r\n\r\nclass GELUHijack(torch.nn.GELU, torch.nn.Module):\r\n    def __init__(self, *args, **kwargs):\r\n        torch.nn.GELU.__init__(self, *args, **kwargs)\r\n    def forward(self, x):\r\n        if devices.unet_needs_upcast:\r\n            return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet)\r\n        else:\r\n            return torch.nn.GELU.forward(self, x)\r\n\r\n\r\nddpm_edit_hijack = None\r\ndef hijack_ddpm_edit():\r\n    global ddpm_edit_hijack\r\n    if not ddpm_edit_hijack:\r\n        CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)\r\n        CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)\r\n        ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model)\r\n\r\n\r\nunet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast\r\nCondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)\r\nCondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding)\r\nCondFunc('ldm.modules.attention.SpatialTransformer.forward', spatial_transformer_forward)\r\nCondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)\r\n\r\nif version.parse(torch.__version__) <= version.parse(\"1.13.2\") or torch.cuda.is_available():\r\n    CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)\r\n    CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)\r\n    CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)\r\n\r\nfirst_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16\r\nfirst_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs)\r\nCondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)\r\nCondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)\r\nCondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)\r\n\r\nCondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model)\r\nCondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model)\r\n\r\n\r\ndef timestep_embedding_cast_result(orig_func, timesteps, *args, **kwargs):\r\n    if devices.unet_needs_upcast and timesteps.dtype == torch.int64:\r\n        dtype = torch.float32\r\n    else:\r\n        dtype = devices.dtype_unet\r\n    return orig_func(timesteps, *args, **kwargs).to(dtype=dtype)\r\n\r\n\r\nCondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)\r\nCondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result)\r\n"
  },
  {
    "path": "modules/sd_hijack_utils.py",
    "content": "import importlib\r\n\r\n\r\nalways_true_func = lambda *args, **kwargs: True\r\n\r\n\r\nclass CondFunc:\r\n    def __new__(cls, orig_func, sub_func, cond_func=always_true_func):\r\n        self = super(CondFunc, cls).__new__(cls)\r\n        if isinstance(orig_func, str):\r\n            func_path = orig_func.split('.')\r\n            for i in range(len(func_path)-1, -1, -1):\r\n                try:\r\n                    resolved_obj = importlib.import_module('.'.join(func_path[:i]))\r\n                    break\r\n                except ImportError:\r\n                    pass\r\n            try:\r\n                for attr_name in func_path[i:-1]:\r\n                    resolved_obj = getattr(resolved_obj, attr_name)\r\n                orig_func = getattr(resolved_obj, func_path[-1])\r\n                setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))\r\n            except AttributeError:\r\n                print(f\"Warning: Failed to resolve {orig_func} for CondFunc hijack\")\r\n                pass\r\n        self.__init__(orig_func, sub_func, cond_func)\r\n        return lambda *args, **kwargs: self(*args, **kwargs)\r\n    def __init__(self, orig_func, sub_func, cond_func):\r\n        self.__orig_func = orig_func\r\n        self.__sub_func = sub_func\r\n        self.__cond_func = cond_func\r\n    def __call__(self, *args, **kwargs):\r\n        if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):\r\n            return self.__sub_func(self.__orig_func, *args, **kwargs)\r\n        else:\r\n            return self.__orig_func(*args, **kwargs)\r\n"
  },
  {
    "path": "modules/sd_hijack_xlmr.py",
    "content": "import torch\r\n\r\nfrom modules import sd_hijack_clip, devices\r\n\r\n\r\nclass FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):\r\n    def __init__(self, wrapped, hijack):\r\n        super().__init__(wrapped, hijack)\r\n\r\n        self.id_start = wrapped.config.bos_token_id\r\n        self.id_end = wrapped.config.eos_token_id\r\n        self.id_pad = wrapped.config.pad_token_id\r\n\r\n        self.comma_token = self.tokenizer.get_vocab().get(',', None)  # alt diffusion doesn't have </w> bits for comma\r\n\r\n    def encode_with_transformers(self, tokens):\r\n        # there's no CLIP Skip here because all hidden layers have size of 1024 and the last one uses a\r\n        # trained layer to transform those 1024 into 768 for unet; so you can't choose which transformer\r\n        # layer to work with - you have to use the last\r\n\r\n        attention_mask = (tokens != self.id_pad).to(device=tokens.device, dtype=torch.int64)\r\n        features = self.wrapped(input_ids=tokens, attention_mask=attention_mask)\r\n        z = features['projection_state']\r\n\r\n        return z\r\n\r\n    def encode_embedding_init_text(self, init_text, nvpt):\r\n        embedding_layer = self.wrapped.roberta.embeddings\r\n        ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors=\"pt\", add_special_tokens=False)[\"input_ids\"]\r\n        embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)\r\n\r\n        return embedded\r\n"
  },
  {
    "path": "modules/sd_models.py",
    "content": "import collections\r\nimport importlib\r\nimport os\r\nimport sys\r\nimport threading\r\nimport enum\r\n\r\nimport torch\r\nimport re\r\nimport safetensors.torch\r\nfrom omegaconf import OmegaConf, ListConfig\r\nfrom urllib import request\r\nimport ldm.modules.midas as midas\r\n\r\nfrom modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches\r\nfrom modules.timer import Timer\r\nfrom modules.shared import opts\r\nimport tomesd\r\nimport numpy as np\r\n\r\nmodel_dir = \"Stable-diffusion\"\r\nmodel_path = os.path.abspath(os.path.join(paths.models_path, model_dir))\r\n\r\ncheckpoints_list = {}\r\ncheckpoint_aliases = {}\r\ncheckpoint_alisases = checkpoint_aliases  # for compatibility with old name\r\ncheckpoints_loaded = collections.OrderedDict()\r\n\r\n\r\nclass ModelType(enum.Enum):\r\n    SD1 = 1\r\n    SD2 = 2\r\n    SDXL = 3\r\n    SSD = 4\r\n    SD3 = 5\r\n\r\n\r\ndef replace_key(d, key, new_key, value):\r\n    keys = list(d.keys())\r\n\r\n    d[new_key] = value\r\n\r\n    if key not in keys:\r\n        return d\r\n\r\n    index = keys.index(key)\r\n    keys[index] = new_key\r\n\r\n    new_d = {k: d[k] for k in keys}\r\n\r\n    d.clear()\r\n    d.update(new_d)\r\n    return d\r\n\r\n\r\nclass CheckpointInfo:\r\n    def __init__(self, filename):\r\n        self.filename = filename\r\n        abspath = os.path.abspath(filename)\r\n        abs_ckpt_dir = os.path.abspath(shared.cmd_opts.ckpt_dir) if shared.cmd_opts.ckpt_dir is not None else None\r\n\r\n        self.is_safetensors = os.path.splitext(filename)[1].lower() == \".safetensors\"\r\n\r\n        if abs_ckpt_dir and abspath.startswith(abs_ckpt_dir):\r\n            name = abspath.replace(abs_ckpt_dir, '')\r\n        elif abspath.startswith(model_path):\r\n            name = abspath.replace(model_path, '')\r\n        else:\r\n            name = os.path.basename(filename)\r\n\r\n        if name.startswith(\"\\\\\") or name.startswith(\"/\"):\r\n            name = name[1:]\r\n\r\n        def read_metadata():\r\n            metadata = read_metadata_from_safetensors(filename)\r\n            self.modelspec_thumbnail = metadata.pop('modelspec.thumbnail', None)\r\n\r\n            return metadata\r\n\r\n        self.metadata = {}\r\n        if self.is_safetensors:\r\n            try:\r\n                self.metadata = cache.cached_data_for_file('safetensors-metadata', \"checkpoint/\" + name, filename, read_metadata)\r\n            except Exception as e:\r\n                errors.display(e, f\"reading metadata for {filename}\")\r\n\r\n        self.name = name\r\n        self.name_for_extra = os.path.splitext(os.path.basename(filename))[0]\r\n        self.model_name = os.path.splitext(name.replace(\"/\", \"_\").replace(\"\\\\\", \"_\"))[0]\r\n        self.hash = model_hash(filename)\r\n\r\n        self.sha256 = hashes.sha256_from_cache(self.filename, f\"checkpoint/{name}\")\r\n        self.shorthash = self.sha256[0:10] if self.sha256 else None\r\n\r\n        self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'\r\n        self.short_title = self.name_for_extra if self.shorthash is None else f'{self.name_for_extra} [{self.shorthash}]'\r\n\r\n        self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]']\r\n        if self.shorthash:\r\n            self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]', f'{self.name_for_extra} [{self.shorthash}]']\r\n\r\n    def register(self):\r\n        checkpoints_list[self.title] = self\r\n        for id in self.ids:\r\n            checkpoint_aliases[id] = self\r\n\r\n    def calculate_shorthash(self):\r\n        self.sha256 = hashes.sha256(self.filename, f\"checkpoint/{self.name}\")\r\n        if self.sha256 is None:\r\n            return\r\n\r\n        shorthash = self.sha256[0:10]\r\n        if self.shorthash == self.sha256[0:10]:\r\n            return self.shorthash\r\n\r\n        self.shorthash = shorthash\r\n\r\n        if self.shorthash not in self.ids:\r\n            self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]', f'{self.name_for_extra} [{self.shorthash}]']\r\n\r\n        old_title = self.title\r\n        self.title = f'{self.name} [{self.shorthash}]'\r\n        self.short_title = f'{self.name_for_extra} [{self.shorthash}]'\r\n\r\n        replace_key(checkpoints_list, old_title, self.title, self)\r\n        self.register()\r\n\r\n        return self.shorthash\r\n\r\n\r\ntry:\r\n    # this silences the annoying \"Some weights of the model checkpoint were not used when initializing...\" message at start.\r\n    from transformers import logging, CLIPModel  # noqa: F401\r\n\r\n    logging.set_verbosity_error()\r\nexcept Exception:\r\n    pass\r\n\r\n\r\ndef setup_model():\r\n    \"\"\"called once at startup to do various one-time tasks related to SD models\"\"\"\r\n\r\n    os.makedirs(model_path, exist_ok=True)\r\n\r\n    enable_midas_autodownload()\r\n    patch_given_betas()\r\n\r\n\r\ndef checkpoint_tiles(use_short=False):\r\n    return [x.short_title if use_short else x.title for x in checkpoints_list.values()]\r\n\r\n\r\ndef list_models():\r\n    checkpoints_list.clear()\r\n    checkpoint_aliases.clear()\r\n\r\n    cmd_ckpt = shared.cmd_opts.ckpt\r\n    if shared.cmd_opts.no_download_sd_model or cmd_ckpt != shared.sd_model_file or os.path.exists(cmd_ckpt):\r\n        model_url = None\r\n        expected_sha256 = None\r\n    else:\r\n        model_url = f\"{shared.hf_endpoint}/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors\"\r\n        expected_sha256 = '6ce0161689b3853acaa03779ec93eafe75a02f4ced659bee03f50797806fa2fa'\r\n\r\n    model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[\".ckpt\", \".safetensors\"], download_name=\"v1-5-pruned-emaonly.safetensors\", ext_blacklist=[\".vae.ckpt\", \".vae.safetensors\"], hash_prefix=expected_sha256)\r\n\r\n    if os.path.exists(cmd_ckpt):\r\n        checkpoint_info = CheckpointInfo(cmd_ckpt)\r\n        checkpoint_info.register()\r\n\r\n        shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title\r\n    elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:\r\n        print(f\"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}\", file=sys.stderr)\r\n\r\n    for filename in model_list:\r\n        checkpoint_info = CheckpointInfo(filename)\r\n        checkpoint_info.register()\r\n\r\n\r\nre_strip_checksum = re.compile(r\"\\s*\\[[^]]+]\\s*$\")\r\n\r\n\r\ndef get_closet_checkpoint_match(search_string):\r\n    if not search_string:\r\n        return None\r\n\r\n    checkpoint_info = checkpoint_aliases.get(search_string, None)\r\n    if checkpoint_info is not None:\r\n        return checkpoint_info\r\n\r\n    found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title))\r\n    if found:\r\n        return found[0]\r\n\r\n    search_string_without_checksum = re.sub(re_strip_checksum, '', search_string)\r\n    found = sorted([info for info in checkpoints_list.values() if search_string_without_checksum in info.title], key=lambda x: len(x.title))\r\n    if found:\r\n        return found[0]\r\n\r\n    return None\r\n\r\n\r\ndef model_hash(filename):\r\n    \"\"\"old hash that only looks at a small part of the file and is prone to collisions\"\"\"\r\n\r\n    try:\r\n        with open(filename, \"rb\") as file:\r\n            import hashlib\r\n            m = hashlib.sha256()\r\n\r\n            file.seek(0x100000)\r\n            m.update(file.read(0x10000))\r\n            return m.hexdigest()[0:8]\r\n    except FileNotFoundError:\r\n        return 'NOFILE'\r\n\r\n\r\ndef select_checkpoint():\r\n    \"\"\"Raises `FileNotFoundError` if no checkpoints are found.\"\"\"\r\n    model_checkpoint = shared.opts.sd_model_checkpoint\r\n\r\n    checkpoint_info = checkpoint_aliases.get(model_checkpoint, None)\r\n    if checkpoint_info is not None:\r\n        return checkpoint_info\r\n\r\n    if len(checkpoints_list) == 0:\r\n        error_message = \"No checkpoints found. When searching for checkpoints, looked at:\"\r\n        if shared.cmd_opts.ckpt is not None:\r\n            error_message += f\"\\n - file {os.path.abspath(shared.cmd_opts.ckpt)}\"\r\n        error_message += f\"\\n - directory {model_path}\"\r\n        if shared.cmd_opts.ckpt_dir is not None:\r\n            error_message += f\"\\n - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}\"\r\n        error_message += \"Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations.\"\r\n        raise FileNotFoundError(error_message)\r\n\r\n    checkpoint_info = next(iter(checkpoints_list.values()))\r\n    if model_checkpoint is not None:\r\n        print(f\"Checkpoint {model_checkpoint} not found; loading fallback {checkpoint_info.title}\", file=sys.stderr)\r\n\r\n    return checkpoint_info\r\n\r\n\r\ncheckpoint_dict_replacements_sd1 = {\r\n    'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',\r\n    'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',\r\n    'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',\r\n}\r\n\r\ncheckpoint_dict_replacements_sd2_turbo = { # Converts SD 2.1 Turbo from SGM to LDM format.\r\n    'conditioner.embedders.0.': 'cond_stage_model.',\r\n}\r\n\r\n\r\ndef transform_checkpoint_dict_key(k, replacements):\r\n    for text, replacement in replacements.items():\r\n        if k.startswith(text):\r\n            k = replacement + k[len(text):]\r\n\r\n    return k\r\n\r\n\r\ndef get_state_dict_from_checkpoint(pl_sd):\r\n    pl_sd = pl_sd.pop(\"state_dict\", pl_sd)\r\n    pl_sd.pop(\"state_dict\", None)\r\n\r\n    is_sd2_turbo = 'conditioner.embedders.0.model.ln_final.weight' in pl_sd and pl_sd['conditioner.embedders.0.model.ln_final.weight'].size()[0] == 1024\r\n\r\n    sd = {}\r\n    for k, v in pl_sd.items():\r\n        if is_sd2_turbo:\r\n            new_key = transform_checkpoint_dict_key(k, checkpoint_dict_replacements_sd2_turbo)\r\n        else:\r\n            new_key = transform_checkpoint_dict_key(k, checkpoint_dict_replacements_sd1)\r\n\r\n        if new_key is not None:\r\n            sd[new_key] = v\r\n\r\n    pl_sd.clear()\r\n    pl_sd.update(sd)\r\n\r\n    return pl_sd\r\n\r\n\r\ndef read_metadata_from_safetensors(filename):\r\n    import json\r\n\r\n    with open(filename, mode=\"rb\") as file:\r\n        metadata_len = file.read(8)\r\n        metadata_len = int.from_bytes(metadata_len, \"little\")\r\n        json_start = file.read(2)\r\n\r\n        assert metadata_len > 2 and json_start in (b'{\"', b\"{'\"), f\"{filename} is not a safetensors file\"\r\n\r\n        res = {}\r\n\r\n        try:\r\n            json_data = json_start + file.read(metadata_len-2)\r\n            json_obj = json.loads(json_data)\r\n            for k, v in json_obj.get(\"__metadata__\", {}).items():\r\n                res[k] = v\r\n                if isinstance(v, str) and v[0:1] == '{':\r\n                    try:\r\n                        res[k] = json.loads(v)\r\n                    except Exception:\r\n                        pass\r\n        except Exception:\r\n             errors.report(f\"Error reading metadata from file: {filename}\", exc_info=True)\r\n\r\n        return res\r\n\r\n\r\ndef read_state_dict(checkpoint_file, print_global_state=False, map_location=None):\r\n    _, extension = os.path.splitext(checkpoint_file)\r\n    if extension.lower() == \".safetensors\":\r\n        device = map_location or shared.weight_load_location or devices.get_optimal_device_name()\r\n\r\n        if not shared.opts.disable_mmap_load_safetensors:\r\n            pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)\r\n        else:\r\n            pl_sd = safetensors.torch.load(open(checkpoint_file, 'rb').read())\r\n            pl_sd = {k: v.to(device) for k, v in pl_sd.items()}\r\n    else:\r\n        pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)\r\n\r\n    if print_global_state and \"global_step\" in pl_sd:\r\n        print(f\"Global Step: {pl_sd['global_step']}\")\r\n\r\n    sd = get_state_dict_from_checkpoint(pl_sd)\r\n    return sd\r\n\r\n\r\ndef get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):\r\n    sd_model_hash = checkpoint_info.calculate_shorthash()\r\n    timer.record(\"calculate hash\")\r\n\r\n    if checkpoint_info in checkpoints_loaded:\r\n        # use checkpoint cache\r\n        print(f\"Loading weights [{sd_model_hash}] from cache\")\r\n        # move to end as latest\r\n        checkpoints_loaded.move_to_end(checkpoint_info)\r\n        return checkpoints_loaded[checkpoint_info]\r\n\r\n    print(f\"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}\")\r\n    res = read_state_dict(checkpoint_info.filename)\r\n    timer.record(\"load weights from disk\")\r\n\r\n    return res\r\n\r\n\r\nclass SkipWritingToConfig:\r\n    \"\"\"This context manager prevents load_model_weights from writing checkpoint name to the config when it loads weight.\"\"\"\r\n\r\n    skip = False\r\n    previous = None\r\n\r\n    def __enter__(self):\r\n        self.previous = SkipWritingToConfig.skip\r\n        SkipWritingToConfig.skip = True\r\n        return self\r\n\r\n    def __exit__(self, exc_type, exc_value, exc_traceback):\r\n        SkipWritingToConfig.skip = self.previous\r\n\r\n\r\ndef check_fp8(model):\r\n    if model is None:\r\n        return None\r\n    if devices.get_optimal_device_name() == \"mps\":\r\n        enable_fp8 = False\r\n    elif shared.opts.fp8_storage == \"Enable\":\r\n        enable_fp8 = True\r\n    elif getattr(model, \"is_sdxl\", False) and shared.opts.fp8_storage == \"Enable for SDXL\":\r\n        enable_fp8 = True\r\n    else:\r\n        enable_fp8 = False\r\n    return enable_fp8\r\n\r\n\r\ndef set_model_type(model, state_dict):\r\n    model.is_sd1 = False\r\n    model.is_sd2 = False\r\n    model.is_sdxl = False\r\n    model.is_ssd = False\r\n    model.is_sd3 = False\r\n\r\n    if \"model.diffusion_model.x_embedder.proj.weight\" in state_dict:\r\n        model.is_sd3 = True\r\n        model.model_type = ModelType.SD3\r\n    elif hasattr(model, 'conditioner'):\r\n        model.is_sdxl = True\r\n\r\n        if 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys():\r\n            model.is_ssd = True\r\n            model.model_type = ModelType.SSD\r\n        else:\r\n            model.model_type = ModelType.SDXL\r\n    elif hasattr(model.cond_stage_model, 'model'):\r\n        model.is_sd2 = True\r\n        model.model_type = ModelType.SD2\r\n    else:\r\n        model.is_sd1 = True\r\n        model.model_type = ModelType.SD1\r\n\r\n\r\ndef set_model_fields(model):\r\n    if not hasattr(model, 'latent_channels'):\r\n        model.latent_channels = 4\r\n\r\n\r\ndef load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):\r\n    sd_model_hash = checkpoint_info.calculate_shorthash()\r\n    timer.record(\"calculate hash\")\r\n\r\n    if devices.fp8:\r\n        # prevent model to load state dict in fp8\r\n        model.half()\r\n\r\n    if not SkipWritingToConfig.skip:\r\n        shared.opts.data[\"sd_model_checkpoint\"] = checkpoint_info.title\r\n\r\n    if state_dict is None:\r\n        state_dict = get_checkpoint_state_dict(checkpoint_info, timer)\r\n\r\n    set_model_type(model, state_dict)\r\n    set_model_fields(model)\r\n\r\n    if model.is_sdxl:\r\n        sd_models_xl.extend_sdxl(model)\r\n\r\n    if model.is_ssd:\r\n        sd_hijack.model_hijack.convert_sdxl_to_ssd(model)\r\n\r\n    if shared.opts.sd_checkpoint_cache > 0:\r\n        # cache newly loaded model\r\n        checkpoints_loaded[checkpoint_info] = state_dict.copy()\r\n\r\n    if hasattr(model, \"before_load_weights\"):\r\n        model.before_load_weights(state_dict)\r\n\r\n    model.load_state_dict(state_dict, strict=False)\r\n    timer.record(\"apply weights to model\")\r\n\r\n    if hasattr(model, \"after_load_weights\"):\r\n        model.after_load_weights(state_dict)\r\n\r\n    del state_dict\r\n\r\n    # Set is_sdxl_inpaint flag.\r\n    # Checks Unet structure to detect inpaint model. The inpaint model's\r\n    # checkpoint state_dict does not contain the key\r\n    # 'diffusion_model.input_blocks.0.0.weight'.\r\n    diffusion_model_input = model.model.state_dict().get(\r\n        'diffusion_model.input_blocks.0.0.weight'\r\n    )\r\n    model.is_sdxl_inpaint = (\r\n        model.is_sdxl and\r\n        diffusion_model_input is not None and\r\n        diffusion_model_input.shape[1] == 9\r\n    )\r\n\r\n    if shared.cmd_opts.opt_channelslast:\r\n        model.to(memory_format=torch.channels_last)\r\n        timer.record(\"apply channels_last\")\r\n\r\n    if shared.cmd_opts.no_half:\r\n        model.float()\r\n        model.alphas_cumprod_original = model.alphas_cumprod\r\n        devices.dtype_unet = torch.float32\r\n        assert shared.cmd_opts.precision != \"half\", \"Cannot use --precision half with --no-half\"\r\n        timer.record(\"apply float()\")\r\n    else:\r\n        vae = model.first_stage_model\r\n        depth_model = getattr(model, 'depth_model', None)\r\n\r\n        # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16\r\n        if shared.cmd_opts.no_half_vae:\r\n            model.first_stage_model = None\r\n        # with --upcast-sampling, don't convert the depth model weights to float16\r\n        if shared.cmd_opts.upcast_sampling and depth_model:\r\n            model.depth_model = None\r\n\r\n        alphas_cumprod = model.alphas_cumprod\r\n        model.alphas_cumprod = None\r\n        model.half()\r\n        model.alphas_cumprod = alphas_cumprod\r\n        model.alphas_cumprod_original = alphas_cumprod\r\n        model.first_stage_model = vae\r\n        if depth_model:\r\n            model.depth_model = depth_model\r\n\r\n        devices.dtype_unet = torch.float16\r\n        timer.record(\"apply half()\")\r\n\r\n    apply_alpha_schedule_override(model)\r\n\r\n    for module in model.modules():\r\n        if hasattr(module, 'fp16_weight'):\r\n            del module.fp16_weight\r\n        if hasattr(module, 'fp16_bias'):\r\n            del module.fp16_bias\r\n\r\n    if check_fp8(model):\r\n        devices.fp8 = True\r\n        first_stage = model.first_stage_model\r\n        model.first_stage_model = None\r\n        for module in model.modules():\r\n            if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):\r\n                if shared.opts.cache_fp16_weight:\r\n                    module.fp16_weight = module.weight.data.clone().cpu().half()\r\n                    if module.bias is not None:\r\n                        module.fp16_bias = module.bias.data.clone().cpu().half()\r\n                module.to(torch.float8_e4m3fn)\r\n        model.first_stage_model = first_stage\r\n        timer.record(\"apply fp8\")\r\n    else:\r\n        devices.fp8 = False\r\n\r\n    devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16\r\n\r\n    model.first_stage_model.to(devices.dtype_vae)\r\n    timer.record(\"apply dtype to VAE\")\r\n\r\n    # clean up cache if limit is reached\r\n    while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:\r\n        checkpoints_loaded.popitem(last=False)\r\n\r\n    model.sd_model_hash = sd_model_hash\r\n    model.sd_model_checkpoint = checkpoint_info.filename\r\n    model.sd_checkpoint_info = checkpoint_info\r\n    shared.opts.data[\"sd_checkpoint_hash\"] = checkpoint_info.sha256\r\n\r\n    if hasattr(model, 'logvar'):\r\n        model.logvar = model.logvar.to(devices.device)  # fix for training\r\n\r\n    sd_vae.delete_base_vae()\r\n    sd_vae.clear_loaded_vae()\r\n    vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename).tuple()\r\n    sd_vae.load_vae(model, vae_file, vae_source)\r\n    timer.record(\"load VAE\")\r\n\r\n\r\ndef enable_midas_autodownload():\r\n    \"\"\"\r\n    Gives the ldm.modules.midas.api.load_model function automatic downloading.\r\n\r\n    When the 512-depth-ema model, and other future models like it, is loaded,\r\n    it calls midas.api.load_model to load the associated midas depth model.\r\n    This function applies a wrapper to download the model to the correct\r\n    location automatically.\r\n    \"\"\"\r\n\r\n    midas_path = os.path.join(paths.models_path, 'midas')\r\n\r\n    # stable-diffusion-stability-ai hard-codes the midas model path to\r\n    # a location that differs from where other scripts using this model look.\r\n    # HACK: Overriding the path here.\r\n    for k, v in midas.api.ISL_PATHS.items():\r\n        file_name = os.path.basename(v)\r\n        midas.api.ISL_PATHS[k] = os.path.join(midas_path, file_name)\r\n\r\n    midas_urls = {\r\n        \"dpt_large\": \"https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt\",\r\n        \"dpt_hybrid\": \"https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt\",\r\n        \"midas_v21\": \"https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21-f6b98070.pt\",\r\n        \"midas_v21_small\": \"https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21_small-70d6b9c8.pt\",\r\n    }\r\n\r\n    midas.api.load_model_inner = midas.api.load_model\r\n\r\n    def load_model_wrapper(model_type):\r\n        path = midas.api.ISL_PATHS[model_type]\r\n        if not os.path.exists(path):\r\n            if not os.path.exists(midas_path):\r\n                os.mkdir(midas_path)\r\n\r\n            print(f\"Downloading midas model weights for {model_type} to {path}\")\r\n            request.urlretrieve(midas_urls[model_type], path)\r\n            print(f\"{model_type} downloaded\")\r\n\r\n        return midas.api.load_model_inner(model_type)\r\n\r\n    midas.api.load_model = load_model_wrapper\r\n\r\n\r\ndef patch_given_betas():\r\n    import ldm.models.diffusion.ddpm\r\n\r\n    def patched_register_schedule(*args, **kwargs):\r\n        \"\"\"a modified version of register_schedule function that converts plain list from Omegaconf into numpy\"\"\"\r\n\r\n        if isinstance(args[1], ListConfig):\r\n            args = (args[0], np.array(args[1]), *args[2:])\r\n\r\n        original_register_schedule(*args, **kwargs)\r\n\r\n    original_register_schedule = patches.patch(__name__, ldm.models.diffusion.ddpm.DDPM, 'register_schedule', patched_register_schedule)\r\n\r\n\r\ndef repair_config(sd_config, state_dict=None):\r\n    if not hasattr(sd_config.model.params, \"use_ema\"):\r\n        sd_config.model.params.use_ema = False\r\n\r\n    if hasattr(sd_config.model.params, 'unet_config'):\r\n        if shared.cmd_opts.no_half:\r\n            sd_config.model.params.unet_config.params.use_fp16 = False\r\n        elif shared.cmd_opts.upcast_sampling or shared.cmd_opts.precision == \"half\":\r\n            sd_config.model.params.unet_config.params.use_fp16 = True\r\n\r\n    if hasattr(sd_config.model.params, 'first_stage_config'):\r\n        if getattr(sd_config.model.params.first_stage_config.params.ddconfig, \"attn_type\", None) == \"vanilla-xformers\" and not shared.xformers_available:\r\n            sd_config.model.params.first_stage_config.params.ddconfig.attn_type = \"vanilla\"\r\n\r\n    # For UnCLIP-L, override the hardcoded karlo directory\r\n    if hasattr(sd_config.model.params, \"noise_aug_config\") and hasattr(sd_config.model.params.noise_aug_config.params, \"clip_stats_path\"):\r\n        karlo_path = os.path.join(paths.models_path, 'karlo')\r\n        sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace(\"checkpoints/karlo_models\", karlo_path)\r\n\r\n    # Do not use checkpoint for inference.\r\n    # This helps prevent extra performance overhead on checking parameters.\r\n    # The perf overhead is about 100ms/it on 4090 for SDXL.\r\n    if hasattr(sd_config.model.params, \"network_config\"):\r\n        sd_config.model.params.network_config.params.use_checkpoint = False\r\n    if hasattr(sd_config.model.params, \"unet_config\"):\r\n        sd_config.model.params.unet_config.params.use_checkpoint = False\r\n\r\n\r\n\r\ndef rescale_zero_terminal_snr_abar(alphas_cumprod):\r\n    alphas_bar_sqrt = alphas_cumprod.sqrt()\r\n\r\n    # Store old values.\r\n    alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()\r\n    alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()\r\n\r\n    # Shift so the last timestep is zero.\r\n    alphas_bar_sqrt -= (alphas_bar_sqrt_T)\r\n\r\n    # Scale so the first timestep is back to the old value.\r\n    alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)\r\n\r\n    # Convert alphas_bar_sqrt to betas\r\n    alphas_bar = alphas_bar_sqrt ** 2  # Revert sqrt\r\n    alphas_bar[-1] = 4.8973451890853435e-08\r\n    return alphas_bar\r\n\r\n\r\ndef apply_alpha_schedule_override(sd_model, p=None):\r\n    \"\"\"\r\n    Applies an override to the alpha schedule of the model according to settings.\r\n    - downcasts the alpha schedule to half precision\r\n    - rescales the alpha schedule to have zero terminal SNR\r\n    \"\"\"\r\n\r\n    if not hasattr(sd_model, 'alphas_cumprod') or not hasattr(sd_model, 'alphas_cumprod_original'):\r\n        return\r\n\r\n    sd_model.alphas_cumprod = sd_model.alphas_cumprod_original.to(shared.device)\r\n\r\n    if opts.use_downcasted_alpha_bar:\r\n        if p is not None:\r\n            p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar\r\n        sd_model.alphas_cumprod = sd_model.alphas_cumprod.half().to(shared.device)\r\n\r\n    if opts.sd_noise_schedule == \"Zero Terminal SNR\":\r\n        if p is not None:\r\n            p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule\r\n        sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(sd_model.alphas_cumprod).to(shared.device)\r\n\r\n\r\nsd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'\r\nsd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'\r\nsdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight'\r\nsdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight'\r\n\r\n\r\nclass SdModelData:\r\n    def __init__(self):\r\n        self.sd_model = None\r\n        self.loaded_sd_models = []\r\n        self.was_loaded_at_least_once = False\r\n        self.lock = threading.Lock()\r\n\r\n    def get_sd_model(self):\r\n        if self.was_loaded_at_least_once:\r\n            return self.sd_model\r\n\r\n        if self.sd_model is None:\r\n            with self.lock:\r\n                if self.sd_model is not None or self.was_loaded_at_least_once:\r\n                    return self.sd_model\r\n\r\n                try:\r\n                    load_model()\r\n\r\n                except Exception as e:\r\n                    errors.display(e, \"loading stable diffusion model\", full_traceback=True)\r\n                    print(\"\", file=sys.stderr)\r\n                    print(\"Stable diffusion model failed to load\", file=sys.stderr)\r\n                    self.sd_model = None\r\n\r\n        return self.sd_model\r\n\r\n    def set_sd_model(self, v, already_loaded=False):\r\n        self.sd_model = v\r\n        if already_loaded:\r\n            sd_vae.base_vae = getattr(v, \"base_vae\", None)\r\n            sd_vae.loaded_vae_file = getattr(v, \"loaded_vae_file\", None)\r\n            sd_vae.checkpoint_info = v.sd_checkpoint_info\r\n\r\n        try:\r\n            self.loaded_sd_models.remove(v)\r\n        except ValueError:\r\n            pass\r\n\r\n        if v is not None:\r\n            self.loaded_sd_models.insert(0, v)\r\n\r\n\r\nmodel_data = SdModelData()\r\n\r\n\r\ndef get_empty_cond(sd_model):\r\n\r\n    p = processing.StableDiffusionProcessingTxt2Img()\r\n    extra_networks.activate(p, {})\r\n\r\n    if hasattr(sd_model, 'get_learned_conditioning'):\r\n        d = sd_model.get_learned_conditioning([\"\"])\r\n    else:\r\n        d = sd_model.cond_stage_model([\"\"])\r\n\r\n    if isinstance(d, dict):\r\n        d = d['crossattn']\r\n\r\n    return d\r\n\r\n\r\ndef send_model_to_cpu(m):\r\n    if m is not None:\r\n        if m.lowvram:\r\n            lowvram.send_everything_to_cpu()\r\n        else:\r\n            m.to(devices.cpu)\r\n\r\n    devices.torch_gc()\r\n\r\n\r\ndef model_target_device(m):\r\n    if lowvram.is_needed(m):\r\n        return devices.cpu\r\n    else:\r\n        return devices.device\r\n\r\n\r\ndef send_model_to_device(m):\r\n    lowvram.apply(m)\r\n\r\n    if not m.lowvram:\r\n        m.to(shared.device)\r\n\r\n\r\ndef send_model_to_trash(m):\r\n    m.to(device=\"meta\")\r\n    devices.torch_gc()\r\n\r\n\r\ndef instantiate_from_config(config, state_dict=None):\r\n    constructor = get_obj_from_str(config[\"target\"])\r\n\r\n    params = {**config.get(\"params\", {})}\r\n\r\n    if state_dict and \"state_dict\" in params and params[\"state_dict\"] is None:\r\n        params[\"state_dict\"] = state_dict\r\n\r\n    return constructor(**params)\r\n\r\n\r\ndef get_obj_from_str(string, reload=False):\r\n    module, cls = string.rsplit(\".\", 1)\r\n    if reload:\r\n        module_imp = importlib.import_module(module)\r\n        importlib.reload(module_imp)\r\n    return getattr(importlib.import_module(module, package=None), cls)\r\n\r\n\r\ndef load_model(checkpoint_info=None, already_loaded_state_dict=None):\r\n    from modules import sd_hijack\r\n    checkpoint_info = checkpoint_info or select_checkpoint()\r\n\r\n    timer = Timer()\r\n\r\n    if model_data.sd_model:\r\n        send_model_to_trash(model_data.sd_model)\r\n        model_data.sd_model = None\r\n        devices.torch_gc()\r\n\r\n    timer.record(\"unload existing model\")\r\n\r\n    if already_loaded_state_dict is not None:\r\n        state_dict = already_loaded_state_dict\r\n    else:\r\n        state_dict = get_checkpoint_state_dict(checkpoint_info, timer)\r\n\r\n    checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)\r\n    clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict)\r\n\r\n    timer.record(\"find config\")\r\n\r\n    sd_config = OmegaConf.load(checkpoint_config)\r\n    repair_config(sd_config, state_dict)\r\n\r\n    timer.record(\"load config\")\r\n\r\n    print(f\"Creating model from config: {checkpoint_config}\")\r\n\r\n    sd_model = None\r\n    try:\r\n        with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):\r\n            with sd_disable_initialization.InitializeOnMeta():\r\n                sd_model = instantiate_from_config(sd_config.model, state_dict)\r\n\r\n    except Exception as e:\r\n        errors.display(e, \"creating model quickly\", full_traceback=True)\r\n\r\n    if sd_model is None:\r\n        print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)\r\n\r\n        with sd_disable_initialization.InitializeOnMeta():\r\n            sd_model = instantiate_from_config(sd_config.model, state_dict)\r\n\r\n    sd_model.used_config = checkpoint_config\r\n\r\n    timer.record(\"create model\")\r\n\r\n    if shared.cmd_opts.no_half:\r\n        weight_dtype_conversion = None\r\n    else:\r\n        weight_dtype_conversion = {\r\n            'first_stage_model': None,\r\n            'alphas_cumprod': None,\r\n            '': torch.float16,\r\n        }\r\n\r\n    with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):\r\n        load_model_weights(sd_model, checkpoint_info, state_dict, timer)\r\n\r\n    timer.record(\"load weights from state dict\")\r\n\r\n    send_model_to_device(sd_model)\r\n    timer.record(\"move model to device\")\r\n\r\n    sd_hijack.model_hijack.hijack(sd_model)\r\n\r\n    timer.record(\"hijack\")\r\n\r\n    sd_model.eval()\r\n    model_data.set_sd_model(sd_model)\r\n    model_data.was_loaded_at_least_once = True\r\n\r\n    sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)  # Reload embeddings after model load as they may or may not fit the model\r\n\r\n    timer.record(\"load textual inversion embeddings\")\r\n\r\n    script_callbacks.model_loaded_callback(sd_model)\r\n\r\n    timer.record(\"scripts callbacks\")\r\n\r\n    with devices.autocast(), torch.no_grad():\r\n        sd_model.cond_stage_model_empty_prompt = get_empty_cond(sd_model)\r\n\r\n    timer.record(\"calculate empty prompt\")\r\n\r\n    print(f\"Model loaded in {timer.summary()}.\")\r\n\r\n    return sd_model\r\n\r\n\r\ndef reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):\r\n    \"\"\"\r\n    Checks if the desired checkpoint from checkpoint_info is not already loaded in model_data.loaded_sd_models.\r\n    If it is loaded, returns that (moving it to GPU if necessary, and moving the currently loadded model to CPU if necessary).\r\n    If not, returns the model that can be used to load weights from checkpoint_info's file.\r\n    If no such model exists, returns None.\r\n    Additionally deletes loaded models that are over the limit set in settings (sd_checkpoints_limit).\r\n    \"\"\"\r\n\r\n    if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:\r\n        return sd_model\r\n\r\n    if shared.opts.sd_checkpoints_keep_in_cpu:\r\n        send_model_to_cpu(sd_model)\r\n        timer.record(\"send model to cpu\")\r\n\r\n    already_loaded = None\r\n    for i in reversed(range(len(model_data.loaded_sd_models))):\r\n        loaded_model = model_data.loaded_sd_models[i]\r\n        if loaded_model.sd_checkpoint_info.filename == checkpoint_info.filename:\r\n            already_loaded = loaded_model\r\n            continue\r\n\r\n        if len(model_data.loaded_sd_models) > shared.opts.sd_checkpoints_limit > 0:\r\n            print(f\"Unloading model {len(model_data.loaded_sd_models)} over the limit of {shared.opts.sd_checkpoints_limit}: {loaded_model.sd_checkpoint_info.title}\")\r\n            del model_data.loaded_sd_models[i]\r\n            send_model_to_trash(loaded_model)\r\n            timer.record(\"send model to trash\")\r\n\r\n    if already_loaded is not None:\r\n        send_model_to_device(already_loaded)\r\n        timer.record(\"send model to device\")\r\n\r\n        model_data.set_sd_model(already_loaded, already_loaded=True)\r\n\r\n        if not SkipWritingToConfig.skip:\r\n            shared.opts.data[\"sd_model_checkpoint\"] = already_loaded.sd_checkpoint_info.title\r\n            shared.opts.data[\"sd_checkpoint_hash\"] = already_loaded.sd_checkpoint_info.sha256\r\n\r\n        print(f\"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}\")\r\n        sd_vae.reload_vae_weights(already_loaded)\r\n        return model_data.sd_model\r\n    elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit:\r\n        print(f\"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})\")\r\n\r\n        model_data.sd_model = None\r\n        load_model(checkpoint_info)\r\n        return model_data.sd_model\r\n    elif len(model_data.loaded_sd_models) > 0:\r\n        sd_model = model_data.loaded_sd_models.pop()\r\n        model_data.sd_model = sd_model\r\n\r\n        sd_vae.base_vae = getattr(sd_model, \"base_vae\", None)\r\n        sd_vae.loaded_vae_file = getattr(sd_model, \"loaded_vae_file\", None)\r\n        sd_vae.checkpoint_info = sd_model.sd_checkpoint_info\r\n\r\n        print(f\"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}\")\r\n        return sd_model\r\n    else:\r\n        return None\r\n\r\n\r\ndef reload_model_weights(sd_model=None, info=None, forced_reload=False):\r\n    checkpoint_info = info or select_checkpoint()\r\n\r\n    timer = Timer()\r\n\r\n    if not sd_model:\r\n        sd_model = model_data.sd_model\r\n\r\n    if sd_model is None:  # previous model load failed\r\n        current_checkpoint_info = None\r\n    else:\r\n        current_checkpoint_info = sd_model.sd_checkpoint_info\r\n        if check_fp8(sd_model) != devices.fp8:\r\n            # load from state dict again to prevent extra numerical errors\r\n            forced_reload = True\r\n        elif sd_model.sd_model_checkpoint == checkpoint_info.filename and not forced_reload:\r\n            return sd_model\r\n\r\n    sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)\r\n    if not forced_reload and sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:\r\n        return sd_model\r\n\r\n    if sd_model is not None:\r\n        sd_unet.apply_unet(\"None\")\r\n        send_model_to_cpu(sd_model)\r\n        sd_hijack.model_hijack.undo_hijack(sd_model)\r\n\r\n    state_dict = get_checkpoint_state_dict(checkpoint_info, timer)\r\n\r\n    checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)\r\n\r\n    timer.record(\"find config\")\r\n\r\n    if sd_model is None or checkpoint_config != sd_model.used_config:\r\n        if sd_model is not None:\r\n            send_model_to_trash(sd_model)\r\n\r\n        load_model(checkpoint_info, already_loaded_state_dict=state_dict)\r\n        return model_data.sd_model\r\n\r\n    try:\r\n        load_model_weights(sd_model, checkpoint_info, state_dict, timer)\r\n    except Exception:\r\n        print(\"Failed to load checkpoint, restoring previous\")\r\n        load_model_weights(sd_model, current_checkpoint_info, None, timer)\r\n        raise\r\n    finally:\r\n        sd_hijack.model_hijack.hijack(sd_model)\r\n        timer.record(\"hijack\")\r\n\r\n        if not sd_model.lowvram:\r\n            sd_model.to(devices.device)\r\n            timer.record(\"move model to device\")\r\n\r\n        script_callbacks.model_loaded_callback(sd_model)\r\n        timer.record(\"script callbacks\")\r\n\r\n    print(f\"Weights loaded in {timer.summary()}.\")\r\n\r\n    model_data.set_sd_model(sd_model)\r\n    sd_unet.apply_unet()\r\n\r\n    return sd_model\r\n\r\n\r\ndef unload_model_weights(sd_model=None, info=None):\r\n    send_model_to_cpu(sd_model or shared.sd_model)\r\n\r\n    return sd_model\r\n\r\n\r\ndef apply_token_merging(sd_model, token_merging_ratio):\r\n    \"\"\"\r\n    Applies speed and memory optimizations from tomesd.\r\n    \"\"\"\r\n\r\n    current_token_merging_ratio = getattr(sd_model, 'applied_token_merged_ratio', 0)\r\n\r\n    if current_token_merging_ratio == token_merging_ratio:\r\n        return\r\n\r\n    if current_token_merging_ratio > 0:\r\n        tomesd.remove_patch(sd_model)\r\n\r\n    if token_merging_ratio > 0:\r\n        tomesd.apply_patch(\r\n            sd_model,\r\n            ratio=token_merging_ratio,\r\n            use_rand=False,  # can cause issues with some samplers\r\n            merge_attn=True,\r\n            merge_crossattn=False,\r\n            merge_mlp=False\r\n        )\r\n\r\n    sd_model.applied_token_merged_ratio = token_merging_ratio\r\n"
  },
  {
    "path": "modules/sd_models_config.py",
    "content": "import os\r\n\r\nimport torch\r\n\r\nfrom modules import shared, paths, sd_disable_initialization, devices\r\n\r\nsd_configs_path = shared.sd_configs_path\r\nsd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], \"configs\", \"stable-diffusion\")\r\nsd_xl_repo_configs_path = os.path.join(paths.paths['Stable Diffusion XL'], \"configs\", \"inference\")\r\n\r\n\r\nconfig_default = shared.sd_default_config\r\nconfig_sd2 = os.path.join(sd_repo_configs_path, \"v2-inference.yaml\")\r\nconfig_sd2v = os.path.join(sd_repo_configs_path, \"v2-inference-v.yaml\")\r\nconfig_sd2_inpainting = os.path.join(sd_repo_configs_path, \"v2-inpainting-inference.yaml\")\r\nconfig_sdxl = os.path.join(sd_xl_repo_configs_path, \"sd_xl_base.yaml\")\r\nconfig_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, \"sd_xl_refiner.yaml\")\r\nconfig_sdxl_inpainting = os.path.join(sd_configs_path, \"sd_xl_inpaint.yaml\")\r\nconfig_depth_model = os.path.join(sd_repo_configs_path, \"v2-midas-inference.yaml\")\r\nconfig_unclip = os.path.join(sd_repo_configs_path, \"v2-1-stable-unclip-l-inference.yaml\")\r\nconfig_unopenclip = os.path.join(sd_repo_configs_path, \"v2-1-stable-unclip-h-inference.yaml\")\r\nconfig_inpainting = os.path.join(sd_configs_path, \"v1-inpainting-inference.yaml\")\r\nconfig_instruct_pix2pix = os.path.join(sd_configs_path, \"instruct-pix2pix.yaml\")\r\nconfig_alt_diffusion = os.path.join(sd_configs_path, \"alt-diffusion-inference.yaml\")\r\nconfig_alt_diffusion_m18 = os.path.join(sd_configs_path, \"alt-diffusion-m18-inference.yaml\")\r\nconfig_sd3 = os.path.join(sd_configs_path, \"sd3-inference.yaml\")\r\n\r\n\r\ndef is_using_v_parameterization_for_sd2(state_dict):\r\n    \"\"\"\r\n    Detects whether unet in state_dict is using v-parameterization. Returns True if it is. You're welcome.\r\n    \"\"\"\r\n\r\n    import ldm.modules.diffusionmodules.openaimodel\r\n\r\n    device = devices.device\r\n\r\n    with sd_disable_initialization.DisableInitialization():\r\n        unet = ldm.modules.diffusionmodules.openaimodel.UNetModel(\r\n            use_checkpoint=False,\r\n            use_fp16=False,\r\n            image_size=32,\r\n            in_channels=4,\r\n            out_channels=4,\r\n            model_channels=320,\r\n            attention_resolutions=[4, 2, 1],\r\n            num_res_blocks=2,\r\n            channel_mult=[1, 2, 4, 4],\r\n            num_head_channels=64,\r\n            use_spatial_transformer=True,\r\n            use_linear_in_transformer=True,\r\n            transformer_depth=1,\r\n            context_dim=1024,\r\n            legacy=False\r\n        )\r\n        unet.eval()\r\n\r\n    with torch.no_grad():\r\n        unet_sd = {k.replace(\"model.diffusion_model.\", \"\"): v for k, v in state_dict.items() if \"model.diffusion_model.\" in k}\r\n        unet.load_state_dict(unet_sd, strict=True)\r\n        unet.to(device=device, dtype=devices.dtype_unet)\r\n\r\n        test_cond = torch.ones((1, 2, 1024), device=device) * 0.5\r\n        x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5\r\n\r\n        with devices.autocast():\r\n            out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().cpu().item()\r\n\r\n    return out < -1\r\n\r\n\r\ndef guess_model_config_from_state_dict(sd, filename):\r\n    sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None)\r\n    diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)\r\n    sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)\r\n\r\n    if \"model.diffusion_model.x_embedder.proj.weight\" in sd:\r\n        return config_sd3\r\n\r\n    if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:\r\n        if diffusion_model_input.shape[1] == 9:\r\n            return config_sdxl_inpainting\r\n        else:\r\n            return config_sdxl\r\n\r\n    if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:\r\n        return config_sdxl_refiner\r\n    elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:\r\n        return config_depth_model\r\n    elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768:\r\n        return config_unclip\r\n    elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 1024:\r\n        return config_unopenclip\r\n\r\n    if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024:\r\n        if diffusion_model_input.shape[1] == 9:\r\n            return config_sd2_inpainting\r\n        elif is_using_v_parameterization_for_sd2(sd):\r\n            return config_sd2v\r\n        else:\r\n            return config_sd2\r\n\r\n    if diffusion_model_input is not None:\r\n        if diffusion_model_input.shape[1] == 9:\r\n            return config_inpainting\r\n        if diffusion_model_input.shape[1] == 8:\r\n            return config_instruct_pix2pix\r\n\r\n    if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None:\r\n        if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024:\r\n            return config_alt_diffusion_m18\r\n        return config_alt_diffusion\r\n\r\n    return config_default\r\n\r\n\r\ndef find_checkpoint_config(state_dict, info):\r\n    if info is None:\r\n        return guess_model_config_from_state_dict(state_dict, \"\")\r\n\r\n    config = find_checkpoint_config_near_filename(info)\r\n    if config is not None:\r\n        return config\r\n\r\n    return guess_model_config_from_state_dict(state_dict, info.filename)\r\n\r\n\r\ndef find_checkpoint_config_near_filename(info):\r\n    if info is None:\r\n        return None\r\n\r\n    config = f\"{os.path.splitext(info.filename)[0]}.yaml\"\r\n    if os.path.exists(config):\r\n        return config\r\n\r\n    return None\r\n\r\n"
  },
  {
    "path": "modules/sd_models_types.py",
    "content": "from ldm.models.diffusion.ddpm import LatentDiffusion\r\nfrom typing import TYPE_CHECKING\r\n\r\n\r\nif TYPE_CHECKING:\r\n    from modules.sd_models import CheckpointInfo\r\n\r\n\r\nclass WebuiSdModel(LatentDiffusion):\r\n    \"\"\"This class is not actually instantinated, but its fields are created and fieeld by webui\"\"\"\r\n\r\n    lowvram: bool\r\n    \"\"\"True if lowvram/medvram optimizations are enabled -- see modules.lowvram for more info\"\"\"\r\n\r\n    sd_model_hash: str\r\n    \"\"\"short hash, 10 first characters of SHA1 hash of the model file; may be None if --no-hashing flag is used\"\"\"\r\n\r\n    sd_model_checkpoint: str\r\n    \"\"\"path to the file on disk that model weights were obtained from\"\"\"\r\n\r\n    sd_checkpoint_info: 'CheckpointInfo'\r\n    \"\"\"structure with additional information about the file with model's weights\"\"\"\r\n\r\n    is_sdxl: bool\r\n    \"\"\"True if the model's architecture is SDXL or SSD\"\"\"\r\n\r\n    is_ssd: bool\r\n    \"\"\"True if the model is SSD\"\"\"\r\n\r\n    is_sd2: bool\r\n    \"\"\"True if the model's architecture is SD 2.x\"\"\"\r\n\r\n    is_sd1: bool\r\n    \"\"\"True if the model's architecture is SD 1.x\"\"\"\r\n\r\n    is_sd3: bool\r\n    \"\"\"True if the model's architecture is SD 3\"\"\"\r\n\r\n    latent_channels: int\r\n    \"\"\"number of layer in latent image representation; will be 16 in SD3 and 4 in other version\"\"\"\r\n"
  },
  {
    "path": "modules/sd_models_xl.py",
    "content": "from __future__ import annotations\r\n\r\nimport torch\r\n\r\nimport sgm.models.diffusion\r\nimport sgm.modules.diffusionmodules.denoiser_scaling\r\nimport sgm.modules.diffusionmodules.discretizer\r\nfrom modules import devices, shared, prompt_parser\r\nfrom modules import torch_utils\r\n\r\n\r\ndef get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):\r\n    for embedder in self.conditioner.embedders:\r\n        embedder.ucg_rate = 0.0\r\n\r\n    width = getattr(batch, 'width', 1024) or 1024\r\n    height = getattr(batch, 'height', 1024) or 1024\r\n    is_negative_prompt = getattr(batch, 'is_negative_prompt', False)\r\n    aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score\r\n\r\n    devices_args = dict(device=devices.device, dtype=devices.dtype)\r\n\r\n    sdxl_conds = {\r\n        \"txt\": batch,\r\n        \"original_size_as_tuple\": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),\r\n        \"crop_coords_top_left\": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left], **devices_args).repeat(len(batch), 1),\r\n        \"target_size_as_tuple\": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),\r\n        \"aesthetic_score\": torch.tensor([aesthetic_score], **devices_args).repeat(len(batch), 1),\r\n    }\r\n\r\n    force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in batch)\r\n    c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else [])\r\n\r\n    return c\r\n\r\n\r\ndef apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond):\r\n    \"\"\"WARNING: This function is called once per denoising iteration. DO NOT add\r\n    expensive functionc calls such as `model.state_dict`. \"\"\"\r\n    if self.is_sdxl_inpaint:\r\n        x = torch.cat([x] + cond['c_concat'], dim=1)\r\n\r\n    return self.model(x, t, cond)\r\n\r\n\r\ndef get_first_stage_encoding(self, x):  # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility\r\n    return x\r\n\r\n\r\nsgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning\r\nsgm.models.diffusion.DiffusionEngine.apply_model = apply_model\r\nsgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding\r\n\r\n\r\ndef encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text, nvpt):\r\n    res = []\r\n\r\n    for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'encode_embedding_init_text')]:\r\n        encoded = embedder.encode_embedding_init_text(init_text, nvpt)\r\n        res.append(encoded)\r\n\r\n    return torch.cat(res, dim=1)\r\n\r\n\r\ndef tokenize(self: sgm.modules.GeneralConditioner, texts):\r\n    for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'tokenize')]:\r\n        return embedder.tokenize(texts)\r\n\r\n    raise AssertionError('no tokenizer available')\r\n\r\n\r\n\r\ndef process_texts(self, texts):\r\n    for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]:\r\n        return embedder.process_texts(texts)\r\n\r\n\r\ndef get_target_prompt_token_count(self, token_count):\r\n    for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'get_target_prompt_token_count')]:\r\n        return embedder.get_target_prompt_token_count(token_count)\r\n\r\n\r\n# those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist\r\nsgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text\r\nsgm.modules.GeneralConditioner.tokenize = tokenize\r\nsgm.modules.GeneralConditioner.process_texts = process_texts\r\nsgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count\r\n\r\n\r\ndef extend_sdxl(model):\r\n    \"\"\"this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase.\"\"\"\r\n\r\n    dtype = torch_utils.get_param(model.model.diffusion_model).dtype\r\n    model.model.diffusion_model.dtype = dtype\r\n    model.model.conditioning_key = 'crossattn'\r\n    model.cond_stage_key = 'txt'\r\n    # model.cond_stage_model will be set in sd_hijack\r\n\r\n    model.parameterization = \"v\" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else \"eps\"\r\n\r\n    discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()\r\n    model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=torch.float32)\r\n\r\n    model.conditioner.wrapped = torch.nn.Module()\r\n\r\n\r\nsgm.modules.attention.print = shared.ldm_print\r\nsgm.modules.diffusionmodules.model.print = shared.ldm_print\r\nsgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print\r\nsgm.modules.encoders.modules.print = shared.ldm_print\r\n\r\n# this gets the code to load the vanilla attention that we override\r\nsgm.modules.attention.SDP_IS_AVAILABLE = True\r\nsgm.modules.attention.XFORMERS_IS_AVAILABLE = False\r\n"
  },
  {
    "path": "modules/sd_samplers.py",
    "content": "from __future__ import annotations\r\n\r\nimport functools\r\nimport logging\r\nfrom modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_lcm, shared, sd_samplers_common, sd_schedulers\r\n\r\n# imports for functions that previously were here and are used by other modules\r\nsamples_to_image_grid = sd_samplers_common.samples_to_image_grid\r\nsample_to_image = sd_samplers_common.sample_to_image\r\n\r\nall_samplers = [\r\n    *sd_samplers_kdiffusion.samplers_data_k_diffusion,\r\n    *sd_samplers_timesteps.samplers_data_timesteps,\r\n    *sd_samplers_lcm.samplers_data_lcm,\r\n]\r\nall_samplers_map = {x.name: x for x in all_samplers}\r\n\r\nsamplers: list[sd_samplers_common.SamplerData] = []\r\nsamplers_for_img2img: list[sd_samplers_common.SamplerData] = []\r\nsamplers_map = {}\r\nsamplers_hidden = {}\r\n\r\n\r\ndef find_sampler_config(name):\r\n    if name is not None:\r\n        config = all_samplers_map.get(name, None)\r\n    else:\r\n        config = all_samplers[0]\r\n\r\n    return config\r\n\r\n\r\ndef create_sampler(name, model):\r\n    config = find_sampler_config(name)\r\n\r\n    assert config is not None, f'bad sampler name: {name}'\r\n\r\n    if model.is_sdxl and config.options.get(\"no_sdxl\", False):\r\n        raise Exception(f\"Sampler {config.name} is not supported for SDXL\")\r\n\r\n    sampler = config.constructor(model)\r\n    sampler.config = config\r\n\r\n    return sampler\r\n\r\n\r\ndef set_samplers():\r\n    global samplers, samplers_for_img2img, samplers_hidden\r\n\r\n    samplers_hidden = set(shared.opts.hide_samplers)\r\n    samplers = all_samplers\r\n    samplers_for_img2img = all_samplers\r\n\r\n    samplers_map.clear()\r\n    for sampler in all_samplers:\r\n        samplers_map[sampler.name.lower()] = sampler.name\r\n        for alias in sampler.aliases:\r\n            samplers_map[alias.lower()] = sampler.name\r\n\r\n\r\ndef visible_sampler_names():\r\n    return [x.name for x in samplers if x.name not in samplers_hidden]\r\n\r\n\r\ndef visible_samplers():\r\n    return [x for x in samplers if x.name not in samplers_hidden]\r\n\r\n\r\ndef get_sampler_from_infotext(d: dict):\r\n    return get_sampler_and_scheduler(d.get(\"Sampler\"), d.get(\"Schedule type\"))[0]\r\n\r\n\r\ndef get_scheduler_from_infotext(d: dict):\r\n    return get_sampler_and_scheduler(d.get(\"Sampler\"), d.get(\"Schedule type\"))[1]\r\n\r\n\r\ndef get_hr_sampler_and_scheduler(d: dict):\r\n    hr_sampler = d.get(\"Hires sampler\", \"Use same sampler\")\r\n    sampler = d.get(\"Sampler\") if hr_sampler == \"Use same sampler\" else hr_sampler\r\n\r\n    hr_scheduler = d.get(\"Hires schedule type\", \"Use same scheduler\")\r\n    scheduler = d.get(\"Schedule type\") if hr_scheduler == \"Use same scheduler\" else hr_scheduler\r\n\r\n    sampler, scheduler = get_sampler_and_scheduler(sampler, scheduler)\r\n\r\n    sampler = sampler if sampler != d.get(\"Sampler\") else \"Use same sampler\"\r\n    scheduler = scheduler if scheduler != d.get(\"Schedule type\") else \"Use same scheduler\"\r\n\r\n    return sampler, scheduler\r\n\r\n\r\ndef get_hr_sampler_from_infotext(d: dict):\r\n    return get_hr_sampler_and_scheduler(d)[0]\r\n\r\n\r\ndef get_hr_scheduler_from_infotext(d: dict):\r\n    return get_hr_sampler_and_scheduler(d)[1]\r\n\r\n\r\n@functools.cache\r\ndef get_sampler_and_scheduler(sampler_name, scheduler_name, *, convert_automatic=True):\r\n    default_sampler = samplers[0]\r\n    found_scheduler = sd_schedulers.schedulers_map.get(scheduler_name, sd_schedulers.schedulers[0])\r\n\r\n    name = sampler_name or default_sampler.name\r\n\r\n    for scheduler in sd_schedulers.schedulers:\r\n        name_options = [scheduler.label, scheduler.name, *(scheduler.aliases or [])]\r\n\r\n        for name_option in name_options:\r\n            if name.endswith(\" \" + name_option):\r\n                found_scheduler = scheduler\r\n                name = name[0:-(len(name_option) + 1)]\r\n                break\r\n\r\n    sampler = all_samplers_map.get(name, default_sampler)\r\n\r\n    # revert back to Automatic if it's the default scheduler for the selected sampler\r\n    if convert_automatic and sampler.options.get('scheduler', None) == found_scheduler.name:\r\n        found_scheduler = sd_schedulers.schedulers[0]\r\n\r\n    return sampler.name, found_scheduler.label\r\n\r\n\r\ndef fix_p_invalid_sampler_and_scheduler(p):\r\n    i_sampler_name, i_scheduler = p.sampler_name, p.scheduler\r\n    p.sampler_name, p.scheduler = get_sampler_and_scheduler(p.sampler_name, p.scheduler, convert_automatic=False)\r\n    if p.sampler_name != i_sampler_name or i_scheduler != p.scheduler:\r\n        logging.warning(f'Sampler Scheduler autocorrection: \"{i_sampler_name}\" -> \"{p.sampler_name}\", \"{i_scheduler}\" -> \"{p.scheduler}\"')\r\n\r\n\r\nset_samplers()\r\n"
  },
  {
    "path": "modules/sd_samplers_cfg_denoiser.py",
    "content": "import torch\r\nfrom modules import prompt_parser, sd_samplers_common\r\n\r\nfrom modules.shared import opts, state\r\nimport modules.shared as shared\r\nfrom modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback\r\nfrom modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback\r\nfrom modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback\r\n\r\n\r\ndef catenate_conds(conds):\r\n    if not isinstance(conds[0], dict):\r\n        return torch.cat(conds)\r\n\r\n    return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()}\r\n\r\n\r\ndef subscript_cond(cond, a, b):\r\n    if not isinstance(cond, dict):\r\n        return cond[a:b]\r\n\r\n    return {key: vec[a:b] for key, vec in cond.items()}\r\n\r\n\r\ndef pad_cond(tensor, repeats, empty):\r\n    if not isinstance(tensor, dict):\r\n        return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1)\r\n\r\n    tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty)\r\n    return tensor\r\n\r\n\r\nclass CFGDenoiser(torch.nn.Module):\r\n    \"\"\"\r\n    Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)\r\n    that can take a noisy picture and produce a noise-free picture using two guidances (prompts)\r\n    instead of one. Originally, the second prompt is just an empty string, but we use non-empty\r\n    negative prompt.\r\n    \"\"\"\r\n\r\n    def __init__(self, sampler):\r\n        super().__init__()\r\n        self.model_wrap = None\r\n        self.mask = None\r\n        self.nmask = None\r\n        self.init_latent = None\r\n        self.steps = None\r\n        \"\"\"number of steps as specified by user in UI\"\"\"\r\n\r\n        self.total_steps = None\r\n        \"\"\"expected number of calls to denoiser calculated from self.steps and specifics of the selected sampler\"\"\"\r\n\r\n        self.step = 0\r\n        self.image_cfg_scale = None\r\n        self.padded_cond_uncond = False\r\n        self.padded_cond_uncond_v0 = False\r\n        self.sampler = sampler\r\n        self.model_wrap = None\r\n        self.p = None\r\n\r\n        self.cond_scale_miltiplier = 1.0\r\n\r\n        self.need_last_noise_uncond = False\r\n        self.last_noise_uncond = None\r\n\r\n        # NOTE: masking before denoising can cause the original latents to be oversmoothed\r\n        # as the original latents do not have noise\r\n        self.mask_before_denoising = False\r\n\r\n    @property\r\n    def inner_model(self):\r\n        raise NotImplementedError()\r\n\r\n    def combine_denoised(self, x_out, conds_list, uncond, cond_scale):\r\n        denoised_uncond = x_out[-uncond.shape[0]:]\r\n        denoised = torch.clone(denoised_uncond)\r\n\r\n        for i, conds in enumerate(conds_list):\r\n            for cond_index, weight in conds:\r\n                denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)\r\n\r\n        return denoised\r\n\r\n    def combine_denoised_for_edit_model(self, x_out, cond_scale):\r\n        out_cond, out_img_cond, out_uncond = x_out.chunk(3)\r\n        denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)\r\n\r\n        return denoised\r\n\r\n    def get_pred_x0(self, x_in, x_out, sigma):\r\n        return x_out\r\n\r\n    def update_inner_model(self):\r\n        self.model_wrap = None\r\n\r\n        c, uc = self.p.get_conds()\r\n        self.sampler.sampler_extra_args['cond'] = c\r\n        self.sampler.sampler_extra_args['uncond'] = uc\r\n\r\n    def pad_cond_uncond(self, cond, uncond):\r\n        empty = shared.sd_model.cond_stage_model_empty_prompt\r\n        num_repeats = (cond.shape[1] - uncond.shape[1]) // empty.shape[1]\r\n\r\n        if num_repeats < 0:\r\n            cond = pad_cond(cond, -num_repeats, empty)\r\n            self.padded_cond_uncond = True\r\n        elif num_repeats > 0:\r\n            uncond = pad_cond(uncond, num_repeats, empty)\r\n            self.padded_cond_uncond = True\r\n\r\n        return cond, uncond\r\n\r\n    def pad_cond_uncond_v0(self, cond, uncond):\r\n        \"\"\"\r\n        Pads the 'uncond' tensor to match the shape of the 'cond' tensor.\r\n\r\n        If 'uncond' is a dictionary, it is assumed that the 'crossattn' key holds the tensor to be padded.\r\n        If 'uncond' is a tensor, it is padded directly.\r\n\r\n        If the number of columns in 'uncond' is less than the number of columns in 'cond', the last column of 'uncond'\r\n        is repeated to match the number of columns in 'cond'.\r\n\r\n        If the number of columns in 'uncond' is greater than the number of columns in 'cond', 'uncond' is truncated\r\n        to match the number of columns in 'cond'.\r\n\r\n        Args:\r\n            cond (torch.Tensor or DictWithShape): The condition tensor to match the shape of 'uncond'.\r\n            uncond (torch.Tensor or DictWithShape): The tensor to be padded, or a dictionary containing the tensor to be padded.\r\n\r\n        Returns:\r\n            tuple: A tuple containing the 'cond' tensor and the padded 'uncond' tensor.\r\n\r\n        Note:\r\n            This is the padding that was always used in DDIM before version 1.6.0\r\n        \"\"\"\r\n\r\n        is_dict_cond = isinstance(uncond, dict)\r\n        uncond_vec = uncond['crossattn'] if is_dict_cond else uncond\r\n\r\n        if uncond_vec.shape[1] < cond.shape[1]:\r\n            last_vector = uncond_vec[:, -1:]\r\n            last_vector_repeated = last_vector.repeat([1, cond.shape[1] - uncond_vec.shape[1], 1])\r\n            uncond_vec = torch.hstack([uncond_vec, last_vector_repeated])\r\n            self.padded_cond_uncond_v0 = True\r\n        elif uncond_vec.shape[1] > cond.shape[1]:\r\n            uncond_vec = uncond_vec[:, :cond.shape[1]]\r\n            self.padded_cond_uncond_v0 = True\r\n\r\n        if is_dict_cond:\r\n            uncond['crossattn'] = uncond_vec\r\n        else:\r\n            uncond = uncond_vec\r\n\r\n        return cond, uncond\r\n\r\n    def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):\r\n        if state.interrupted or state.skipped:\r\n            raise sd_samplers_common.InterruptedException\r\n\r\n        if sd_samplers_common.apply_refiner(self, sigma):\r\n            cond = self.sampler.sampler_extra_args['cond']\r\n            uncond = self.sampler.sampler_extra_args['uncond']\r\n\r\n        # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,\r\n        # so is_edit_model is set to False to support AND composition.\r\n        is_edit_model = shared.sd_model.cond_stage_key == \"edit\" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0\r\n\r\n        conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)\r\n        uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)\r\n\r\n        assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), \"AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)\"\r\n\r\n        # If we use masks, blending between the denoised and original latent images occurs here.\r\n        def apply_blend(current_latent):\r\n            blended_latent = current_latent * self.nmask + self.init_latent * self.mask\r\n\r\n            if self.p.scripts is not None:\r\n                from modules import scripts\r\n                mba = scripts.MaskBlendArgs(current_latent, self.nmask, self.init_latent, self.mask, blended_latent, denoiser=self, sigma=sigma)\r\n                self.p.scripts.on_mask_blend(self.p, mba)\r\n                blended_latent = mba.blended_latent\r\n\r\n            return blended_latent\r\n\r\n        # Blend in the original latents (before)\r\n        if self.mask_before_denoising and self.mask is not None:\r\n            x = apply_blend(x)\r\n\r\n        batch_size = len(conds_list)\r\n        repeats = [len(conds_list[i]) for i in range(batch_size)]\r\n\r\n        if shared.sd_model.model.conditioning_key == \"crossattn-adm\":\r\n            image_uncond = torch.zeros_like(image_cond)\r\n            make_condition_dict = lambda c_crossattn, c_adm: {\"c_crossattn\": [c_crossattn], \"c_adm\": c_adm}\r\n        else:\r\n            image_uncond = image_cond\r\n            if isinstance(uncond, dict):\r\n                make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, \"c_concat\": [c_concat]}\r\n            else:\r\n                make_condition_dict = lambda c_crossattn, c_concat: {\"c_crossattn\": [c_crossattn], \"c_concat\": [c_concat]}\r\n\r\n        if not is_edit_model:\r\n            x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])\r\n            sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])\r\n            image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond])\r\n        else:\r\n            x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])\r\n            sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])\r\n            image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])\r\n\r\n        denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond, self)\r\n        cfg_denoiser_callback(denoiser_params)\r\n        x_in = denoiser_params.x\r\n        image_cond_in = denoiser_params.image_cond\r\n        sigma_in = denoiser_params.sigma\r\n        tensor = denoiser_params.text_cond\r\n        uncond = denoiser_params.text_uncond\r\n        skip_uncond = False\r\n\r\n        if shared.opts.skip_early_cond != 0. and self.step / self.total_steps <= shared.opts.skip_early_cond:\r\n            skip_uncond = True\r\n            self.p.extra_generation_params[\"Skip Early CFG\"] = shared.opts.skip_early_cond\r\n        elif (self.step % 2 or shared.opts.s_min_uncond_all) and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model:\r\n            skip_uncond = True\r\n            self.p.extra_generation_params[\"NGMS\"] = s_min_uncond\r\n            if shared.opts.s_min_uncond_all:\r\n                self.p.extra_generation_params[\"NGMS all steps\"] = shared.opts.s_min_uncond_all\r\n\r\n        if skip_uncond:\r\n            x_in = x_in[:-batch_size]\r\n            sigma_in = sigma_in[:-batch_size]\r\n\r\n        self.padded_cond_uncond = False\r\n        self.padded_cond_uncond_v0 = False\r\n        if shared.opts.pad_cond_uncond_v0 and tensor.shape[1] != uncond.shape[1]:\r\n            tensor, uncond = self.pad_cond_uncond_v0(tensor, uncond)\r\n        elif shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]:\r\n            tensor, uncond = self.pad_cond_uncond(tensor, uncond)\r\n\r\n        if tensor.shape[1] == uncond.shape[1] or skip_uncond:\r\n            if is_edit_model:\r\n                cond_in = catenate_conds([tensor, uncond, uncond])\r\n            elif skip_uncond:\r\n                cond_in = tensor\r\n            else:\r\n                cond_in = catenate_conds([tensor, uncond])\r\n\r\n            if shared.opts.batch_cond_uncond:\r\n                x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in))\r\n            else:\r\n                x_out = torch.zeros_like(x_in)\r\n                for batch_offset in range(0, x_out.shape[0], batch_size):\r\n                    a = batch_offset\r\n                    b = a + batch_size\r\n                    x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b]))\r\n        else:\r\n            x_out = torch.zeros_like(x_in)\r\n            batch_size = batch_size*2 if shared.opts.batch_cond_uncond else batch_size\r\n            for batch_offset in range(0, tensor.shape[0], batch_size):\r\n                a = batch_offset\r\n                b = min(a + batch_size, tensor.shape[0])\r\n\r\n                if not is_edit_model:\r\n                    c_crossattn = subscript_cond(tensor, a, b)\r\n                else:\r\n                    c_crossattn = torch.cat([tensor[a:b]], uncond)\r\n\r\n                x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))\r\n\r\n            if not skip_uncond:\r\n                x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:]))\r\n\r\n        denoised_image_indexes = [x[0][0] for x in conds_list]\r\n        if skip_uncond:\r\n            fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])\r\n            x_out = torch.cat([x_out, fake_uncond])  # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be\r\n\r\n        denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)\r\n        cfg_denoised_callback(denoised_params)\r\n\r\n        if self.need_last_noise_uncond:\r\n            self.last_noise_uncond = torch.clone(x_out[-uncond.shape[0]:])\r\n\r\n        if is_edit_model:\r\n            denoised = self.combine_denoised_for_edit_model(x_out, cond_scale * self.cond_scale_miltiplier)\r\n        elif skip_uncond:\r\n            denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)\r\n        else:\r\n            denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale * self.cond_scale_miltiplier)\r\n\r\n        # Blend in the original latents (after)\r\n        if not self.mask_before_denoising and self.mask is not None:\r\n            denoised = apply_blend(denoised)\r\n\r\n        self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)\r\n\r\n        if opts.live_preview_content == \"Prompt\":\r\n            preview = self.sampler.last_latent\r\n        elif opts.live_preview_content == \"Negative prompt\":\r\n            preview = self.get_pred_x0(x_in[-uncond.shape[0]:], x_out[-uncond.shape[0]:], sigma)\r\n        else:\r\n            preview = self.get_pred_x0(torch.cat([x_in[i:i+1] for i in denoised_image_indexes]), torch.cat([denoised[i:i+1] for i in denoised_image_indexes]), sigma)\r\n\r\n        sd_samplers_common.store_latent(preview)\r\n\r\n        after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)\r\n        cfg_after_cfg_callback(after_cfg_callback_params)\r\n        denoised = after_cfg_callback_params.x\r\n\r\n        self.step += 1\r\n        return denoised\r\n\r\n"
  },
  {
    "path": "modules/sd_samplers_common.py",
    "content": "import inspect\r\nfrom collections import namedtuple\r\nimport numpy as np\r\nimport torch\r\nfrom PIL import Image\r\nfrom modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models\r\nfrom modules.shared import opts, state\r\nimport k_diffusion.sampling\r\n\r\n\r\nSamplerDataTuple = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])\r\n\r\n\r\nclass SamplerData(SamplerDataTuple):\r\n    def total_steps(self, steps):\r\n        if self.options.get(\"second_order\", False):\r\n            steps = steps * 2\r\n\r\n        return steps\r\n\r\n\r\ndef setup_img2img_steps(p, steps=None):\r\n    if opts.img2img_fix_steps or steps is not None:\r\n        requested_steps = (steps or p.steps)\r\n        steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0\r\n        t_enc = requested_steps - 1\r\n    else:\r\n        steps = p.steps\r\n        t_enc = int(min(p.denoising_strength, 0.999) * steps)\r\n\r\n    return steps, t_enc\r\n\r\n\r\napproximation_indexes = {\"Full\": 0, \"Approx NN\": 1, \"Approx cheap\": 2, \"TAESD\": 3}\r\n\r\n\r\ndef samples_to_images_tensor(sample, approximation=None, model=None):\r\n    \"\"\"Transforms 4-channel latent space images into 3-channel RGB image tensors, with values in range [-1, 1].\"\"\"\r\n\r\n    if approximation is None or (shared.state.interrupted and opts.live_preview_fast_interrupt):\r\n        approximation = approximation_indexes.get(opts.show_progress_type, 0)\r\n\r\n        from modules import lowvram\r\n        if approximation == 0 and lowvram.is_enabled(shared.sd_model) and not shared.opts.live_preview_allow_lowvram_full:\r\n            approximation = 1\r\n\r\n    if approximation == 2:\r\n        x_sample = sd_vae_approx.cheap_approximation(sample)\r\n    elif approximation == 1:\r\n        x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype)).detach()\r\n    elif approximation == 3:\r\n        x_sample = sd_vae_taesd.decoder_model()(sample.to(devices.device, devices.dtype)).detach()\r\n        x_sample = x_sample * 2 - 1\r\n    else:\r\n        if model is None:\r\n            model = shared.sd_model\r\n        with torch.no_grad(), devices.without_autocast(): # fixes an issue with unstable VAEs that are flaky even in fp32\r\n            x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype))\r\n\r\n    return x_sample\r\n\r\n\r\ndef single_sample_to_image(sample, approximation=None):\r\n    x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5\r\n\r\n    x_sample = torch.clamp(x_sample, min=0.0, max=1.0)\r\n    x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)\r\n    x_sample = x_sample.astype(np.uint8)\r\n\r\n    return Image.fromarray(x_sample)\r\n\r\n\r\ndef decode_first_stage(model, x):\r\n    x = x.to(devices.dtype_vae)\r\n    approx_index = approximation_indexes.get(opts.sd_vae_decode_method, 0)\r\n    return samples_to_images_tensor(x, approx_index, model)\r\n\r\n\r\ndef sample_to_image(samples, index=0, approximation=None):\r\n    return single_sample_to_image(samples[index], approximation)\r\n\r\n\r\ndef samples_to_image_grid(samples, approximation=None):\r\n    return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])\r\n\r\n\r\ndef images_tensor_to_samples(image, approximation=None, model=None):\r\n    '''image[0, 1] -> latent'''\r\n    if approximation is None:\r\n        approximation = approximation_indexes.get(opts.sd_vae_encode_method, 0)\r\n\r\n    if approximation == 3:\r\n        image = image.to(devices.device, devices.dtype)\r\n        x_latent = sd_vae_taesd.encoder_model()(image)\r\n    else:\r\n        if model is None:\r\n            model = shared.sd_model\r\n        model.first_stage_model.to(devices.dtype_vae)\r\n\r\n        image = image.to(shared.device, dtype=devices.dtype_vae)\r\n        image = image * 2 - 1\r\n        if len(image) > 1:\r\n            x_latent = torch.stack([\r\n                model.get_first_stage_encoding(\r\n                    model.encode_first_stage(torch.unsqueeze(img, 0))\r\n                )[0]\r\n                for img in image\r\n            ])\r\n        else:\r\n            x_latent = model.get_first_stage_encoding(model.encode_first_stage(image))\r\n\r\n    return x_latent\r\n\r\n\r\ndef store_latent(decoded):\r\n    state.current_latent = decoded\r\n\r\n    if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:\r\n        if not shared.parallel_processing_allowed:\r\n            shared.state.assign_current_image(sample_to_image(decoded))\r\n\r\n\r\ndef is_sampler_using_eta_noise_seed_delta(p):\r\n    \"\"\"returns whether sampler from config will use eta noise seed delta for image creation\"\"\"\r\n\r\n    sampler_config = sd_samplers.find_sampler_config(p.sampler_name)\r\n\r\n    eta = p.eta\r\n\r\n    if eta is None and p.sampler is not None:\r\n        eta = p.sampler.eta\r\n\r\n    if eta is None and sampler_config is not None:\r\n        eta = 0 if sampler_config.options.get(\"default_eta_is_0\", False) else 1.0\r\n\r\n    if eta == 0:\r\n        return False\r\n\r\n    return sampler_config.options.get(\"uses_ensd\", False)\r\n\r\n\r\nclass InterruptedException(BaseException):\r\n    pass\r\n\r\n\r\ndef replace_torchsde_browinan():\r\n    import torchsde._brownian.brownian_interval\r\n\r\n    def torchsde_randn(size, dtype, device, seed):\r\n        return devices.randn_local(seed, size).to(device=device, dtype=dtype)\r\n\r\n    torchsde._brownian.brownian_interval._randn = torchsde_randn\r\n\r\n\r\nreplace_torchsde_browinan()\r\n\r\n\r\ndef apply_refiner(cfg_denoiser, sigma=None):\r\n    if opts.refiner_switch_by_sample_steps or sigma is None:\r\n        completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps\r\n        cfg_denoiser.p.extra_generation_params[\"Refiner switch by sampling steps\"] = True\r\n\r\n    else:\r\n        # torch.max(sigma) only to handle rare case where we might have different sigmas in the same batch\r\n        try:\r\n            timestep = torch.argmin(torch.abs(cfg_denoiser.inner_model.sigmas.to(sigma.device) - torch.max(sigma)))\r\n        except AttributeError:  # for samplers that don't use sigmas (DDIM) sigma is actually the timestep\r\n            timestep = torch.max(sigma).to(dtype=int)\r\n        completed_ratio = (999 - timestep) / 1000\r\n\r\n    refiner_switch_at = cfg_denoiser.p.refiner_switch_at\r\n    refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info\r\n\r\n    if refiner_switch_at is not None and completed_ratio < refiner_switch_at:\r\n        return False\r\n\r\n    if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info:\r\n        return False\r\n\r\n    if getattr(cfg_denoiser.p, \"enable_hr\", False):\r\n        is_second_pass = cfg_denoiser.p.is_hr_pass\r\n\r\n        if opts.hires_fix_refiner_pass == \"first pass\" and is_second_pass:\r\n            return False\r\n\r\n        if opts.hires_fix_refiner_pass == \"second pass\" and not is_second_pass:\r\n            return False\r\n\r\n        if opts.hires_fix_refiner_pass != \"second pass\":\r\n            cfg_denoiser.p.extra_generation_params['Hires refiner'] = opts.hires_fix_refiner_pass\r\n\r\n    cfg_denoiser.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title\r\n    cfg_denoiser.p.extra_generation_params['Refiner switch at'] = refiner_switch_at\r\n\r\n    with sd_models.SkipWritingToConfig():\r\n        sd_models.reload_model_weights(info=refiner_checkpoint_info)\r\n\r\n    devices.torch_gc()\r\n    cfg_denoiser.p.setup_conds()\r\n    cfg_denoiser.update_inner_model()\r\n\r\n    return True\r\n\r\n\r\nclass TorchHijack:\r\n    \"\"\"This is here to replace torch.randn_like of k-diffusion.\r\n\r\n    k-diffusion has random_sampler argument for most samplers, but not for all, so\r\n    this is needed to properly replace every use of torch.randn_like.\r\n\r\n    We need to replace to make images generated in batches to be same as images generated individually.\"\"\"\r\n\r\n    def __init__(self, p):\r\n        self.rng = p.rng\r\n\r\n    def __getattr__(self, item):\r\n        if item == 'randn_like':\r\n            return self.randn_like\r\n\r\n        if hasattr(torch, item):\r\n            return getattr(torch, item)\r\n\r\n        raise AttributeError(f\"'{type(self).__name__}' object has no attribute '{item}'\")\r\n\r\n    def randn_like(self, x):\r\n        return self.rng.next()\r\n\r\n\r\nclass Sampler:\r\n    def __init__(self, funcname):\r\n        self.funcname = funcname\r\n        self.func = funcname\r\n        self.extra_params = []\r\n        self.sampler_noises = None\r\n        self.stop_at = None\r\n        self.eta = None\r\n        self.config: SamplerData = None  # set by the function calling the constructor\r\n        self.last_latent = None\r\n        self.s_min_uncond = None\r\n        self.s_churn = 0.0\r\n        self.s_tmin = 0.0\r\n        self.s_tmax = float('inf')\r\n        self.s_noise = 1.0\r\n\r\n        self.eta_option_field = 'eta_ancestral'\r\n        self.eta_infotext_field = 'Eta'\r\n        self.eta_default = 1.0\r\n\r\n        self.conditioning_key = getattr(shared.sd_model.model, 'conditioning_key', 'crossattn')\r\n\r\n        self.p = None\r\n        self.model_wrap_cfg = None\r\n        self.sampler_extra_args = None\r\n        self.options = {}\r\n\r\n    def callback_state(self, d):\r\n        step = d['i']\r\n\r\n        if self.stop_at is not None and step > self.stop_at:\r\n            raise InterruptedException\r\n\r\n        state.sampling_step = step\r\n        shared.total_tqdm.update()\r\n\r\n    def launch_sampling(self, steps, func):\r\n        self.model_wrap_cfg.steps = steps\r\n        self.model_wrap_cfg.total_steps = self.config.total_steps(steps)\r\n        state.sampling_steps = steps\r\n        state.sampling_step = 0\r\n\r\n        try:\r\n            return func()\r\n        except RecursionError:\r\n            print(\r\n                'Encountered RecursionError during sampling, returning last latent. '\r\n                'rho >5 with a polyexponential scheduler may cause this error. '\r\n                'You should try to use a smaller rho value instead.'\r\n            )\r\n            return self.last_latent\r\n        except InterruptedException:\r\n            return self.last_latent\r\n\r\n    def number_of_needed_noises(self, p):\r\n        return p.steps\r\n\r\n    def initialize(self, p) -> dict:\r\n        self.p = p\r\n        self.model_wrap_cfg.p = p\r\n        self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None\r\n        self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None\r\n        self.model_wrap_cfg.step = 0\r\n        self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)\r\n        self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0)\r\n        self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)\r\n\r\n        k_diffusion.sampling.torch = TorchHijack(p)\r\n\r\n        extra_params_kwargs = {}\r\n        for param_name in self.extra_params:\r\n            if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:\r\n                extra_params_kwargs[param_name] = getattr(p, param_name)\r\n\r\n        if 'eta' in inspect.signature(self.func).parameters:\r\n            if self.eta != self.eta_default:\r\n                p.extra_generation_params[self.eta_infotext_field] = self.eta\r\n\r\n            extra_params_kwargs['eta'] = self.eta\r\n\r\n        if len(self.extra_params) > 0:\r\n            s_churn = getattr(opts, 's_churn', p.s_churn)\r\n            s_tmin = getattr(opts, 's_tmin', p.s_tmin)\r\n            s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf\r\n            s_noise = getattr(opts, 's_noise', p.s_noise)\r\n\r\n            if 's_churn' in extra_params_kwargs and s_churn != self.s_churn:\r\n                extra_params_kwargs['s_churn'] = s_churn\r\n                p.s_churn = s_churn\r\n                p.extra_generation_params['Sigma churn'] = s_churn\r\n            if 's_tmin' in extra_params_kwargs and s_tmin != self.s_tmin:\r\n                extra_params_kwargs['s_tmin'] = s_tmin\r\n                p.s_tmin = s_tmin\r\n                p.extra_generation_params['Sigma tmin'] = s_tmin\r\n            if 's_tmax' in extra_params_kwargs and s_tmax != self.s_tmax:\r\n                extra_params_kwargs['s_tmax'] = s_tmax\r\n                p.s_tmax = s_tmax\r\n                p.extra_generation_params['Sigma tmax'] = s_tmax\r\n            if 's_noise' in extra_params_kwargs and s_noise != self.s_noise:\r\n                extra_params_kwargs['s_noise'] = s_noise\r\n                p.s_noise = s_noise\r\n                p.extra_generation_params['Sigma noise'] = s_noise\r\n\r\n        return extra_params_kwargs\r\n\r\n    def create_noise_sampler(self, x, sigmas, p):\r\n        \"\"\"For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes\"\"\"\r\n        if shared.opts.no_dpmpp_sde_batch_determinism:\r\n            return None\r\n\r\n        from k_diffusion.sampling import BrownianTreeNoiseSampler\r\n        sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()\r\n        current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]\r\n        return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)\r\n\r\n    def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):\r\n        raise NotImplementedError()\r\n\r\n    def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):\r\n        raise NotImplementedError()\r\n\r\n    def add_infotext(self, p):\r\n        if self.model_wrap_cfg.padded_cond_uncond:\r\n            p.extra_generation_params[\"Pad conds\"] = True\r\n\r\n        if self.model_wrap_cfg.padded_cond_uncond_v0:\r\n            p.extra_generation_params[\"Pad conds v0\"] = True\r\n"
  },
  {
    "path": "modules/sd_samplers_compvis.py",
    "content": ""
  },
  {
    "path": "modules/sd_samplers_extra.py",
    "content": "import torch\r\nimport tqdm\r\nimport k_diffusion.sampling\r\n\r\n\r\n@torch.no_grad()\r\ndef restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., restart_list=None):\r\n    \"\"\"Implements restart sampling in Restart Sampling for Improving Generative Processes (2023)\r\n    Restart_list format: {min_sigma: [ restart_steps, restart_times, max_sigma]}\r\n    If restart_list is None: will choose restart_list automatically, otherwise will use the given restart_list\r\n    \"\"\"\r\n    extra_args = {} if extra_args is None else extra_args\r\n    s_in = x.new_ones([x.shape[0]])\r\n    step_id = 0\r\n    from k_diffusion.sampling import to_d, get_sigmas_karras\r\n\r\n    def heun_step(x, old_sigma, new_sigma, second_order=True):\r\n        nonlocal step_id\r\n        denoised = model(x, old_sigma * s_in, **extra_args)\r\n        d = to_d(x, old_sigma, denoised)\r\n        if callback is not None:\r\n            callback({'x': x, 'i': step_id, 'sigma': new_sigma, 'sigma_hat': old_sigma, 'denoised': denoised})\r\n        dt = new_sigma - old_sigma\r\n        if new_sigma == 0 or not second_order:\r\n            # Euler method\r\n            x = x + d * dt\r\n        else:\r\n            # Heun's method\r\n            x_2 = x + d * dt\r\n            denoised_2 = model(x_2, new_sigma * s_in, **extra_args)\r\n            d_2 = to_d(x_2, new_sigma, denoised_2)\r\n            d_prime = (d + d_2) / 2\r\n            x = x + d_prime * dt\r\n        step_id += 1\r\n        return x\r\n\r\n    steps = sigmas.shape[0] - 1\r\n    if restart_list is None:\r\n        if steps >= 20:\r\n            restart_steps = 9\r\n            restart_times = 1\r\n            if steps >= 36:\r\n                restart_steps = steps // 4\r\n                restart_times = 2\r\n            sigmas = get_sigmas_karras(steps - restart_steps * restart_times, sigmas[-2].item(), sigmas[0].item(), device=sigmas.device)\r\n            restart_list = {0.1: [restart_steps + 1, restart_times, 2]}\r\n        else:\r\n            restart_list = {}\r\n\r\n    restart_list = {int(torch.argmin(abs(sigmas - key), dim=0)): value for key, value in restart_list.items()}\r\n\r\n    step_list = []\r\n    for i in range(len(sigmas) - 1):\r\n        step_list.append((sigmas[i], sigmas[i + 1]))\r\n        if i + 1 in restart_list:\r\n            restart_steps, restart_times, restart_max = restart_list[i + 1]\r\n            min_idx = i + 1\r\n            max_idx = int(torch.argmin(abs(sigmas - restart_max), dim=0))\r\n            if max_idx < min_idx:\r\n                sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx].item(), sigmas[max_idx].item(), device=sigmas.device)[:-1]\r\n                while restart_times > 0:\r\n                    restart_times -= 1\r\n                    step_list.extend(zip(sigma_restart[:-1], sigma_restart[1:]))\r\n\r\n    last_sigma = None\r\n    for old_sigma, new_sigma in tqdm.tqdm(step_list, disable=disable):\r\n        if last_sigma is None:\r\n            last_sigma = old_sigma\r\n        elif last_sigma < old_sigma:\r\n            x = x + k_diffusion.sampling.torch.randn_like(x) * s_noise * (old_sigma ** 2 - last_sigma ** 2) ** 0.5\r\n        x = heun_step(x, old_sigma, new_sigma)\r\n        last_sigma = new_sigma\r\n\r\n    return x\r\n"
  },
  {
    "path": "modules/sd_samplers_kdiffusion.py",
    "content": "import torch\r\nimport inspect\r\nimport k_diffusion.sampling\r\nfrom modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser, sd_schedulers, devices\r\nfrom modules.sd_samplers_cfg_denoiser import CFGDenoiser  # noqa: F401\r\nfrom modules.script_callbacks import ExtraNoiseParams, extra_noise_callback\r\n\r\nfrom modules.shared import opts\r\nimport modules.shared as shared\r\n\r\nsamplers_k_diffusion = [\r\n    ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {'scheduler': 'karras'}),\r\n    ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {'scheduler': 'karras', \"second_order\": True, \"brownian_noise\": True}),\r\n    ('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde'], {'scheduler': 'exponential', \"brownian_noise\": True}),\r\n    ('DPM++ 2M SDE Heun', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_heun'], {'scheduler': 'exponential', \"brownian_noise\": True, \"solver_type\": \"heun\"}),\r\n    ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {'scheduler': 'karras', \"uses_ensd\": True, \"second_order\": True}),\r\n    ('DPM++ 3M SDE', 'sample_dpmpp_3m_sde', ['k_dpmpp_3m_sde'], {'scheduler': 'exponential', 'discard_next_to_last_sigma': True, \"brownian_noise\": True}),\r\n    ('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {\"uses_ensd\": True}),\r\n    ('Euler', 'sample_euler', ['k_euler'], {}),\r\n    ('LMS', 'sample_lms', ['k_lms'], {}),\r\n    ('Heun', 'sample_heun', ['k_heun'], {\"second_order\": True}),\r\n    ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, \"second_order\": True}),\r\n    ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, \"uses_ensd\": True, \"second_order\": True}),\r\n    ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {\"uses_ensd\": True}),\r\n    ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {\"uses_ensd\": True}),\r\n    ('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras', \"second_order\": True}),\r\n]\r\n\r\n\r\nsamplers_data_k_diffusion = [\r\n    sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)\r\n    for label, funcname, aliases, options in samplers_k_diffusion\r\n    if callable(funcname) or hasattr(k_diffusion.sampling, funcname)\r\n]\r\n\r\nsampler_extra_params = {\r\n    'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],\r\n    'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],\r\n    'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],\r\n    'sample_dpm_fast': ['s_noise'],\r\n    'sample_dpm_2_ancestral': ['s_noise'],\r\n    'sample_dpmpp_2s_ancestral': ['s_noise'],\r\n    'sample_dpmpp_sde': ['s_noise'],\r\n    'sample_dpmpp_2m_sde': ['s_noise'],\r\n    'sample_dpmpp_3m_sde': ['s_noise'],\r\n}\r\n\r\nk_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion}\r\nk_diffusion_scheduler = {x.name: x.function for x in sd_schedulers.schedulers}\r\n\r\n\r\nclass CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser):\r\n    @property\r\n    def inner_model(self):\r\n        if self.model_wrap is None:\r\n            denoiser_constructor = getattr(shared.sd_model, 'create_denoiser', None)\r\n\r\n            if denoiser_constructor is not None:\r\n                self.model_wrap = denoiser_constructor()\r\n            else:\r\n                denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == \"v\" else k_diffusion.external.CompVisDenoiser\r\n                self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization)\r\n\r\n        return self.model_wrap\r\n\r\n\r\nclass KDiffusionSampler(sd_samplers_common.Sampler):\r\n    def __init__(self, funcname, sd_model, options=None):\r\n        super().__init__(funcname)\r\n\r\n        self.extra_params = sampler_extra_params.get(funcname, [])\r\n\r\n        self.options = options or {}\r\n        self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)\r\n\r\n        self.model_wrap_cfg = CFGDenoiserKDiffusion(self)\r\n        self.model_wrap = self.model_wrap_cfg.inner_model\r\n\r\n    def get_sigmas(self, p, steps):\r\n        discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)\r\n        if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma:\r\n            discard_next_to_last_sigma = True\r\n            p.extra_generation_params[\"Discard penultimate sigma\"] = True\r\n\r\n        steps += 1 if discard_next_to_last_sigma else 0\r\n\r\n        scheduler_name = (p.hr_scheduler if p.is_hr_pass else p.scheduler) or 'Automatic'\r\n        if scheduler_name == 'Automatic':\r\n            scheduler_name = self.config.options.get('scheduler', None)\r\n\r\n        scheduler = sd_schedulers.schedulers_map.get(scheduler_name)\r\n\r\n        m_sigma_min, m_sigma_max = self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()\r\n        sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (m_sigma_min, m_sigma_max)\r\n\r\n        if p.sampler_noise_scheduler_override:\r\n            sigmas = p.sampler_noise_scheduler_override(steps)\r\n        elif scheduler is None or scheduler.function is None:\r\n            sigmas = self.model_wrap.get_sigmas(steps)\r\n        else:\r\n            sigmas_kwargs = {'sigma_min': sigma_min, 'sigma_max': sigma_max}\r\n\r\n            if scheduler.label != 'Automatic' and not p.is_hr_pass:\r\n                p.extra_generation_params[\"Schedule type\"] = scheduler.label\r\n            elif scheduler.label != p.extra_generation_params.get(\"Schedule type\"):\r\n                p.extra_generation_params[\"Hires schedule type\"] = scheduler.label\r\n\r\n            if opts.sigma_min != 0 and opts.sigma_min != m_sigma_min:\r\n                sigmas_kwargs['sigma_min'] = opts.sigma_min\r\n                p.extra_generation_params[\"Schedule min sigma\"] = opts.sigma_min\r\n\r\n            if opts.sigma_max != 0 and opts.sigma_max != m_sigma_max:\r\n                sigmas_kwargs['sigma_max'] = opts.sigma_max\r\n                p.extra_generation_params[\"Schedule max sigma\"] = opts.sigma_max\r\n\r\n            if scheduler.default_rho != -1 and opts.rho != 0 and opts.rho != scheduler.default_rho:\r\n                sigmas_kwargs['rho'] = opts.rho\r\n                p.extra_generation_params[\"Schedule rho\"] = opts.rho\r\n\r\n            if scheduler.need_inner_model:\r\n                sigmas_kwargs['inner_model'] = self.model_wrap\r\n\r\n            if scheduler.label == 'Beta':\r\n                p.extra_generation_params[\"Beta schedule alpha\"] = opts.beta_dist_alpha\r\n                p.extra_generation_params[\"Beta schedule beta\"] = opts.beta_dist_beta\r\n\r\n            sigmas = scheduler.function(n=steps, **sigmas_kwargs, device=devices.cpu)\r\n\r\n        if discard_next_to_last_sigma:\r\n            sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])\r\n\r\n        return sigmas.cpu()\r\n\r\n    def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):\r\n        steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)\r\n\r\n        sigmas = self.get_sigmas(p, steps)\r\n        sigma_sched = sigmas[steps - t_enc - 1:]\r\n\r\n        if hasattr(shared.sd_model, 'add_noise_to_latent'):\r\n            xi = shared.sd_model.add_noise_to_latent(x, noise, sigma_sched[0])\r\n        else:\r\n            xi = x + noise * sigma_sched[0]\r\n\r\n        if opts.img2img_extra_noise > 0:\r\n            p.extra_generation_params[\"Extra noise\"] = opts.img2img_extra_noise\r\n            extra_noise_params = ExtraNoiseParams(noise, x, xi)\r\n            extra_noise_callback(extra_noise_params)\r\n            noise = extra_noise_params.noise\r\n            xi += noise * opts.img2img_extra_noise\r\n\r\n        extra_params_kwargs = self.initialize(p)\r\n        parameters = inspect.signature(self.func).parameters\r\n\r\n        if 'sigma_min' in parameters:\r\n            ## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last\r\n            extra_params_kwargs['sigma_min'] = sigma_sched[-2]\r\n        if 'sigma_max' in parameters:\r\n            extra_params_kwargs['sigma_max'] = sigma_sched[0]\r\n        if 'n' in parameters:\r\n            extra_params_kwargs['n'] = len(sigma_sched) - 1\r\n        if 'sigma_sched' in parameters:\r\n            extra_params_kwargs['sigma_sched'] = sigma_sched\r\n        if 'sigmas' in parameters:\r\n            extra_params_kwargs['sigmas'] = sigma_sched\r\n\r\n        if self.config.options.get('brownian_noise', False):\r\n            noise_sampler = self.create_noise_sampler(x, sigmas, p)\r\n            extra_params_kwargs['noise_sampler'] = noise_sampler\r\n\r\n        if self.config.options.get('solver_type', None) == 'heun':\r\n            extra_params_kwargs['solver_type'] = 'heun'\r\n\r\n        self.model_wrap_cfg.init_latent = x\r\n        self.last_latent = x\r\n        self.sampler_extra_args = {\r\n            'cond': conditioning,\r\n            'image_cond': image_conditioning,\r\n            'uncond': unconditional_conditioning,\r\n            'cond_scale': p.cfg_scale,\r\n            's_min_uncond': self.s_min_uncond\r\n        }\r\n\r\n        samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))\r\n\r\n        self.add_infotext(p)\r\n\r\n        return samples\r\n\r\n    def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):\r\n        steps = steps or p.steps\r\n\r\n        sigmas = self.get_sigmas(p, steps)\r\n\r\n        if opts.sgm_noise_multiplier:\r\n            p.extra_generation_params[\"SGM noise multiplier\"] = True\r\n            x = x * torch.sqrt(1.0 + sigmas[0] ** 2.0)\r\n        else:\r\n            x = x * sigmas[0]\r\n\r\n        extra_params_kwargs = self.initialize(p)\r\n        parameters = inspect.signature(self.func).parameters\r\n\r\n        if 'n' in parameters:\r\n            extra_params_kwargs['n'] = steps\r\n\r\n        if 'sigma_min' in parameters:\r\n            extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()\r\n            extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()\r\n\r\n        if 'sigmas' in parameters:\r\n            extra_params_kwargs['sigmas'] = sigmas\r\n\r\n        if self.config.options.get('brownian_noise', False):\r\n            noise_sampler = self.create_noise_sampler(x, sigmas, p)\r\n            extra_params_kwargs['noise_sampler'] = noise_sampler\r\n\r\n        if self.config.options.get('solver_type', None) == 'heun':\r\n            extra_params_kwargs['solver_type'] = 'heun'\r\n\r\n        self.last_latent = x\r\n        self.sampler_extra_args = {\r\n            'cond': conditioning,\r\n            'image_cond': image_conditioning,\r\n            'uncond': unconditional_conditioning,\r\n            'cond_scale': p.cfg_scale,\r\n            's_min_uncond': self.s_min_uncond\r\n        }\r\n\r\n        samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))\r\n\r\n        self.add_infotext(p)\r\n\r\n        return samples\r\n\r\n\r\n"
  },
  {
    "path": "modules/sd_samplers_lcm.py",
    "content": "import torch\n\nfrom k_diffusion import utils, sampling\nfrom k_diffusion.external import DiscreteEpsDDPMDenoiser\nfrom k_diffusion.sampling import default_noise_sampler, trange\n\nfrom modules import shared, sd_samplers_cfg_denoiser, sd_samplers_kdiffusion, sd_samplers_common\n\n\nclass LCMCompVisDenoiser(DiscreteEpsDDPMDenoiser):\n    def __init__(self, model):\n        timesteps = 1000\n        original_timesteps = 50     # LCM Original Timesteps (default=50, for current version of LCM)\n        self.skip_steps = timesteps // original_timesteps\n\n        alphas_cumprod_valid = torch.zeros((original_timesteps), dtype=torch.float32, device=model.device)\n        for x in range(original_timesteps):\n            alphas_cumprod_valid[original_timesteps - 1 - x] = model.alphas_cumprod[timesteps - 1 - x * self.skip_steps]\n\n        super().__init__(model, alphas_cumprod_valid, quantize=None)\n\n\n    def get_sigmas(self, n=None,):\n        if n is None:\n            return sampling.append_zero(self.sigmas.flip(0))\n\n        start = self.sigma_to_t(self.sigma_max)\n        end = self.sigma_to_t(self.sigma_min)\n\n        t = torch.linspace(start, end, n, device=shared.sd_model.device)\n\n        return sampling.append_zero(self.t_to_sigma(t))\n\n\n    def sigma_to_t(self, sigma, quantize=None):\n        log_sigma = sigma.log()\n        dists = log_sigma - self.log_sigmas[:, None]\n        return dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)\n\n\n    def t_to_sigma(self, timestep):\n        t = torch.clamp(((timestep - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1))\n        return super().t_to_sigma(t)\n\n\n    def get_eps(self, *args, **kwargs):\n        return self.inner_model.apply_model(*args, **kwargs)\n\n\n    def get_scaled_out(self, sigma, output, input):\n        sigma_data = 0.5\n        scaled_timestep = utils.append_dims(self.sigma_to_t(sigma), output.ndim) * 10.0\n\n        c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)\n        c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5\n\n        return c_out * output + c_skip * input\n\n\n    def forward(self, input, sigma, **kwargs):\n        c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]\n        eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)\n        return self.get_scaled_out(sigma, input + eps * c_out, input)\n\n\ndef sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):\n    extra_args = {} if extra_args is None else extra_args\n    noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler\n    s_in = x.new_ones([x.shape[0]])\n\n    for i in trange(len(sigmas) - 1, disable=disable):\n        denoised = model(x, sigmas[i] * s_in, **extra_args)\n\n        if callback is not None:\n            callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})\n\n        x = denoised\n        if sigmas[i + 1] > 0:\n            x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1])\n    return x\n\n\nclass CFGDenoiserLCM(sd_samplers_cfg_denoiser.CFGDenoiser):\n    @property\n    def inner_model(self):\n        if self.model_wrap is None:\n            denoiser = LCMCompVisDenoiser\n            self.model_wrap = denoiser(shared.sd_model)\n\n        return self.model_wrap\n\n\nclass LCMSampler(sd_samplers_kdiffusion.KDiffusionSampler):\n    def __init__(self, funcname, sd_model, options=None):\n        super().__init__(funcname, sd_model, options)\n        self.model_wrap_cfg = CFGDenoiserLCM(self)\n        self.model_wrap = self.model_wrap_cfg.inner_model\n\n\nsamplers_lcm = [('LCM', sample_lcm, ['k_lcm'], {})]\nsamplers_data_lcm = [\n    sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: LCMSampler(funcname, model), aliases, options)\n    for label, funcname, aliases, options in samplers_lcm\n]\n"
  },
  {
    "path": "modules/sd_samplers_timesteps.py",
    "content": "import torch\r\nimport inspect\r\nimport sys\r\nfrom modules import devices, sd_samplers_common, sd_samplers_timesteps_impl\r\nfrom modules.sd_samplers_cfg_denoiser import CFGDenoiser\r\nfrom modules.script_callbacks import ExtraNoiseParams, extra_noise_callback\r\n\r\nfrom modules.shared import opts\r\nimport modules.shared as shared\r\n\r\nsamplers_timesteps = [\r\n    ('DDIM', sd_samplers_timesteps_impl.ddim, ['ddim'], {}),\r\n    ('DDIM CFG++', sd_samplers_timesteps_impl.ddim_cfgpp, ['ddim_cfgpp'], {}),\r\n    ('PLMS', sd_samplers_timesteps_impl.plms, ['plms'], {}),\r\n    ('UniPC', sd_samplers_timesteps_impl.unipc, ['unipc'], {}),\r\n]\r\n\r\n\r\nsamplers_data_timesteps = [\r\n    sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: CompVisSampler(funcname, model), aliases, options)\r\n    for label, funcname, aliases, options in samplers_timesteps\r\n]\r\n\r\n\r\nclass CompVisTimestepsDenoiser(torch.nn.Module):\r\n    def __init__(self, model, *args, **kwargs):\r\n        super().__init__(*args, **kwargs)\r\n        self.inner_model = model\r\n\r\n    def forward(self, input, timesteps, **kwargs):\r\n        return self.inner_model.apply_model(input, timesteps, **kwargs)\r\n\r\n\r\nclass CompVisTimestepsVDenoiser(torch.nn.Module):\r\n    def __init__(self, model, *args, **kwargs):\r\n        super().__init__(*args, **kwargs)\r\n        self.inner_model = model\r\n\r\n    def predict_eps_from_z_and_v(self, x_t, t, v):\r\n        return torch.sqrt(self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * v + torch.sqrt(1 - self.inner_model.alphas_cumprod)[t.to(torch.int), None, None, None] * x_t\r\n\r\n    def forward(self, input, timesteps, **kwargs):\r\n        model_output = self.inner_model.apply_model(input, timesteps, **kwargs)\r\n        e_t = self.predict_eps_from_z_and_v(input, timesteps, model_output)\r\n        return e_t\r\n\r\n\r\nclass CFGDenoiserTimesteps(CFGDenoiser):\r\n\r\n    def __init__(self, sampler):\r\n        super().__init__(sampler)\r\n\r\n        self.alphas = shared.sd_model.alphas_cumprod\r\n        self.mask_before_denoising = True\r\n\r\n    def get_pred_x0(self, x_in, x_out, sigma):\r\n        ts = sigma.to(dtype=int)\r\n\r\n        a_t = self.alphas[ts][:, None, None, None]\r\n        sqrt_one_minus_at = (1 - a_t).sqrt()\r\n\r\n        pred_x0 = (x_in - sqrt_one_minus_at * x_out) / a_t.sqrt()\r\n\r\n        return pred_x0\r\n\r\n    @property\r\n    def inner_model(self):\r\n        if self.model_wrap is None:\r\n            denoiser = CompVisTimestepsVDenoiser if shared.sd_model.parameterization == \"v\" else CompVisTimestepsDenoiser\r\n            self.model_wrap = denoiser(shared.sd_model)\r\n\r\n        return self.model_wrap\r\n\r\n\r\nclass CompVisSampler(sd_samplers_common.Sampler):\r\n    def __init__(self, funcname, sd_model):\r\n        super().__init__(funcname)\r\n\r\n        self.eta_option_field = 'eta_ddim'\r\n        self.eta_infotext_field = 'Eta DDIM'\r\n        self.eta_default = 0.0\r\n\r\n        self.model_wrap_cfg = CFGDenoiserTimesteps(self)\r\n        self.model_wrap = self.model_wrap_cfg.inner_model\r\n\r\n    def get_timesteps(self, p, steps):\r\n        discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)\r\n        if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma:\r\n            discard_next_to_last_sigma = True\r\n            p.extra_generation_params[\"Discard penultimate sigma\"] = True\r\n\r\n        steps += 1 if discard_next_to_last_sigma else 0\r\n\r\n        timesteps = torch.clip(torch.asarray(list(range(0, 1000, 1000 // steps)), device=devices.device) + 1, 0, 999)\r\n\r\n        return timesteps\r\n\r\n    def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):\r\n        steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)\r\n\r\n        timesteps = self.get_timesteps(p, steps)\r\n        timesteps_sched = timesteps[:t_enc]\r\n\r\n        alphas_cumprod = shared.sd_model.alphas_cumprod\r\n        sqrt_alpha_cumprod = torch.sqrt(alphas_cumprod[timesteps[t_enc]])\r\n        sqrt_one_minus_alpha_cumprod = torch.sqrt(1 - alphas_cumprod[timesteps[t_enc]])\r\n\r\n        xi = x * sqrt_alpha_cumprod + noise * sqrt_one_minus_alpha_cumprod\r\n\r\n        if opts.img2img_extra_noise > 0:\r\n            p.extra_generation_params[\"Extra noise\"] = opts.img2img_extra_noise\r\n            extra_noise_params = ExtraNoiseParams(noise, x, xi)\r\n            extra_noise_callback(extra_noise_params)\r\n            noise = extra_noise_params.noise\r\n            xi += noise * opts.img2img_extra_noise * sqrt_alpha_cumprod\r\n\r\n        extra_params_kwargs = self.initialize(p)\r\n        parameters = inspect.signature(self.func).parameters\r\n\r\n        if 'timesteps' in parameters:\r\n            extra_params_kwargs['timesteps'] = timesteps_sched\r\n        if 'is_img2img' in parameters:\r\n            extra_params_kwargs['is_img2img'] = True\r\n\r\n        self.model_wrap_cfg.init_latent = x\r\n        self.last_latent = x\r\n        self.sampler_extra_args = {\r\n            'cond': conditioning,\r\n            'image_cond': image_conditioning,\r\n            'uncond': unconditional_conditioning,\r\n            'cond_scale': p.cfg_scale,\r\n            's_min_uncond': self.s_min_uncond\r\n        }\r\n\r\n        samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))\r\n\r\n        self.add_infotext(p)\r\n\r\n        return samples\r\n\r\n    def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):\r\n        steps = steps or p.steps\r\n        timesteps = self.get_timesteps(p, steps)\r\n\r\n        extra_params_kwargs = self.initialize(p)\r\n        parameters = inspect.signature(self.func).parameters\r\n\r\n        if 'timesteps' in parameters:\r\n            extra_params_kwargs['timesteps'] = timesteps\r\n\r\n        self.last_latent = x\r\n        self.sampler_extra_args = {\r\n            'cond': conditioning,\r\n            'image_cond': image_conditioning,\r\n            'uncond': unconditional_conditioning,\r\n            'cond_scale': p.cfg_scale,\r\n            's_min_uncond': self.s_min_uncond\r\n        }\r\n        samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))\r\n\r\n        self.add_infotext(p)\r\n\r\n        return samples\r\n\r\n\r\nsys.modules['modules.sd_samplers_compvis'] = sys.modules[__name__]\r\nVanillaStableDiffusionSampler = CompVisSampler  # temp. compatibility with older extensions\r\n"
  },
  {
    "path": "modules/sd_samplers_timesteps_impl.py",
    "content": "import torch\r\nimport tqdm\r\nimport k_diffusion.sampling\r\nimport numpy as np\r\n\r\nfrom modules import shared\r\nfrom modules.models.diffusion.uni_pc import uni_pc\r\nfrom modules.torch_utils import float64\r\n\r\n\r\n@torch.no_grad()\r\ndef ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):\r\n    alphas_cumprod = model.inner_model.inner_model.alphas_cumprod\r\n    alphas = alphas_cumprod[timesteps]\r\n    alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))\r\n    sqrt_one_minus_alphas = torch.sqrt(1 - alphas)\r\n    sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))\r\n\r\n    extra_args = {} if extra_args is None else extra_args\r\n    s_in = x.new_ones((x.shape[0]))\r\n    s_x = x.new_ones((x.shape[0], 1, 1, 1))\r\n    for i in tqdm.trange(len(timesteps) - 1, disable=disable):\r\n        index = len(timesteps) - 1 - i\r\n\r\n        e_t = model(x, timesteps[index].item() * s_in, **extra_args)\r\n\r\n        a_t = alphas[index].item() * s_x\r\n        a_prev = alphas_prev[index].item() * s_x\r\n        sigma_t = sigmas[index].item() * s_x\r\n        sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x\r\n\r\n        pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()\r\n        dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t\r\n        noise = sigma_t * k_diffusion.sampling.torch.randn_like(x)\r\n        x = a_prev.sqrt() * pred_x0 + dir_xt + noise\r\n\r\n        if callback is not None:\r\n            callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})\r\n\r\n    return x\r\n\r\n\r\n@torch.no_grad()\r\ndef ddim_cfgpp(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):\r\n    \"\"\" Implements CFG++: Manifold-constrained Classifier Free Guidance For Diffusion Models (2024).\r\n    Uses the unconditional noise prediction instead of the conditional noise to guide the denoising direction.\r\n    The CFG scale is divided by 12.5 to map CFG from [0.0, 12.5] to [0, 1.0].\r\n    \"\"\"\r\n    alphas_cumprod = model.inner_model.inner_model.alphas_cumprod\r\n    alphas = alphas_cumprod[timesteps]\r\n    alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))\r\n    sqrt_one_minus_alphas = torch.sqrt(1 - alphas)\r\n    sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))\r\n\r\n    model.cond_scale_miltiplier = 1 / 12.5\r\n    model.need_last_noise_uncond = True\r\n\r\n    extra_args = {} if extra_args is None else extra_args\r\n    s_in = x.new_ones((x.shape[0]))\r\n    s_x = x.new_ones((x.shape[0], 1, 1, 1))\r\n    for i in tqdm.trange(len(timesteps) - 1, disable=disable):\r\n        index = len(timesteps) - 1 - i\r\n\r\n        e_t = model(x, timesteps[index].item() * s_in, **extra_args)\r\n        last_noise_uncond = model.last_noise_uncond\r\n\r\n        a_t = alphas[index].item() * s_x\r\n        a_prev = alphas_prev[index].item() * s_x\r\n        sigma_t = sigmas[index].item() * s_x\r\n        sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x\r\n\r\n        pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()\r\n        dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * last_noise_uncond\r\n        noise = sigma_t * k_diffusion.sampling.torch.randn_like(x)\r\n        x = a_prev.sqrt() * pred_x0 + dir_xt + noise\r\n\r\n        if callback is not None:\r\n            callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})\r\n\r\n    return x\r\n\r\n\r\n@torch.no_grad()\r\ndef plms(model, x, timesteps, extra_args=None, callback=None, disable=None):\r\n    alphas_cumprod = model.inner_model.inner_model.alphas_cumprod\r\n    alphas = alphas_cumprod[timesteps]\r\n    alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))\r\n    sqrt_one_minus_alphas = torch.sqrt(1 - alphas)\r\n\r\n    extra_args = {} if extra_args is None else extra_args\r\n    s_in = x.new_ones([x.shape[0]])\r\n    s_x = x.new_ones((x.shape[0], 1, 1, 1))\r\n    old_eps = []\r\n\r\n    def get_x_prev_and_pred_x0(e_t, index):\r\n        # select parameters corresponding to the currently considered timestep\r\n        a_t = alphas[index].item() * s_x\r\n        a_prev = alphas_prev[index].item() * s_x\r\n        sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x\r\n\r\n        # current prediction for x_0\r\n        pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()\r\n\r\n        # direction pointing to x_t\r\n        dir_xt = (1. - a_prev).sqrt() * e_t\r\n        x_prev = a_prev.sqrt() * pred_x0 + dir_xt\r\n        return x_prev, pred_x0\r\n\r\n    for i in tqdm.trange(len(timesteps) - 1, disable=disable):\r\n        index = len(timesteps) - 1 - i\r\n        ts = timesteps[index].item() * s_in\r\n        t_next = timesteps[max(index - 1, 0)].item() * s_in\r\n\r\n        e_t = model(x, ts, **extra_args)\r\n\r\n        if len(old_eps) == 0:\r\n            # Pseudo Improved Euler (2nd order)\r\n            x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)\r\n            e_t_next = model(x_prev, t_next, **extra_args)\r\n            e_t_prime = (e_t + e_t_next) / 2\r\n        elif len(old_eps) == 1:\r\n            # 2nd order Pseudo Linear Multistep (Adams-Bashforth)\r\n            e_t_prime = (3 * e_t - old_eps[-1]) / 2\r\n        elif len(old_eps) == 2:\r\n            # 3nd order Pseudo Linear Multistep (Adams-Bashforth)\r\n            e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12\r\n        else:\r\n            # 4nd order Pseudo Linear Multistep (Adams-Bashforth)\r\n            e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24\r\n\r\n        x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)\r\n\r\n        old_eps.append(e_t)\r\n        if len(old_eps) >= 4:\r\n            old_eps.pop(0)\r\n\r\n        x = x_prev\r\n\r\n        if callback is not None:\r\n            callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})\r\n\r\n    return x\r\n\r\n\r\nclass UniPCCFG(uni_pc.UniPC):\r\n    def __init__(self, cfg_model, extra_args, callback, *args, **kwargs):\r\n        super().__init__(None, *args, **kwargs)\r\n\r\n        def after_update(x, model_x):\r\n            callback({'x': x, 'i': self.index, 'sigma': 0, 'sigma_hat': 0, 'denoised': model_x})\r\n            self.index += 1\r\n\r\n        self.cfg_model = cfg_model\r\n        self.extra_args = extra_args\r\n        self.callback = callback\r\n        self.index = 0\r\n        self.after_update = after_update\r\n\r\n    def get_model_input_time(self, t_continuous):\r\n        return (t_continuous - 1. / self.noise_schedule.total_N) * 1000.\r\n\r\n    def model(self, x, t):\r\n        t_input = self.get_model_input_time(t)\r\n\r\n        res = self.cfg_model(x, t_input, **self.extra_args)\r\n\r\n        return res\r\n\r\n\r\ndef unipc(model, x, timesteps, extra_args=None, callback=None, disable=None, is_img2img=False):\r\n    alphas_cumprod = model.inner_model.inner_model.alphas_cumprod\r\n\r\n    ns = uni_pc.NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)\r\n    t_start = timesteps[-1] / 1000 + 1 / 1000 if is_img2img else None  # this is likely off by a bit - if someone wants to fix it please by all means\r\n    unipc_sampler = UniPCCFG(model, extra_args, callback, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant)\r\n    x = unipc_sampler.sample(x, steps=len(timesteps), t_start=t_start, skip_type=shared.opts.uni_pc_skip_type, method=\"multistep\", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final)\r\n\r\n    return x\r\n"
  },
  {
    "path": "modules/sd_schedulers.py",
    "content": "import dataclasses\nimport torch\nimport k_diffusion\nimport numpy as np\nfrom scipy import stats\n\nfrom modules import shared\n\n\ndef to_d(x, sigma, denoised):\n    \"\"\"Converts a denoiser output to a Karras ODE derivative.\"\"\"\n    return (x - denoised) / sigma\n\n\nk_diffusion.sampling.to_d = to_d\n\n\n@dataclasses.dataclass\nclass Scheduler:\n    name: str\n    label: str\n    function: any\n\n    default_rho: float = -1\n    need_inner_model: bool = False\n    aliases: list = None\n\n\ndef uniform(n, sigma_min, sigma_max, inner_model, device):\n    return inner_model.get_sigmas(n).to(device)\n\n\ndef sgm_uniform(n, sigma_min, sigma_max, inner_model, device):\n    start = inner_model.sigma_to_t(torch.tensor(sigma_max))\n    end = inner_model.sigma_to_t(torch.tensor(sigma_min))\n    sigs = [\n        inner_model.t_to_sigma(ts)\n        for ts in torch.linspace(start, end, n + 1)[:-1]\n    ]\n    sigs += [0.0]\n    return torch.FloatTensor(sigs).to(device)\n\n\ndef get_align_your_steps_sigmas(n, sigma_min, sigma_max, device):\n    # https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html\n    def loglinear_interp(t_steps, num_steps):\n        \"\"\"\n        Performs log-linear interpolation of a given array of decreasing numbers.\n        \"\"\"\n        xs = np.linspace(0, 1, len(t_steps))\n        ys = np.log(t_steps[::-1])\n\n        new_xs = np.linspace(0, 1, num_steps)\n        new_ys = np.interp(new_xs, xs, ys)\n\n        interped_ys = np.exp(new_ys)[::-1].copy()\n        return interped_ys\n\n    if shared.sd_model.is_sdxl:\n        sigmas = [14.615, 6.315, 3.771, 2.181, 1.342, 0.862, 0.555, 0.380, 0.234, 0.113, 0.029]\n    else:\n        # Default to SD 1.5 sigmas.\n        sigmas = [14.615, 6.475, 3.861, 2.697, 1.886, 1.396, 0.963, 0.652, 0.399, 0.152, 0.029]\n\n    if n != len(sigmas):\n        sigmas = np.append(loglinear_interp(sigmas, n), [0.0])\n    else:\n        sigmas.append(0.0)\n\n    return torch.FloatTensor(sigmas).to(device)\n\n\ndef kl_optimal(n, sigma_min, sigma_max, device):\n    alpha_min = torch.arctan(torch.tensor(sigma_min, device=device))\n    alpha_max = torch.arctan(torch.tensor(sigma_max, device=device))\n    step_indices = torch.arange(n + 1, device=device)\n    sigmas = torch.tan(step_indices / n * alpha_min + (1.0 - step_indices / n) * alpha_max)\n    return sigmas\n\n\ndef simple_scheduler(n, sigma_min, sigma_max, inner_model, device):\n    sigs = []\n    ss = len(inner_model.sigmas) / n\n    for x in range(n):\n        sigs += [float(inner_model.sigmas[-(1 + int(x * ss))])]\n    sigs += [0.0]\n    return torch.FloatTensor(sigs).to(device)\n\n\ndef normal_scheduler(n, sigma_min, sigma_max, inner_model, device, sgm=False, floor=False):\n    start = inner_model.sigma_to_t(torch.tensor(sigma_max))\n    end = inner_model.sigma_to_t(torch.tensor(sigma_min))\n\n    if sgm:\n        timesteps = torch.linspace(start, end, n + 1)[:-1]\n    else:\n        timesteps = torch.linspace(start, end, n)\n\n    sigs = []\n    for x in range(len(timesteps)):\n        ts = timesteps[x]\n        sigs.append(inner_model.t_to_sigma(ts))\n    sigs += [0.0]\n    return torch.FloatTensor(sigs).to(device)\n\n\ndef ddim_scheduler(n, sigma_min, sigma_max, inner_model, device):\n    sigs = []\n    ss = max(len(inner_model.sigmas) // n, 1)\n    x = 1\n    while x < len(inner_model.sigmas):\n        sigs += [float(inner_model.sigmas[x])]\n        x += ss\n    sigs = sigs[::-1]\n    sigs += [0.0]\n    return torch.FloatTensor(sigs).to(device)\n\n\ndef beta_scheduler(n, sigma_min, sigma_max, inner_model, device):\n    # From \"Beta Sampling is All You Need\" [arXiv:2407.12173] (Lee et. al, 2024) \"\"\"\n    alpha = shared.opts.beta_dist_alpha\n    beta = shared.opts.beta_dist_beta\n    timesteps = 1 - np.linspace(0, 1, n)\n    timesteps = [stats.beta.ppf(x, alpha, beta) for x in timesteps]\n    sigmas = [sigma_min + (x * (sigma_max-sigma_min)) for x in timesteps]\n    sigmas += [0.0]\n    return torch.FloatTensor(sigmas).to(device)\n\n\nschedulers = [\n    Scheduler('automatic', 'Automatic', None),\n    Scheduler('uniform', 'Uniform', uniform, need_inner_model=True),\n    Scheduler('karras', 'Karras', k_diffusion.sampling.get_sigmas_karras, default_rho=7.0),\n    Scheduler('exponential', 'Exponential', k_diffusion.sampling.get_sigmas_exponential),\n    Scheduler('polyexponential', 'Polyexponential', k_diffusion.sampling.get_sigmas_polyexponential, default_rho=1.0),\n    Scheduler('sgm_uniform', 'SGM Uniform', sgm_uniform, need_inner_model=True, aliases=[\"SGMUniform\"]),\n    Scheduler('kl_optimal', 'KL Optimal', kl_optimal),\n    Scheduler('align_your_steps', 'Align Your Steps', get_align_your_steps_sigmas),\n    Scheduler('simple', 'Simple', simple_scheduler, need_inner_model=True),\n    Scheduler('normal', 'Normal', normal_scheduler, need_inner_model=True),\n    Scheduler('ddim', 'DDIM', ddim_scheduler, need_inner_model=True),\n    Scheduler('beta', 'Beta', beta_scheduler, need_inner_model=True),\n]\n\nschedulers_map = {**{x.name: x for x in schedulers}, **{x.label: x for x in schedulers}}\n"
  },
  {
    "path": "modules/sd_unet.py",
    "content": "import torch.nn\r\n\r\nfrom modules import script_callbacks, shared, devices\r\n\r\nunet_options = []\r\ncurrent_unet_option = None\r\ncurrent_unet = None\r\noriginal_forward = None  # not used, only left temporarily for compatibility\r\n\r\ndef list_unets():\r\n    new_unets = script_callbacks.list_unets_callback()\r\n\r\n    unet_options.clear()\r\n    unet_options.extend(new_unets)\r\n\r\n\r\ndef get_unet_option(option=None):\r\n    option = option or shared.opts.sd_unet\r\n\r\n    if option == \"None\":\r\n        return None\r\n\r\n    if option == \"Automatic\":\r\n        name = shared.sd_model.sd_checkpoint_info.model_name\r\n\r\n        options = [x for x in unet_options if x.model_name == name]\r\n\r\n        option = options[0].label if options else \"None\"\r\n\r\n    return next(iter([x for x in unet_options if x.label == option]), None)\r\n\r\n\r\ndef apply_unet(option=None):\r\n    global current_unet_option\r\n    global current_unet\r\n\r\n    new_option = get_unet_option(option)\r\n    if new_option == current_unet_option:\r\n        return\r\n\r\n    if current_unet is not None:\r\n        print(f\"Dectivating unet: {current_unet.option.label}\")\r\n        current_unet.deactivate()\r\n\r\n    current_unet_option = new_option\r\n    if current_unet_option is None:\r\n        current_unet = None\r\n\r\n        if not shared.sd_model.lowvram:\r\n            shared.sd_model.model.diffusion_model.to(devices.device)\r\n\r\n        return\r\n\r\n    shared.sd_model.model.diffusion_model.to(devices.cpu)\r\n    devices.torch_gc()\r\n\r\n    current_unet = current_unet_option.create_unet()\r\n    current_unet.option = current_unet_option\r\n    print(f\"Activating unet: {current_unet.option.label}\")\r\n    current_unet.activate()\r\n\r\n\r\nclass SdUnetOption:\r\n    model_name = None\r\n    \"\"\"name of related checkpoint - this option will be selected automatically for unet if the name of checkpoint matches this\"\"\"\r\n\r\n    label = None\r\n    \"\"\"name of the unet in UI\"\"\"\r\n\r\n    def create_unet(self):\r\n        \"\"\"returns SdUnet object to be used as a Unet instead of built-in unet when making pictures\"\"\"\r\n        raise NotImplementedError()\r\n\r\n\r\nclass SdUnet(torch.nn.Module):\r\n    def forward(self, x, timesteps, context, *args, **kwargs):\r\n        raise NotImplementedError()\r\n\r\n    def activate(self):\r\n        pass\r\n\r\n    def deactivate(self):\r\n        pass\r\n\r\n\r\ndef create_unet_forward(original_forward):\r\n    def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):\r\n        if current_unet is not None:\r\n            return current_unet.forward(x, timesteps, context, *args, **kwargs)\r\n\r\n        return original_forward(self, x, timesteps, context, *args, **kwargs)\r\n\r\n    return UNetModel_forward\r\n\r\n"
  },
  {
    "path": "modules/sd_vae.py",
    "content": "import os\nimport collections\nfrom dataclasses import dataclass\n\nfrom modules import paths, shared, devices, script_callbacks, sd_models, extra_networks, lowvram, sd_hijack, hashes\n\nimport glob\nfrom copy import deepcopy\n\n\nvae_path = os.path.abspath(os.path.join(paths.models_path, \"VAE\"))\nvae_ignore_keys = {\"model_ema.decay\", \"model_ema.num_updates\"}\nvae_dict = {}\n\n\nbase_vae = None\nloaded_vae_file = None\ncheckpoint_info = None\n\ncheckpoints_loaded = collections.OrderedDict()\n\n\ndef get_loaded_vae_name():\n    if loaded_vae_file is None:\n        return None\n\n    return os.path.basename(loaded_vae_file)\n\n\ndef get_loaded_vae_hash():\n    if loaded_vae_file is None:\n        return None\n\n    sha256 = hashes.sha256(loaded_vae_file, 'vae')\n\n    return sha256[0:10] if sha256 else None\n\n\ndef get_base_vae(model):\n    if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:\n        return base_vae\n    return None\n\n\ndef store_base_vae(model):\n    global base_vae, checkpoint_info\n    if checkpoint_info != model.sd_checkpoint_info:\n        assert not loaded_vae_file, \"Trying to store non-base VAE!\"\n        base_vae = deepcopy(model.first_stage_model.state_dict())\n        checkpoint_info = model.sd_checkpoint_info\n\n\ndef delete_base_vae():\n    global base_vae, checkpoint_info\n    base_vae = None\n    checkpoint_info = None\n\n\ndef restore_base_vae(model):\n    global loaded_vae_file\n    if base_vae is not None and checkpoint_info == model.sd_checkpoint_info:\n        print(\"Restoring base VAE\")\n        _load_vae_dict(model, base_vae)\n        loaded_vae_file = None\n    delete_base_vae()\n\n\ndef get_filename(filepath):\n    return os.path.basename(filepath)\n\n\ndef refresh_vae_list():\n    vae_dict.clear()\n\n    paths = [\n        os.path.join(sd_models.model_path, '**/*.vae.ckpt'),\n        os.path.join(sd_models.model_path, '**/*.vae.pt'),\n        os.path.join(sd_models.model_path, '**/*.vae.safetensors'),\n        os.path.join(vae_path, '**/*.ckpt'),\n        os.path.join(vae_path, '**/*.pt'),\n        os.path.join(vae_path, '**/*.safetensors'),\n    ]\n\n    if shared.cmd_opts.ckpt_dir is not None and os.path.isdir(shared.cmd_opts.ckpt_dir):\n        paths += [\n            os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.ckpt'),\n            os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.pt'),\n            os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.safetensors'),\n        ]\n\n    if shared.cmd_opts.vae_dir is not None and os.path.isdir(shared.cmd_opts.vae_dir):\n        paths += [\n            os.path.join(shared.cmd_opts.vae_dir, '**/*.ckpt'),\n            os.path.join(shared.cmd_opts.vae_dir, '**/*.pt'),\n            os.path.join(shared.cmd_opts.vae_dir, '**/*.safetensors'),\n        ]\n\n    candidates = []\n    for path in paths:\n        candidates += glob.iglob(path, recursive=True)\n\n    for filepath in candidates:\n        name = get_filename(filepath)\n        vae_dict[name] = filepath\n\n    vae_dict.update(dict(sorted(vae_dict.items(), key=lambda item: shared.natural_sort_key(item[0]))))\n\n\ndef find_vae_near_checkpoint(checkpoint_file):\n    checkpoint_path = os.path.basename(checkpoint_file).rsplit('.', 1)[0]\n    for vae_file in vae_dict.values():\n        if os.path.basename(vae_file).startswith(checkpoint_path):\n            return vae_file\n\n    return None\n\n\n@dataclass\nclass VaeResolution:\n    vae: str = None\n    source: str = None\n    resolved: bool = True\n\n    def tuple(self):\n        return self.vae, self.source\n\n\ndef is_automatic():\n    return shared.opts.sd_vae in {\"Automatic\", \"auto\"}  # \"auto\" for people with old config\n\n\ndef resolve_vae_from_setting() -> VaeResolution:\n    if shared.opts.sd_vae == \"None\":\n        return VaeResolution()\n\n    vae_from_options = vae_dict.get(shared.opts.sd_vae, None)\n    if vae_from_options is not None:\n        return VaeResolution(vae_from_options, 'specified in settings')\n\n    if not is_automatic():\n        print(f\"Couldn't find VAE named {shared.opts.sd_vae}; using None instead\")\n\n    return VaeResolution(resolved=False)\n\n\ndef resolve_vae_from_user_metadata(checkpoint_file) -> VaeResolution:\n    metadata = extra_networks.get_user_metadata(checkpoint_file)\n    vae_metadata = metadata.get(\"vae\", None)\n    if vae_metadata is not None and vae_metadata != \"Automatic\":\n        if vae_metadata == \"None\":\n            return VaeResolution()\n\n        vae_from_metadata = vae_dict.get(vae_metadata, None)\n        if vae_from_metadata is not None:\n            return VaeResolution(vae_from_metadata, \"from user metadata\")\n\n    return VaeResolution(resolved=False)\n\n\ndef resolve_vae_near_checkpoint(checkpoint_file) -> VaeResolution:\n    vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)\n    if vae_near_checkpoint is not None and (not shared.opts.sd_vae_overrides_per_model_preferences or is_automatic()):\n        return VaeResolution(vae_near_checkpoint, 'found near the checkpoint')\n\n    return VaeResolution(resolved=False)\n\n\ndef resolve_vae(checkpoint_file) -> VaeResolution:\n    if shared.cmd_opts.vae_path is not None:\n        return VaeResolution(shared.cmd_opts.vae_path, 'from commandline argument')\n\n    if shared.opts.sd_vae_overrides_per_model_preferences and not is_automatic():\n        return resolve_vae_from_setting()\n\n    res = resolve_vae_from_user_metadata(checkpoint_file)\n    if res.resolved:\n        return res\n\n    res = resolve_vae_near_checkpoint(checkpoint_file)\n    if res.resolved:\n        return res\n\n    res = resolve_vae_from_setting()\n\n    return res\n\n\ndef load_vae_dict(filename, map_location):\n    vae_ckpt = sd_models.read_state_dict(filename, map_location=map_location)\n    vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != \"loss\" and k not in vae_ignore_keys}\n    return vae_dict_1\n\n\ndef load_vae(model, vae_file=None, vae_source=\"from unknown source\"):\n    global vae_dict, base_vae, loaded_vae_file\n    # save_settings = False\n\n    cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0\n\n    if vae_file:\n        if cache_enabled and vae_file in checkpoints_loaded:\n            # use vae checkpoint cache\n            print(f\"Loading VAE weights {vae_source}: cached {get_filename(vae_file)}\")\n            store_base_vae(model)\n            _load_vae_dict(model, checkpoints_loaded[vae_file])\n        else:\n            assert os.path.isfile(vae_file), f\"VAE {vae_source} doesn't exist: {vae_file}\"\n            print(f\"Loading VAE weights {vae_source}: {vae_file}\")\n            store_base_vae(model)\n\n            vae_dict_1 = load_vae_dict(vae_file, map_location=shared.weight_load_location)\n            _load_vae_dict(model, vae_dict_1)\n\n            if cache_enabled:\n                # cache newly loaded vae\n                checkpoints_loaded[vae_file] = vae_dict_1.copy()\n\n        # clean up cache if limit is reached\n        if cache_enabled:\n            while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model\n                checkpoints_loaded.popitem(last=False)  # LRU\n\n        # If vae used is not in dict, update it\n        # It will be removed on refresh though\n        vae_opt = get_filename(vae_file)\n        if vae_opt not in vae_dict:\n            vae_dict[vae_opt] = vae_file\n\n    elif loaded_vae_file:\n        restore_base_vae(model)\n\n    loaded_vae_file = vae_file\n    model.base_vae = base_vae\n    model.loaded_vae_file = loaded_vae_file\n\n\n# don't call this from outside\ndef _load_vae_dict(model, vae_dict_1):\n    model.first_stage_model.load_state_dict(vae_dict_1)\n    model.first_stage_model.to(devices.dtype_vae)\n\n\ndef clear_loaded_vae():\n    global loaded_vae_file\n    loaded_vae_file = None\n\n\nunspecified = object()\n\n\ndef reload_vae_weights(sd_model=None, vae_file=unspecified):\n    if not sd_model:\n        sd_model = shared.sd_model\n\n    checkpoint_info = sd_model.sd_checkpoint_info\n    checkpoint_file = checkpoint_info.filename\n\n    if vae_file == unspecified:\n        vae_file, vae_source = resolve_vae(checkpoint_file).tuple()\n    else:\n        vae_source = \"from function argument\"\n\n    if loaded_vae_file == vae_file:\n        return\n\n    if sd_model.lowvram:\n        lowvram.send_everything_to_cpu()\n    else:\n        sd_model.to(devices.cpu)\n\n    sd_hijack.model_hijack.undo_hijack(sd_model)\n\n    load_vae(sd_model, vae_file, vae_source)\n\n    sd_hijack.model_hijack.hijack(sd_model)\n\n    if not sd_model.lowvram:\n        sd_model.to(devices.device)\n\n    script_callbacks.model_loaded_callback(sd_model)\n\n    print(\"VAE weights loaded.\")\n    return sd_model\n"
  },
  {
    "path": "modules/sd_vae_approx.py",
    "content": "import os\r\n\r\nimport torch\r\nfrom torch import nn\r\nfrom modules import devices, paths, shared\r\n\r\nsd_vae_approx_models = {}\r\n\r\n\r\nclass VAEApprox(nn.Module):\r\n    def __init__(self, latent_channels=4):\r\n        super(VAEApprox, self).__init__()\r\n        self.conv1 = nn.Conv2d(latent_channels, 8, (7, 7))\r\n        self.conv2 = nn.Conv2d(8, 16, (5, 5))\r\n        self.conv3 = nn.Conv2d(16, 32, (3, 3))\r\n        self.conv4 = nn.Conv2d(32, 64, (3, 3))\r\n        self.conv5 = nn.Conv2d(64, 32, (3, 3))\r\n        self.conv6 = nn.Conv2d(32, 16, (3, 3))\r\n        self.conv7 = nn.Conv2d(16, 8, (3, 3))\r\n        self.conv8 = nn.Conv2d(8, 3, (3, 3))\r\n\r\n    def forward(self, x):\r\n        extra = 11\r\n        x = nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2))\r\n        x = nn.functional.pad(x, (extra, extra, extra, extra))\r\n\r\n        for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7, self.conv8, ]:\r\n            x = layer(x)\r\n            x = nn.functional.leaky_relu(x, 0.1)\r\n\r\n        return x\r\n\r\n\r\ndef download_model(model_path, model_url):\r\n    if not os.path.exists(model_path):\r\n        os.makedirs(os.path.dirname(model_path), exist_ok=True)\r\n\r\n        print(f'Downloading VAEApprox model to: {model_path}')\r\n        torch.hub.download_url_to_file(model_url, model_path)\r\n\r\n\r\ndef model():\r\n    if shared.sd_model.is_sd3:\r\n        model_name = \"vaeapprox-sd3.pt\"\r\n    elif shared.sd_model.is_sdxl:\r\n        model_name = \"vaeapprox-sdxl.pt\"\r\n    else:\r\n        model_name = \"model.pt\"\r\n\r\n    loaded_model = sd_vae_approx_models.get(model_name)\r\n\r\n    if loaded_model is None:\r\n        model_path = os.path.join(paths.models_path, \"VAE-approx\", model_name)\r\n        if not os.path.exists(model_path):\r\n            model_path = os.path.join(paths.script_path, \"models\", \"VAE-approx\", model_name)\r\n\r\n        if not os.path.exists(model_path):\r\n            model_path = os.path.join(paths.models_path, \"VAE-approx\", model_name)\r\n            download_model(model_path, 'https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/download/v1.0.0-pre/' + model_name)\r\n\r\n        loaded_model = VAEApprox(latent_channels=shared.sd_model.latent_channels)\r\n        loaded_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None))\r\n        loaded_model.eval()\r\n        loaded_model.to(devices.device, devices.dtype)\r\n        sd_vae_approx_models[model_name] = loaded_model\r\n\r\n    return loaded_model\r\n\r\n\r\ndef cheap_approximation(sample):\r\n    # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2\r\n\r\n    if shared.sd_model.is_sd3:\r\n        coeffs = [\r\n            [-0.0645,  0.0177,  0.1052], [ 0.0028,  0.0312,  0.0650],\r\n            [ 0.1848,  0.0762,  0.0360], [ 0.0944,  0.0360,  0.0889],\r\n            [ 0.0897,  0.0506, -0.0364], [-0.0020,  0.1203,  0.0284],\r\n            [ 0.0855,  0.0118,  0.0283], [-0.0539,  0.0658,  0.1047],\r\n            [-0.0057,  0.0116,  0.0700], [-0.0412,  0.0281, -0.0039],\r\n            [ 0.1106,  0.1171,  0.1220], [-0.0248,  0.0682, -0.0481],\r\n            [ 0.0815,  0.0846,  0.1207], [-0.0120, -0.0055, -0.0867],\r\n            [-0.0749, -0.0634, -0.0456], [-0.1418, -0.1457, -0.1259],\r\n        ]\r\n    elif shared.sd_model.is_sdxl:\r\n        coeffs = [\r\n            [ 0.3448,  0.4168,  0.4395],\r\n            [-0.1953, -0.0290,  0.0250],\r\n            [ 0.1074,  0.0886, -0.0163],\r\n            [-0.3730, -0.2499, -0.2088],\r\n        ]\r\n    else:\r\n        coeffs = [\r\n            [ 0.298,  0.207,  0.208],\r\n            [ 0.187,  0.286,  0.173],\r\n            [-0.158,  0.189,  0.264],\r\n            [-0.184, -0.271, -0.473],\r\n        ]\r\n\r\n    coefs = torch.tensor(coeffs).to(sample.device)\r\n\r\n    x_sample = torch.einsum(\"...lxy,lr -> ...rxy\", sample, coefs)\r\n\r\n    return x_sample\r\n"
  },
  {
    "path": "modules/sd_vae_taesd.py",
    "content": "\"\"\"\nTiny AutoEncoder for Stable Diffusion\n(DNN for encoding / decoding SD's latent space)\n\nhttps://github.com/madebyollin/taesd\n\"\"\"\nimport os\nimport torch\nimport torch.nn as nn\n\nfrom modules import devices, paths_internal, shared\n\nsd_vae_taesd_models = {}\n\n\ndef conv(n_in, n_out, **kwargs):\n    return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)\n\n\nclass Clamp(nn.Module):\n    @staticmethod\n    def forward(x):\n        return torch.tanh(x / 3) * 3\n\n\nclass Block(nn.Module):\n    def __init__(self, n_in, n_out):\n        super().__init__()\n        self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))\n        self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()\n        self.fuse = nn.ReLU()\n\n    def forward(self, x):\n        return self.fuse(self.conv(x) + self.skip(x))\n\n\ndef decoder(latent_channels=4):\n    return nn.Sequential(\n        Clamp(), conv(latent_channels, 64), nn.ReLU(),\n        Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),\n        Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),\n        Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),\n        Block(64, 64), conv(64, 3),\n    )\n\n\ndef encoder(latent_channels=4):\n    return nn.Sequential(\n        conv(3, 64), Block(64, 64),\n        conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),\n        conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),\n        conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),\n        conv(64, latent_channels),\n    )\n\n\nclass TAESDDecoder(nn.Module):\n    latent_magnitude = 3\n    latent_shift = 0.5\n\n    def __init__(self, decoder_path=\"taesd_decoder.pth\", latent_channels=None):\n        \"\"\"Initialize pretrained TAESD on the given device from the given checkpoints.\"\"\"\n        super().__init__()\n\n        if latent_channels is None:\n            latent_channels = 16 if \"taesd3\" in str(decoder_path) else 4\n\n        self.decoder = decoder(latent_channels)\n        self.decoder.load_state_dict(\n            torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))\n\n\nclass TAESDEncoder(nn.Module):\n    latent_magnitude = 3\n    latent_shift = 0.5\n\n    def __init__(self, encoder_path=\"taesd_encoder.pth\", latent_channels=None):\n        \"\"\"Initialize pretrained TAESD on the given device from the given checkpoints.\"\"\"\n        super().__init__()\n\n        if latent_channels is None:\n            latent_channels = 16 if \"taesd3\" in str(encoder_path) else 4\n\n        self.encoder = encoder(latent_channels)\n        self.encoder.load_state_dict(\n            torch.load(encoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))\n\n\ndef download_model(model_path, model_url):\n    if not os.path.exists(model_path):\n        os.makedirs(os.path.dirname(model_path), exist_ok=True)\n\n        print(f'Downloading TAESD model to: {model_path}')\n        torch.hub.download_url_to_file(model_url, model_path)\n\n\ndef decoder_model():\n    if shared.sd_model.is_sd3:\n        model_name = \"taesd3_decoder.pth\"\n    elif shared.sd_model.is_sdxl:\n        model_name = \"taesdxl_decoder.pth\"\n    else:\n        model_name = \"taesd_decoder.pth\"\n\n    loaded_model = sd_vae_taesd_models.get(model_name)\n\n    if loaded_model is None:\n        model_path = os.path.join(paths_internal.models_path, \"VAE-taesd\", model_name)\n        download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name)\n\n        if os.path.exists(model_path):\n            loaded_model = TAESDDecoder(model_path)\n            loaded_model.eval()\n            loaded_model.to(devices.device, devices.dtype)\n            sd_vae_taesd_models[model_name] = loaded_model\n        else:\n            raise FileNotFoundError('TAESD model not found')\n\n    return loaded_model.decoder\n\n\ndef encoder_model():\n    if shared.sd_model.is_sd3:\n        model_name = \"taesd3_encoder.pth\"\n    elif shared.sd_model.is_sdxl:\n        model_name = \"taesdxl_encoder.pth\"\n    else:\n        model_name = \"taesd_encoder.pth\"\n\n    loaded_model = sd_vae_taesd_models.get(model_name)\n\n    if loaded_model is None:\n        model_path = os.path.join(paths_internal.models_path, \"VAE-taesd\", model_name)\n        download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name)\n\n        if os.path.exists(model_path):\n            loaded_model = TAESDEncoder(model_path)\n            loaded_model.eval()\n            loaded_model.to(devices.device, devices.dtype)\n            sd_vae_taesd_models[model_name] = loaded_model\n        else:\n            raise FileNotFoundError('TAESD model not found')\n\n    return loaded_model.encoder\n"
  },
  {
    "path": "modules/shared.py",
    "content": "import os\r\nimport sys\r\n\r\nimport gradio as gr\r\n\r\nfrom modules import shared_cmd_options, shared_gradio_themes, options, shared_items, sd_models_types\r\nfrom modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir  # noqa: F401\r\nfrom modules import util\r\nfrom typing import TYPE_CHECKING\r\n\r\nif TYPE_CHECKING:\r\n    from modules import shared_state, styles, interrogate, shared_total_tqdm, memmon\r\n\r\ncmd_opts = shared_cmd_options.cmd_opts\r\nparser = shared_cmd_options.parser\r\n\r\nbatch_cond_uncond = True  # old field, unused now in favor of shared.opts.batch_cond_uncond\r\nparallel_processing_allowed = True\r\nstyles_filename = cmd_opts.styles_file = cmd_opts.styles_file if len(cmd_opts.styles_file) > 0 else [os.path.join(data_path, 'styles.csv')]\r\nconfig_filename = cmd_opts.ui_settings_file\r\nhide_dirs = {\"visible\": not cmd_opts.hide_ui_dir_config}\r\n\r\ndemo: gr.Blocks = None\r\n\r\ndevice: str = None\r\n\r\nweight_load_location: str = None\r\n\r\nxformers_available = False\r\n\r\nhypernetworks = {}\r\n\r\nloaded_hypernetworks = []\r\n\r\nstate: 'shared_state.State' = None\r\n\r\nprompt_styles: 'styles.StyleDatabase' = None\r\n\r\ninterrogator: 'interrogate.InterrogateModels' = None\r\n\r\nface_restorers = []\r\n\r\noptions_templates: dict = None\r\nopts: options.Options = None\r\nrestricted_opts: set[str] = None\r\n\r\nsd_model: sd_models_types.WebuiSdModel = None\r\n\r\nsettings_components: dict = None\r\n\"\"\"assigned from ui.py, a mapping on setting names to gradio components responsible for those settings\"\"\"\r\n\r\ntab_names = []\r\n\r\nlatent_upscale_default_mode = \"Latent\"\r\nlatent_upscale_modes = {\r\n    \"Latent\": {\"mode\": \"bilinear\", \"antialias\": False},\r\n    \"Latent (antialiased)\": {\"mode\": \"bilinear\", \"antialias\": True},\r\n    \"Latent (bicubic)\": {\"mode\": \"bicubic\", \"antialias\": False},\r\n    \"Latent (bicubic antialiased)\": {\"mode\": \"bicubic\", \"antialias\": True},\r\n    \"Latent (nearest)\": {\"mode\": \"nearest\", \"antialias\": False},\r\n    \"Latent (nearest-exact)\": {\"mode\": \"nearest-exact\", \"antialias\": False},\r\n}\r\n\r\nsd_upscalers = []\r\n\r\nclip_model = None\r\n\r\nprogress_print_out = sys.stdout\r\n\r\ngradio_theme = gr.themes.Base()\r\n\r\ntotal_tqdm: 'shared_total_tqdm.TotalTQDM' = None\r\n\r\nmem_mon: 'memmon.MemUsageMonitor' = None\r\n\r\noptions_section = options.options_section\r\nOptionInfo = options.OptionInfo\r\nOptionHTML = options.OptionHTML\r\n\r\nnatural_sort_key = util.natural_sort_key\r\nlistfiles = util.listfiles\r\nhtml_path = util.html_path\r\nhtml = util.html\r\nwalk_files = util.walk_files\r\nldm_print = util.ldm_print\r\n\r\nreload_gradio_theme = shared_gradio_themes.reload_gradio_theme\r\n\r\nlist_checkpoint_tiles = shared_items.list_checkpoint_tiles\r\nrefresh_checkpoints = shared_items.refresh_checkpoints\r\nlist_samplers = shared_items.list_samplers\r\nreload_hypernetworks = shared_items.reload_hypernetworks\r\n\r\nhf_endpoint = os.getenv('HF_ENDPOINT', 'https://huggingface.co')\r\n"
  },
  {
    "path": "modules/shared_cmd_options.py",
    "content": "import os\r\n\r\nimport launch\r\nfrom modules import cmd_args, script_loading\r\nfrom modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir  # noqa: F401\r\n\r\nparser = cmd_args.parser\r\n\r\nscript_loading.preload_extensions(extensions_dir, parser, extension_list=launch.list_extensions(launch.args.ui_settings_file))\r\nscript_loading.preload_extensions(extensions_builtin_dir, parser)\r\n\r\nif os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None:\r\n    cmd_opts = parser.parse_args()\r\nelse:\r\n    cmd_opts, _ = parser.parse_known_args()\r\n\r\ncmd_opts.webui_is_non_local = any([cmd_opts.share, cmd_opts.listen, cmd_opts.ngrok, cmd_opts.server_name])\r\ncmd_opts.disable_extension_access = cmd_opts.webui_is_non_local and not cmd_opts.enable_insecure_extension_access\r\n"
  },
  {
    "path": "modules/shared_gradio_themes.py",
    "content": "import os\r\n\r\nimport gradio as gr\r\n\r\nfrom modules import errors, shared\r\nfrom modules.paths_internal import script_path\r\n\r\n\r\n# https://huggingface.co/datasets/freddyaboulton/gradio-theme-subdomains/resolve/main/subdomains.json\r\ngradio_hf_hub_themes = [\r\n    \"gradio/base\",\r\n    \"gradio/glass\",\r\n    \"gradio/monochrome\",\r\n    \"gradio/seafoam\",\r\n    \"gradio/soft\",\r\n    \"gradio/dracula_test\",\r\n    \"abidlabs/dracula_test\",\r\n    \"abidlabs/Lime\",\r\n    \"abidlabs/pakistan\",\r\n    \"Ama434/neutral-barlow\",\r\n    \"dawood/microsoft_windows\",\r\n    \"finlaymacklon/smooth_slate\",\r\n    \"Franklisi/darkmode\",\r\n    \"freddyaboulton/dracula_revamped\",\r\n    \"freddyaboulton/test-blue\",\r\n    \"gstaff/xkcd\",\r\n    \"Insuz/Mocha\",\r\n    \"Insuz/SimpleIndigo\",\r\n    \"JohnSmith9982/small_and_pretty\",\r\n    \"nota-ai/theme\",\r\n    \"nuttea/Softblue\",\r\n    \"ParityError/Anime\",\r\n    \"reilnuud/polite\",\r\n    \"remilia/Ghostly\",\r\n    \"rottenlittlecreature/Moon_Goblin\",\r\n    \"step-3-profit/Midnight-Deep\",\r\n    \"Taithrah/Minimal\",\r\n    \"ysharma/huggingface\",\r\n    \"ysharma/steampunk\",\r\n    \"NoCrypt/miku\"\r\n]\r\n\r\n\r\ndef reload_gradio_theme(theme_name=None):\r\n    if not theme_name:\r\n        theme_name = shared.opts.gradio_theme\r\n\r\n    default_theme_args = dict(\r\n        font=[\"Source Sans Pro\", 'ui-sans-serif', 'system-ui', 'sans-serif'],\r\n        font_mono=['IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'],\r\n    )\r\n\r\n    if theme_name == \"Default\":\r\n        shared.gradio_theme = gr.themes.Default(**default_theme_args)\r\n    else:\r\n        try:\r\n            theme_cache_dir = os.path.join(script_path, 'tmp', 'gradio_themes')\r\n            theme_cache_path = os.path.join(theme_cache_dir, f'{theme_name.replace(\"/\", \"_\")}.json')\r\n            if shared.opts.gradio_themes_cache and os.path.exists(theme_cache_path):\r\n                shared.gradio_theme = gr.themes.ThemeClass.load(theme_cache_path)\r\n            else:\r\n                os.makedirs(theme_cache_dir, exist_ok=True)\r\n                shared.gradio_theme = gr.themes.ThemeClass.from_hub(theme_name)\r\n                shared.gradio_theme.dump(theme_cache_path)\r\n        except Exception as e:\r\n            errors.display(e, \"changing gradio theme\")\r\n            shared.gradio_theme = gr.themes.Default(**default_theme_args)\r\n\r\n    # append additional values gradio_theme\r\n    shared.gradio_theme.sd_webui_modal_lightbox_toolbar_opacity = shared.opts.sd_webui_modal_lightbox_toolbar_opacity\r\n    shared.gradio_theme.sd_webui_modal_lightbox_icon_opacity = shared.opts.sd_webui_modal_lightbox_icon_opacity\r\n\r\n\r\ndef resolve_var(name: str, gradio_theme=None, history=None):\r\n    \"\"\"\r\n    Attempt to resolve a theme variable name to its value\r\n\r\n    Parameters:\r\n        name (str): The name of the theme variable\r\n            ie \"background_fill_primary\", \"background_fill_primary_dark\"\r\n            spaces and asterisk (*) prefix is removed from name before lookup\r\n        gradio_theme (gradio.themes.ThemeClass): The theme object to resolve the variable from\r\n            blank to use the webui default shared.gradio_theme\r\n        history (list): A list of previously resolved variables to prevent circular references\r\n            for regular use leave blank\r\n    Returns:\r\n        str: The resolved value\r\n\r\n    Error handling:\r\n        return either #000000 or #ffffff depending on initial name ending with \"_dark\"\r\n    \"\"\"\r\n    try:\r\n        if history is None:\r\n            history = []\r\n        if gradio_theme is None:\r\n            gradio_theme = shared.gradio_theme\r\n\r\n        name = name.strip()\r\n        name = name[1:] if name.startswith(\"*\") else name\r\n\r\n        if name in history:\r\n            raise ValueError(f'Circular references: name \"{name}\" in {history}')\r\n\r\n        if value := getattr(gradio_theme, name, None):\r\n            return resolve_var(value, gradio_theme, history + [name])\r\n        else:\r\n            return name\r\n\r\n    except Exception:\r\n        name = history[0] if history else name\r\n        errors.report(f'resolve_color({name})', exc_info=True)\r\n        return '#000000' if name.endswith(\"_dark\") else '#ffffff'\r\n"
  },
  {
    "path": "modules/shared_init.py",
    "content": "import os\r\n\r\nimport torch\r\n\r\nfrom modules import shared\r\nfrom modules.shared import cmd_opts\r\n\r\n\r\ndef initialize():\r\n    \"\"\"Initializes fields inside the shared module in a controlled manner.\r\n\r\n    Should be called early because some other modules you can import mingt need these fields to be already set.\r\n    \"\"\"\r\n\r\n    os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)\r\n\r\n    from modules import options, shared_options\r\n    shared.options_templates = shared_options.options_templates\r\n    shared.opts = options.Options(shared_options.options_templates, shared_options.restricted_opts)\r\n    shared.restricted_opts = shared_options.restricted_opts\r\n    try:\r\n        shared.opts.load(shared.config_filename)\r\n    except FileNotFoundError:\r\n        pass\r\n\r\n    from modules import devices\r\n    devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \\\r\n        (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer'])\r\n\r\n    devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16\r\n    devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16\r\n    devices.dtype_inference = torch.float32 if cmd_opts.precision == 'full' else devices.dtype\r\n\r\n    if cmd_opts.precision == \"half\":\r\n        msg = \"--no-half and --no-half-vae conflict with --precision half\"\r\n        assert devices.dtype == torch.float16, msg\r\n        assert devices.dtype_vae == torch.float16, msg\r\n        assert devices.dtype_inference == torch.float16, msg\r\n        devices.force_fp16 = True\r\n        devices.force_model_fp16()\r\n\r\n    shared.device = devices.device\r\n    shared.weight_load_location = None if cmd_opts.lowram else \"cpu\"\r\n\r\n    from modules import shared_state\r\n    shared.state = shared_state.State()\r\n\r\n    from modules import styles\r\n    shared.prompt_styles = styles.StyleDatabase(shared.styles_filename)\r\n\r\n    from modules import interrogate\r\n    shared.interrogator = interrogate.InterrogateModels(\"interrogate\")\r\n\r\n    from modules import shared_total_tqdm\r\n    shared.total_tqdm = shared_total_tqdm.TotalTQDM()\r\n\r\n    from modules import memmon, devices\r\n    shared.mem_mon = memmon.MemUsageMonitor(\"MemMon\", devices.device, shared.opts)\r\n    shared.mem_mon.start()\r\n\r\n"
  },
  {
    "path": "modules/shared_items.py",
    "content": "import html\r\nimport sys\r\n\r\nfrom modules import script_callbacks, scripts, ui_components\r\nfrom modules.options import OptionHTML, OptionInfo\r\nfrom modules.shared_cmd_options import cmd_opts\r\n\r\n\r\ndef realesrgan_models_names():\r\n    import modules.realesrgan_model\r\n    return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]\r\n\r\n\r\ndef dat_models_names():\r\n    import modules.dat_model\r\n    return [x.name for x in modules.dat_model.get_dat_models(None)]\r\n\r\n\r\ndef postprocessing_scripts():\r\n    import modules.scripts\r\n\r\n    return modules.scripts.scripts_postproc.scripts\r\n\r\n\r\ndef sd_vae_items():\r\n    import modules.sd_vae\r\n\r\n    return [\"Automatic\", \"None\"] + list(modules.sd_vae.vae_dict)\r\n\r\n\r\ndef refresh_vae_list():\r\n    import modules.sd_vae\r\n\r\n    modules.sd_vae.refresh_vae_list()\r\n\r\n\r\ndef cross_attention_optimizations():\r\n    import modules.sd_hijack\r\n\r\n    return [\"Automatic\"] + [x.title() for x in modules.sd_hijack.optimizers] + [\"None\"]\r\n\r\n\r\ndef sd_unet_items():\r\n    import modules.sd_unet\r\n\r\n    return [\"Automatic\"] + [x.label for x in modules.sd_unet.unet_options] + [\"None\"]\r\n\r\n\r\ndef refresh_unet_list():\r\n    import modules.sd_unet\r\n\r\n    modules.sd_unet.list_unets()\r\n\r\n\r\ndef list_checkpoint_tiles(use_short=False):\r\n    import modules.sd_models\r\n    return modules.sd_models.checkpoint_tiles(use_short)\r\n\r\n\r\ndef refresh_checkpoints():\r\n    import modules.sd_models\r\n    return modules.sd_models.list_models()\r\n\r\n\r\ndef list_samplers():\r\n    import modules.sd_samplers\r\n    return modules.sd_samplers.all_samplers\r\n\r\n\r\ndef reload_hypernetworks():\r\n    from modules.hypernetworks import hypernetwork\r\n    from modules import shared\r\n\r\n    shared.hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)\r\n\r\n\r\ndef get_infotext_names():\r\n    from modules import infotext_utils, shared\r\n    res = {}\r\n\r\n    for info in shared.opts.data_labels.values():\r\n        if info.infotext:\r\n            res[info.infotext] = 1\r\n\r\n    for tab_data in infotext_utils.paste_fields.values():\r\n        for _, name in tab_data.get(\"fields\") or []:\r\n            if isinstance(name, str):\r\n                res[name] = 1\r\n\r\n    return list(res)\r\n\r\n\r\nui_reorder_categories_builtin_items = [\r\n    \"prompt\",\r\n    \"image\",\r\n    \"inpaint\",\r\n    \"sampler\",\r\n    \"accordions\",\r\n    \"checkboxes\",\r\n    \"dimensions\",\r\n    \"cfg\",\r\n    \"denoising\",\r\n    \"seed\",\r\n    \"batch\",\r\n    \"override_settings\",\r\n]\r\n\r\n\r\ndef ui_reorder_categories():\r\n    from modules import scripts\r\n\r\n    yield from ui_reorder_categories_builtin_items\r\n\r\n    sections = {}\r\n    for script in scripts.scripts_txt2img.scripts + scripts.scripts_img2img.scripts:\r\n        if isinstance(script.section, str) and script.section not in ui_reorder_categories_builtin_items:\r\n            sections[script.section] = 1\r\n\r\n    yield from sections\r\n\r\n    yield \"scripts\"\r\n\r\n\r\ndef callbacks_order_settings():\r\n    options = {\r\n        \"sd_vae_explanation\": OptionHTML(\"\"\"\r\n    For categories below, callbacks added to dropdowns happen before others, in order listed.\r\n    \"\"\"),\r\n\r\n    }\r\n\r\n    callback_options = {}\r\n\r\n    for category, _ in script_callbacks.enumerate_callbacks():\r\n        callback_options[category] = script_callbacks.ordered_callbacks(category, enable_user_sort=False)\r\n\r\n    for method_name in scripts.scripts_txt2img.callback_names:\r\n        callback_options[\"script_\" + method_name] = scripts.scripts_txt2img.create_ordered_callbacks_list(method_name, enable_user_sort=False)\r\n\r\n    for method_name in scripts.scripts_img2img.callback_names:\r\n        callbacks = callback_options.get(\"script_\" + method_name, [])\r\n\r\n        for addition in scripts.scripts_img2img.create_ordered_callbacks_list(method_name, enable_user_sort=False):\r\n            if any(x.name == addition.name for x in callbacks):\r\n                continue\r\n\r\n            callbacks.append(addition)\r\n\r\n        callback_options[\"script_\" + method_name] = callbacks\r\n\r\n    for category, callbacks in callback_options.items():\r\n        if not callbacks:\r\n            continue\r\n\r\n        option_info = OptionInfo([], f\"{category} callback priority\", ui_components.DropdownMulti, {\"choices\": [x.name for x in callbacks]})\r\n        option_info.needs_restart()\r\n        option_info.html(\"<div class='info'>Default order: <ol>\" + \"\".join(f\"<li>{html.escape(x.name)}</li>\\n\" for x in callbacks) + \"</ol></div>\")\r\n        options['prioritized_callbacks_' + category] = option_info\r\n\r\n    return options\r\n\r\n\r\nclass Shared(sys.modules[__name__].__class__):\r\n    \"\"\"\r\n    this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than\r\n    at program startup.\r\n    \"\"\"\r\n\r\n    sd_model_val = None\r\n\r\n    @property\r\n    def sd_model(self):\r\n        import modules.sd_models\r\n\r\n        return modules.sd_models.model_data.get_sd_model()\r\n\r\n    @sd_model.setter\r\n    def sd_model(self, value):\r\n        import modules.sd_models\r\n\r\n        modules.sd_models.model_data.set_sd_model(value)\r\n\r\n\r\nsys.modules['modules.shared'].__class__ = Shared\r\n"
  },
  {
    "path": "modules/shared_options.py",
    "content": "import os\r\nimport gradio as gr\r\n\r\nfrom modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes, util, sd_emphasis\r\nfrom modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir, default_output_dir  # noqa: F401\r\nfrom modules.shared_cmd_options import cmd_opts\r\nfrom modules.options import options_section, OptionInfo, OptionHTML, categories\r\n\r\noptions_templates = {}\r\nhide_dirs = shared.hide_dirs\r\n\r\nrestricted_opts = {\r\n    \"samples_filename_pattern\",\r\n    \"directories_filename_pattern\",\r\n    \"outdir_samples\",\r\n    \"outdir_txt2img_samples\",\r\n    \"outdir_img2img_samples\",\r\n    \"outdir_extras_samples\",\r\n    \"outdir_grids\",\r\n    \"outdir_txt2img_grids\",\r\n    \"outdir_save\",\r\n    \"outdir_init_images\",\r\n    \"temp_dir\",\r\n    \"clean_temp_dir_at_start\",\r\n}\r\n\r\ncategories.register_category(\"saving\", \"Saving images\")\r\ncategories.register_category(\"sd\", \"Stable Diffusion\")\r\ncategories.register_category(\"ui\", \"User Interface\")\r\ncategories.register_category(\"system\", \"System\")\r\ncategories.register_category(\"postprocessing\", \"Postprocessing\")\r\ncategories.register_category(\"training\", \"Training\")\r\n\r\noptions_templates.update(options_section(('saving-images', \"Saving images/grids\", \"saving\"), {\r\n    \"samples_save\": OptionInfo(True, \"Always save all generated images\"),\r\n    \"samples_format\": OptionInfo('png', 'File format for images'),\r\n    \"samples_filename_pattern\": OptionInfo(\"\", \"Images filename pattern\", component_args=hide_dirs).link(\"wiki\", \"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory\"),\r\n    \"save_images_add_number\": OptionInfo(True, \"Add number to filename when saving\", component_args=hide_dirs),\r\n    \"save_images_replace_action\": OptionInfo(\"Replace\", \"Saving the image to an existing file\", gr.Radio, {\"choices\": [\"Replace\", \"Add number suffix\"], **hide_dirs}),\r\n    \"grid_save\": OptionInfo(True, \"Always save all generated image grids\"),\r\n    \"grid_format\": OptionInfo('png', 'File format for grids'),\r\n    \"grid_extended_filename\": OptionInfo(False, \"Add extended info (seed, prompt) to filename when saving grid\"),\r\n    \"grid_only_if_multiple\": OptionInfo(True, \"Do not save grids consisting of one picture\"),\r\n    \"grid_prevent_empty_spots\": OptionInfo(False, \"Prevent empty spots in grid (when set to autodetect)\"),\r\n    \"grid_zip_filename_pattern\": OptionInfo(\"\", \"Archive filename pattern\", component_args=hide_dirs).link(\"wiki\", \"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory\"),\r\n    \"n_rows\": OptionInfo(-1, \"Grid row count; use -1 for autodetect and 0 for it to be same as batch size\", gr.Slider, {\"minimum\": -1, \"maximum\": 16, \"step\": 1}),\r\n    \"font\": OptionInfo(\"\", \"Font for image grids that have text\"),\r\n    \"grid_text_active_color\": OptionInfo(\"#000000\", \"Text color for image grids\", ui_components.FormColorPicker, {}),\r\n    \"grid_text_inactive_color\": OptionInfo(\"#999999\", \"Inactive text color for image grids\", ui_components.FormColorPicker, {}),\r\n    \"grid_background_color\": OptionInfo(\"#ffffff\", \"Background color for image grids\", ui_components.FormColorPicker, {}),\r\n\r\n    \"save_images_before_face_restoration\": OptionInfo(False, \"Save a copy of image before doing face restoration.\"),\r\n    \"save_images_before_highres_fix\": OptionInfo(False, \"Save a copy of image before applying highres fix.\"),\r\n    \"save_images_before_color_correction\": OptionInfo(False, \"Save a copy of image before applying color correction to img2img results\"),\r\n    \"save_mask\": OptionInfo(False, \"For inpainting, save a copy of the greyscale mask\"),\r\n    \"save_mask_composite\": OptionInfo(False, \"For inpainting, save a masked composite\"),\r\n    \"jpeg_quality\": OptionInfo(80, \"Quality for saved jpeg and avif images\", gr.Slider, {\"minimum\": 1, \"maximum\": 100, \"step\": 1}),\r\n    \"webp_lossless\": OptionInfo(False, \"Use lossless compression for webp images\"),\r\n    \"export_for_4chan\": OptionInfo(True, \"Save copy of large images as JPG\").info(\"if the file size is above the limit, or either width or height are above the limit\"),\r\n    \"img_downscale_threshold\": OptionInfo(4.0, \"File size limit for the above option, MB\", gr.Number),\r\n    \"target_side_length\": OptionInfo(4000, \"Width/height limit for the above option, in pixels\", gr.Number),\r\n    \"img_max_size_mp\": OptionInfo(200, \"Maximum image size\", gr.Number).info(\"in megapixels\"),\r\n\r\n    \"use_original_name_batch\": OptionInfo(True, \"Use original name for output filename during batch process in extras tab\"),\r\n    \"use_upscaler_name_as_suffix\": OptionInfo(False, \"Use upscaler name as filename suffix in the extras tab\"),\r\n    \"save_selected_only\": OptionInfo(True, \"When using 'Save' button, only save a single selected image\"),\r\n    \"save_write_log_csv\": OptionInfo(True, \"Write log.csv when saving images using 'Save' button\"),\r\n    \"save_init_img\": OptionInfo(False, \"Save init images when using img2img\"),\r\n\r\n    \"temp_dir\":  OptionInfo(\"\", \"Directory for temporary images; leave empty for default\"),\r\n    \"clean_temp_dir_at_start\": OptionInfo(False, \"Cleanup non-default temporary directory when starting webui\"),\r\n\r\n    \"save_incomplete_images\": OptionInfo(False, \"Save incomplete images\").info(\"save images that has been interrupted in mid-generation; even if not saved, they will still show up in webui output.\"),\r\n\r\n    \"notification_audio\": OptionInfo(True, \"Play notification sound after image generation\").info(\"notification.mp3 should be present in the root directory\").needs_reload_ui(),\r\n    \"notification_volume\": OptionInfo(100, \"Notification sound volume\", gr.Slider, {\"minimum\": 0, \"maximum\": 100, \"step\": 1}).info(\"in %\"),\r\n}))\r\n\r\noptions_templates.update(options_section(('saving-paths', \"Paths for saving\", \"saving\"), {\r\n    \"outdir_samples\": OptionInfo(\"\", \"Output directory for images; if empty, defaults to three directories below\", component_args=hide_dirs),\r\n    \"outdir_txt2img_samples\": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'txt2img-images')), 'Output directory for txt2img images', component_args=hide_dirs),\r\n    \"outdir_img2img_samples\": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'img2img-images')), 'Output directory for img2img images', component_args=hide_dirs),\r\n    \"outdir_extras_samples\": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'extras-images')), 'Output directory for images from extras tab', component_args=hide_dirs),\r\n    \"outdir_grids\": OptionInfo(\"\", \"Output directory for grids; if empty, defaults to two directories below\", component_args=hide_dirs),\r\n    \"outdir_txt2img_grids\": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'txt2img-grids')), 'Output directory for txt2img grids', component_args=hide_dirs),\r\n    \"outdir_img2img_grids\": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'img2img-grids')), 'Output directory for img2img grids', component_args=hide_dirs),\r\n    \"outdir_save\": OptionInfo(util.truncate_path(os.path.join(data_path, 'log', 'images')), \"Directory for saving images using the Save button\", component_args=hide_dirs),\r\n    \"outdir_init_images\": OptionInfo(util.truncate_path(os.path.join(default_output_dir, 'init-images')), \"Directory for saving init images when using img2img\", component_args=hide_dirs),\r\n}))\r\n\r\noptions_templates.update(options_section(('saving-to-dirs', \"Saving to a directory\", \"saving\"), {\r\n    \"save_to_dirs\": OptionInfo(True, \"Save images to a subdirectory\"),\r\n    \"grid_save_to_dirs\": OptionInfo(True, \"Save grids to a subdirectory\"),\r\n    \"use_save_to_dirs_for_ui\": OptionInfo(False, \"When using \\\"Save\\\" button, save images to a subdirectory\"),\r\n    \"directories_filename_pattern\": OptionInfo(\"[date]\", \"Directory name pattern\", component_args=hide_dirs).link(\"wiki\", \"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory\"),\r\n    \"directories_max_prompt_words\": OptionInfo(8, \"Max prompt words for [prompt_words] pattern\", gr.Slider, {\"minimum\": 1, \"maximum\": 20, \"step\": 1, **hide_dirs}),\r\n}))\r\n\r\noptions_templates.update(options_section(('upscaling', \"Upscaling\", \"postprocessing\"), {\r\n    \"ESRGAN_tile\": OptionInfo(192, \"Tile size for ESRGAN upscalers.\", gr.Slider, {\"minimum\": 0, \"maximum\": 512, \"step\": 16}).info(\"0 = no tiling\"),\r\n    \"ESRGAN_tile_overlap\": OptionInfo(8, \"Tile overlap for ESRGAN upscalers.\", gr.Slider, {\"minimum\": 0, \"maximum\": 48, \"step\": 1}).info(\"Low values = visible seam\"),\r\n    \"realesrgan_enabled_models\": OptionInfo([\"R-ESRGAN 4x+\", \"R-ESRGAN 4x+ Anime6B\"], \"Select which Real-ESRGAN models to show in the web UI.\", gr.CheckboxGroup, lambda: {\"choices\": shared_items.realesrgan_models_names()}),\r\n    \"dat_enabled_models\": OptionInfo([\"DAT x2\", \"DAT x3\", \"DAT x4\"], \"Select which DAT models to show in the web UI.\", gr.CheckboxGroup, lambda: {\"choices\": shared_items.dat_models_names()}),\r\n    \"DAT_tile\": OptionInfo(192, \"Tile size for DAT upscalers.\", gr.Slider, {\"minimum\": 0, \"maximum\": 512, \"step\": 16}).info(\"0 = no tiling\"),\r\n    \"DAT_tile_overlap\": OptionInfo(8, \"Tile overlap for DAT upscalers.\", gr.Slider, {\"minimum\": 0, \"maximum\": 48, \"step\": 1}).info(\"Low values = visible seam\"),\r\n    \"upscaler_for_img2img\": OptionInfo(None, \"Upscaler for img2img\", gr.Dropdown, lambda: {\"choices\": [x.name for x in shared.sd_upscalers]}),\r\n    \"set_scale_by_when_changing_upscaler\": OptionInfo(False, \"Automatically set the Scale by factor based on the name of the selected Upscaler.\"),\r\n}))\r\n\r\noptions_templates.update(options_section(('face-restoration', \"Face restoration\", \"postprocessing\"), {\r\n    \"face_restoration\": OptionInfo(False, \"Restore faces\", infotext='Face restoration').info(\"will use a third-party model on generation result to reconstruct faces\"),\r\n    \"face_restoration_model\": OptionInfo(\"CodeFormer\", \"Face restoration model\", gr.Radio, lambda: {\"choices\": [x.name() for x in shared.face_restorers]}),\r\n    \"code_former_weight\": OptionInfo(0.5, \"CodeFormer weight\", gr.Slider, {\"minimum\": 0, \"maximum\": 1, \"step\": 0.01}).info(\"0 = maximum effect; 1 = minimum effect\"),\r\n    \"face_restoration_unload\": OptionInfo(False, \"Move face restoration model from VRAM into RAM after processing\"),\r\n}))\r\n\r\noptions_templates.update(options_section(('system', \"System\", \"system\"), {\r\n    \"auto_launch_browser\": OptionInfo(\"Local\", \"Automatically open webui in browser on startup\", gr.Radio, lambda: {\"choices\": [\"Disable\", \"Local\", \"Remote\"]}),\r\n    \"enable_console_prompts\": OptionInfo(shared.cmd_opts.enable_console_prompts, \"Print prompts to console when generating with txt2img and img2img.\"),\r\n    \"show_warnings\": OptionInfo(False, \"Show warnings in console.\").needs_reload_ui(),\r\n    \"show_gradio_deprecation_warnings\": OptionInfo(True, \"Show gradio deprecation warnings in console.\").needs_reload_ui(),\r\n    \"memmon_poll_rate\": OptionInfo(8, \"VRAM usage polls per second during generation.\", gr.Slider, {\"minimum\": 0, \"maximum\": 40, \"step\": 1}).info(\"0 = disable\"),\r\n    \"samples_log_stdout\": OptionInfo(False, \"Always print all generation info to standard output\"),\r\n    \"multiple_tqdm\": OptionInfo(True, \"Add a second progress bar to the console that shows progress for an entire job.\"),\r\n    \"enable_upscale_progressbar\": OptionInfo(True, \"Show a progress bar in the console for tiled upscaling.\"),\r\n    \"print_hypernet_extra\": OptionInfo(False, \"Print extra hypernetwork information to console.\"),\r\n    \"list_hidden_files\": OptionInfo(True, \"Load models/files in hidden directories\").info(\"directory is hidden if its name starts with \\\".\\\"\"),\r\n    \"disable_mmap_load_safetensors\": OptionInfo(False, \"Disable memmapping for loading .safetensors files.\").info(\"fixes very slow loading speed in some cases\"),\r\n    \"hide_ldm_prints\": OptionInfo(True, \"Prevent Stability-AI's ldm/sgm modules from printing noise to console.\"),\r\n    \"dump_stacks_on_signal\": OptionInfo(False, \"Print stack traces before exiting the program with ctrl+c.\"),\r\n}))\r\n\r\noptions_templates.update(options_section(('profiler', \"Profiler\", \"system\"), {\r\n    \"profiling_explanation\": OptionHTML(\"\"\"\r\nThose settings allow you to enable torch profiler when generating pictures.\r\nProfiling allows you to see which code uses how much of computer's resources during generation.\r\nEach generation writes its own profile to one file, overwriting previous.\r\nThe file can be viewed in <a href=\"chrome:tracing\">Chrome</a>, or on a <a href=\"https://ui.perfetto.dev/\">Perfetto</a> web site.\r\nWarning: writing profile can take a lot of time, up to 30 seconds, and the file itelf can be around 500MB in size.\r\n\"\"\"),\r\n    \"profiling_enable\": OptionInfo(False, \"Enable profiling\"),\r\n    \"profiling_activities\": OptionInfo([\"CPU\"], \"Activities\", gr.CheckboxGroup, {\"choices\": [\"CPU\", \"CUDA\"]}),\r\n    \"profiling_record_shapes\": OptionInfo(True, \"Record shapes\"),\r\n    \"profiling_profile_memory\": OptionInfo(True, \"Profile memory\"),\r\n    \"profiling_with_stack\": OptionInfo(True, \"Include python stack\"),\r\n    \"profiling_filename\": OptionInfo(\"trace.json\", \"Profile filename\"),\r\n}))\r\n\r\noptions_templates.update(options_section(('API', \"API\", \"system\"), {\r\n    \"api_enable_requests\": OptionInfo(True, \"Allow http:// and https:// URLs for input images in API\", restrict_api=True),\r\n    \"api_forbid_local_requests\": OptionInfo(True, \"Forbid URLs to local resources\", restrict_api=True),\r\n    \"api_useragent\": OptionInfo(\"\", \"User agent for requests\", restrict_api=True),\r\n}))\r\n\r\noptions_templates.update(options_section(('training', \"Training\", \"training\"), {\r\n    \"unload_models_when_training\": OptionInfo(False, \"Move VAE and CLIP to RAM when training if possible. Saves VRAM.\"),\r\n    \"pin_memory\": OptionInfo(False, \"Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage.\"),\r\n    \"save_optimizer_state\": OptionInfo(False, \"Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file.\"),\r\n    \"save_training_settings_to_txt\": OptionInfo(True, \"Save textual inversion and hypernet settings to a text file whenever training starts.\"),\r\n    \"dataset_filename_word_regex\": OptionInfo(\"\", \"Filename word regex\"),\r\n    \"dataset_filename_join_string\": OptionInfo(\" \", \"Filename join string\"),\r\n    \"training_image_repeats_per_epoch\": OptionInfo(1, \"Number of repeats for a single input image per epoch; used only for displaying epoch number\", gr.Number, {\"precision\": 0}),\r\n    \"training_write_csv_every\": OptionInfo(500, \"Save an csv containing the loss to log directory every N steps, 0 to disable\"),\r\n    \"training_xattention_optimizations\": OptionInfo(False, \"Use cross attention optimizations while training\"),\r\n    \"training_enable_tensorboard\": OptionInfo(False, \"Enable tensorboard logging.\"),\r\n    \"training_tensorboard_save_images\": OptionInfo(False, \"Save generated images within tensorboard.\"),\r\n    \"training_tensorboard_flush_every\": OptionInfo(120, \"How often, in seconds, to flush the pending tensorboard events and summaries to disk.\"),\r\n}))\r\n\r\noptions_templates.update(options_section(('sd', \"Stable Diffusion\", \"sd\"), {\r\n    \"sd_model_checkpoint\": OptionInfo(None, \"Stable Diffusion checkpoint\", gr.Dropdown, lambda: {\"choices\": shared_items.list_checkpoint_tiles(shared.opts.sd_checkpoint_dropdown_use_short)}, refresh=shared_items.refresh_checkpoints, infotext='Model hash'),\r\n    \"sd_checkpoints_limit\": OptionInfo(1, \"Maximum number of checkpoints loaded at the same time\", gr.Slider, {\"minimum\": 1, \"maximum\": 10, \"step\": 1}),\r\n    \"sd_checkpoints_keep_in_cpu\": OptionInfo(True, \"Only keep one model on device\").info(\"will keep models other than the currently used one in RAM rather than VRAM\"),\r\n    \"sd_checkpoint_cache\": OptionInfo(0, \"Checkpoints to cache in RAM\", gr.Slider, {\"minimum\": 0, \"maximum\": 10, \"step\": 1}).info(\"obsolete; set to 0 and use the two settings above instead\"),\r\n    \"sd_unet\": OptionInfo(\"Automatic\", \"SD Unet\", gr.Dropdown, lambda: {\"choices\": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list).info(\"choose Unet model: Automatic = use one with same filename as checkpoint; None = use Unet from checkpoint\"),\r\n    \"enable_quantization\": OptionInfo(False, \"Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds\").needs_reload_ui(),\r\n    \"emphasis\": OptionInfo(\"Original\", \"Emphasis mode\", gr.Radio, lambda: {\"choices\": [x.name for x in sd_emphasis.options]}, infotext=\"Emphasis\").info(\"makes it possible to make model to pay (more:1.1) or (less:0.9) attention to text when you use the syntax in prompt; \" + sd_emphasis.get_options_descriptions()),\r\n    \"enable_batch_seeds\": OptionInfo(True, \"Make K-diffusion samplers produce same images in a batch as when making a single image\"),\r\n    \"comma_padding_backtrack\": OptionInfo(20, \"Prompt word wrap length limit\", gr.Slider, {\"minimum\": 0, \"maximum\": 74, \"step\": 1}).info(\"in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk\"),\r\n    \"sdxl_clip_l_skip\": OptionInfo(False, \"Clip skip SDXL\", gr.Checkbox).info(\"Enable Clip skip for the secondary clip model in sdxl. Has no effect on SD 1.5 or SD 2.0/2.1.\"),\r\n    \"CLIP_stop_at_last_layers\": OptionInfo(1, \"Clip skip\", gr.Slider, {\"minimum\": 1, \"maximum\": 12, \"step\": 1}, infotext=\"Clip skip\").link(\"wiki\", \"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip\").info(\"ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer\"),\r\n    \"upcast_attn\": OptionInfo(False, \"Upcast cross attention layer to float32\"),\r\n    \"randn_source\": OptionInfo(\"GPU\", \"Random number generator source.\", gr.Radio, {\"choices\": [\"GPU\", \"CPU\", \"NV\"]}, infotext=\"RNG\").info(\"changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards\"),\r\n    \"tiling\": OptionInfo(False, \"Tiling\", infotext='Tiling').info(\"produce a tileable picture\"),\r\n    \"hires_fix_refiner_pass\": OptionInfo(\"second pass\", \"Hires fix: which pass to enable refiner for\", gr.Radio, {\"choices\": [\"first pass\", \"second pass\", \"both passes\"]}, infotext=\"Hires refiner\"),\r\n}))\r\n\r\noptions_templates.update(options_section(('sdxl', \"Stable Diffusion XL\", \"sd\"), {\r\n    \"sdxl_crop_top\": OptionInfo(0, \"crop top coordinate\"),\r\n    \"sdxl_crop_left\": OptionInfo(0, \"crop left coordinate\"),\r\n    \"sdxl_refiner_low_aesthetic_score\": OptionInfo(2.5, \"SDXL low aesthetic score\", gr.Number).info(\"used for refiner model negative prompt\"),\r\n    \"sdxl_refiner_high_aesthetic_score\": OptionInfo(6.0, \"SDXL high aesthetic score\", gr.Number).info(\"used for refiner model prompt\"),\r\n}))\r\n\r\noptions_templates.update(options_section(('sd3', \"Stable Diffusion 3\", \"sd\"), {\r\n    \"sd3_enable_t5\": OptionInfo(False, \"Enable T5\").info(\"load T5 text encoder; increases VRAM use by a lot, potentially improving quality of generation; requires model reload to apply\"),\r\n}))\r\n\r\noptions_templates.update(options_section(('vae', \"VAE\", \"sd\"), {\r\n    \"sd_vae_explanation\": OptionHTML(\"\"\"\r\n<abbr title='Variational autoencoder'>VAE</abbr> is a neural network that transforms a standard <abbr title='red/green/blue'>RGB</abbr>\r\nimage into latent space representation and back. Latent space representation is what stable diffusion is working on during sampling\r\n(i.e. when the progress bar is between empty and full). For txt2img, VAE is used to create a resulting image after the sampling is finished.\r\nFor img2img, VAE is used to process user's input image before the sampling, and to create an image after sampling.\r\n\"\"\"),\r\n    \"sd_vae_checkpoint_cache\": OptionInfo(0, \"VAE Checkpoints to cache in RAM\", gr.Slider, {\"minimum\": 0, \"maximum\": 10, \"step\": 1}),\r\n    \"sd_vae\": OptionInfo(\"Automatic\", \"SD VAE\", gr.Dropdown, lambda: {\"choices\": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list, infotext='VAE').info(\"choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint\"),\r\n    \"sd_vae_overrides_per_model_preferences\": OptionInfo(True, \"Selected VAE overrides per-model preferences\").info(\"you can set per-model VAE either by editing user metadata for checkpoints, or by making the VAE have same name as checkpoint\"),\r\n    \"auto_vae_precision_bfloat16\": OptionInfo(False, \"Automatically convert VAE to bfloat16\").info(\"triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image; if enabled, overrides the option below\"),\r\n    \"auto_vae_precision\": OptionInfo(True, \"Automatically revert VAE to 32-bit floats\").info(\"triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image\"),\r\n    \"sd_vae_encode_method\": OptionInfo(\"Full\", \"VAE type for encode\", gr.Radio, {\"choices\": [\"Full\", \"TAESD\"]}, infotext='VAE Encoder').info(\"method to encode image to latent (use in img2img, hires-fix or inpaint mask)\"),\r\n    \"sd_vae_decode_method\": OptionInfo(\"Full\", \"VAE type for decode\", gr.Radio, {\"choices\": [\"Full\", \"TAESD\"]}, infotext='VAE Decoder').info(\"method to decode latent to image\"),\r\n}))\r\n\r\noptions_templates.update(options_section(('img2img', \"img2img\", \"sd\"), {\r\n    \"inpainting_mask_weight\": OptionInfo(1.0, \"Inpainting conditioning mask strength\", gr.Slider, {\"minimum\": 0.0, \"maximum\": 1.0, \"step\": 0.01}, infotext='Conditional mask weight'),\r\n    \"initial_noise_multiplier\": OptionInfo(1.0, \"Noise multiplier for img2img\", gr.Slider, {\"minimum\": 0.0, \"maximum\": 1.5, \"step\": 0.001}, infotext='Noise multiplier'),\r\n    \"img2img_extra_noise\": OptionInfo(0.0, \"Extra noise multiplier for img2img and hires fix\", gr.Slider, {\"minimum\": 0.0, \"maximum\": 1.0, \"step\": 0.01}, infotext='Extra noise').info(\"0 = disabled (default); should be lower than denoising strength\"),\r\n    \"img2img_color_correction\": OptionInfo(False, \"Apply color correction to img2img results to match original colors.\"),\r\n    \"img2img_fix_steps\": OptionInfo(False, \"With img2img, do exactly the amount of steps the slider specifies.\").info(\"normally you'd do less with less denoising\"),\r\n    \"img2img_background_color\": OptionInfo(\"#ffffff\", \"With img2img, fill transparent parts of the input image with this color.\", ui_components.FormColorPicker, {}),\r\n    \"img2img_editor_height\": OptionInfo(720, \"Height of the image editor\", gr.Slider, {\"minimum\": 80, \"maximum\": 1600, \"step\": 1}).info(\"in pixels\").needs_reload_ui(),\r\n    \"img2img_sketch_default_brush_color\": OptionInfo(\"#ffffff\", \"Sketch initial brush color\", ui_components.FormColorPicker, {}).info(\"default brush color of img2img sketch\").needs_reload_ui(),\r\n    \"img2img_inpaint_mask_brush_color\": OptionInfo(\"#ffffff\", \"Inpaint mask brush color\", ui_components.FormColorPicker,  {}).info(\"brush color of inpaint mask\").needs_reload_ui(),\r\n    \"img2img_inpaint_sketch_default_brush_color\": OptionInfo(\"#ffffff\", \"Inpaint sketch initial brush color\", ui_components.FormColorPicker, {}).info(\"default brush color of img2img inpaint sketch\").needs_reload_ui(),\r\n    \"return_mask\": OptionInfo(False, \"For inpainting, include the greyscale mask in results for web\"),\r\n    \"return_mask_composite\": OptionInfo(False, \"For inpainting, include masked composite in results for web\"),\r\n    \"img2img_batch_show_results_limit\": OptionInfo(32, \"Show the first N batch img2img results in UI\", gr.Slider, {\"minimum\": -1, \"maximum\": 1000, \"step\": 1}).info('0: disable, -1: show all images. Too many images can cause lag'),\r\n    \"overlay_inpaint\": OptionInfo(True, \"Overlay original for inpaint\").info(\"when inpainting, overlay the original image over the areas that weren't inpainted.\"),\r\n}))\r\n\r\noptions_templates.update(options_section(('optimizations', \"Optimizations\", \"sd\"), {\r\n    \"cross_attention_optimization\": OptionInfo(\"Automatic\", \"Cross attention optimization\", gr.Dropdown, lambda: {\"choices\": shared_items.cross_attention_optimizations()}),\r\n    \"s_min_uncond\": OptionInfo(0.0, \"Negative Guidance minimum sigma\", gr.Slider, {\"minimum\": 0.0, \"maximum\": 15.0, \"step\": 0.01}, infotext='NGMS').link(\"PR\", \"https://github.com/AUTOMATIC1111/stablediffusion-webui/pull/9177\").info(\"skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster\"),\r\n    \"s_min_uncond_all\": OptionInfo(False, \"Negative Guidance minimum sigma all steps\", infotext='NGMS all steps').info(\"By default, NGMS above skips every other step; this makes it skip all steps\"),\r\n    \"token_merging_ratio\": OptionInfo(0.0, \"Token merging ratio\", gr.Slider, {\"minimum\": 0.0, \"maximum\": 0.9, \"step\": 0.1}, infotext='Token merging ratio').link(\"PR\", \"https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256\").info(\"0=disable, higher=faster\"),\r\n    \"token_merging_ratio_img2img\": OptionInfo(0.0, \"Token merging ratio for img2img\", gr.Slider, {\"minimum\": 0.0, \"maximum\": 0.9, \"step\": 0.1}).info(\"only applies if non-zero and overrides above\"),\r\n    \"token_merging_ratio_hr\": OptionInfo(0.0, \"Token merging ratio for high-res pass\", gr.Slider, {\"minimum\": 0.0, \"maximum\": 0.9, \"step\": 0.1}, infotext='Token merging ratio hr').info(\"only applies if non-zero and overrides above\"),\r\n    \"pad_cond_uncond\": OptionInfo(False, \"Pad prompt/negative prompt\", infotext='Pad conds').info(\"improves performance when prompt and negative prompt have different lengths; changes seeds\"),\r\n    \"pad_cond_uncond_v0\": OptionInfo(False, \"Pad prompt/negative prompt (v0)\", infotext='Pad conds v0').info(\"alternative implementation for the above; used prior to 1.6.0 for DDIM sampler; overrides the above if set; WARNING: truncates negative prompt if it's too long; changes seeds\"),\r\n    \"persistent_cond_cache\": OptionInfo(True, \"Persistent cond cache\").info(\"do not recalculate conds from prompts if prompts have not changed since previous calculation\"),\r\n    \"batch_cond_uncond\": OptionInfo(True, \"Batch cond/uncond\").info(\"do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond commandline argument\"),\r\n    \"fp8_storage\": OptionInfo(\"Disable\", \"FP8 weight\", gr.Radio, {\"choices\": [\"Disable\", \"Enable for SDXL\", \"Enable\"]}).info(\"Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0.\"),\r\n    \"cache_fp16_weight\": OptionInfo(False, \"Cache FP16 weight for LoRA\").info(\"Cache fp16 weight when enabling FP8, will increase the quality of LoRA. Use more system ram.\"),\r\n}))\r\n\r\noptions_templates.update(options_section(('compatibility', \"Compatibility\", \"sd\"), {\r\n    \"auto_backcompat\": OptionInfo(True, \"Automatic backward compatibility\").info(\"automatically enable options for backwards compatibility when importing generation parameters from infotext that has program version.\"),\r\n    \"use_old_emphasis_implementation\": OptionInfo(False, \"Use old emphasis implementation. Can be useful to reproduce old seeds.\"),\r\n    \"use_old_karras_scheduler_sigmas\": OptionInfo(False, \"Use old karras scheduler sigmas (0.1 to 10).\"),\r\n    \"no_dpmpp_sde_batch_determinism\": OptionInfo(False, \"Do not make DPM++ SDE deterministic across different batch sizes.\"),\r\n    \"use_old_hires_fix_width_height\": OptionInfo(False, \"For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to).\"),\r\n    \"hires_fix_use_firstpass_conds\": OptionInfo(False, \"For hires fix, calculate conds of second pass using extra networks of first pass.\"),\r\n    \"use_old_scheduling\": OptionInfo(False, \"Use old prompt editing timelines.\", infotext=\"Old prompt editing timelines\").info(\"For [red:green:N]; old: If N < 1, it's a fraction of steps (and hires fix uses range from 0 to 1), if N >= 1, it's an absolute number of steps; new: If N has a decimal point in it, it's a fraction of steps (and hires fix uses range from 1 to 2), othewrwise it's an absolute number of steps\"),\r\n    \"use_downcasted_alpha_bar\": OptionInfo(False, \"Downcast model alphas_cumprod to fp16 before sampling. For reproducing old seeds.\", infotext=\"Downcast alphas_cumprod\"),\r\n    \"refiner_switch_by_sample_steps\": OptionInfo(False, \"Switch to refiner by sampling steps instead of model timesteps. Old behavior for refiner.\", infotext=\"Refiner switch by sampling steps\")\r\n}))\r\n\r\noptions_templates.update(options_section(('interrogate', \"Interrogate\"), {\r\n    \"interrogate_keep_models_in_memory\": OptionInfo(False, \"Keep models in VRAM\"),\r\n    \"interrogate_return_ranks\": OptionInfo(False, \"Include ranks of model tags matches in results.\").info(\"booru only\"),\r\n    \"interrogate_clip_num_beams\": OptionInfo(1, \"BLIP: num_beams\", gr.Slider, {\"minimum\": 1, \"maximum\": 16, \"step\": 1}),\r\n    \"interrogate_clip_min_length\": OptionInfo(24, \"BLIP: minimum description length\", gr.Slider, {\"minimum\": 1, \"maximum\": 128, \"step\": 1}),\r\n    \"interrogate_clip_max_length\": OptionInfo(48, \"BLIP: maximum description length\", gr.Slider, {\"minimum\": 1, \"maximum\": 256, \"step\": 1}),\r\n    \"interrogate_clip_dict_limit\": OptionInfo(1500, \"CLIP: maximum number of lines in text file\").info(\"0 = No limit\"),\r\n    \"interrogate_clip_skip_categories\": OptionInfo([], \"CLIP: skip inquire categories\", gr.CheckboxGroup, lambda: {\"choices\": interrogate.category_types()}, refresh=interrogate.category_types),\r\n    \"interrogate_deepbooru_score_threshold\": OptionInfo(0.5, \"deepbooru: score threshold\", gr.Slider, {\"minimum\": 0, \"maximum\": 1, \"step\": 0.01}),\r\n    \"deepbooru_sort_alpha\": OptionInfo(True, \"deepbooru: sort tags alphabetically\").info(\"if not: sort by score\"),\r\n    \"deepbooru_use_spaces\": OptionInfo(True, \"deepbooru: use spaces in tags\").info(\"if not: use underscores\"),\r\n    \"deepbooru_escape\": OptionInfo(True, \"deepbooru: escape (\\\\) brackets\").info(\"so they are used as literal brackets and not for emphasis\"),\r\n    \"deepbooru_filter_tags\": OptionInfo(\"\", \"deepbooru: filter out those tags\").info(\"separate by comma\"),\r\n}))\r\n\r\noptions_templates.update(options_section(('extra_networks', \"Extra Networks\", \"sd\"), {\r\n    \"extra_networks_show_hidden_directories\": OptionInfo(True, \"Show hidden directories\").info(\"directory is hidden if its name starts with \\\".\\\".\"),\r\n    \"extra_networks_dir_button_function\": OptionInfo(False, \"Add a '/' to the beginning of directory buttons\").info(\"Buttons will display the contents of the selected directory without acting as a search filter.\"),\r\n    \"extra_networks_hidden_models\": OptionInfo(\"When searched\", \"Show cards for models in hidden directories\", gr.Radio, {\"choices\": [\"Always\", \"When searched\", \"Never\"]}).info('\"When searched\" option will only show the item when the search string has 4 characters or more'),\r\n    \"extra_networks_default_multiplier\": OptionInfo(1.0, \"Default multiplier for extra networks\", gr.Slider, {\"minimum\": 0.0, \"maximum\": 2.0, \"step\": 0.01}),\r\n    \"extra_networks_card_width\": OptionInfo(0, \"Card width for Extra Networks\").info(\"in pixels\"),\r\n    \"extra_networks_card_height\": OptionInfo(0, \"Card height for Extra Networks\").info(\"in pixels\"),\r\n    \"extra_networks_card_text_scale\": OptionInfo(1.0, \"Card text scale\", gr.Slider, {\"minimum\": 0.0, \"maximum\": 2.0, \"step\": 0.01}).info(\"1 = original size\"),\r\n    \"extra_networks_card_show_desc\": OptionInfo(True, \"Show description on card\"),\r\n    \"extra_networks_card_description_is_html\": OptionInfo(False, \"Treat card description as HTML\"),\r\n    \"extra_networks_card_order_field\": OptionInfo(\"Path\", \"Default order field for Extra Networks cards\", gr.Dropdown, {\"choices\": ['Path', 'Name', 'Date Created', 'Date Modified']}).needs_reload_ui(),\r\n    \"extra_networks_card_order\": OptionInfo(\"Ascending\", \"Default order for Extra Networks cards\", gr.Dropdown, {\"choices\": ['Ascending', 'Descending']}).needs_reload_ui(),\r\n    \"extra_networks_tree_view_style\": OptionInfo(\"Dirs\", \"Extra Networks directory view style\", gr.Radio, {\"choices\": [\"Tree\", \"Dirs\"]}).needs_reload_ui(),\r\n    \"extra_networks_tree_view_default_enabled\": OptionInfo(True, \"Show the Extra Networks directory view by default\").needs_reload_ui(),\r\n    \"extra_networks_tree_view_default_width\": OptionInfo(180, \"Default width for the Extra Networks directory tree view\", gr.Number).needs_reload_ui(),\r\n    \"extra_networks_add_text_separator\": OptionInfo(\" \", \"Extra networks separator\").info(\"extra text to add before <...> when adding extra network to prompt\"),\r\n    \"ui_extra_networks_tab_reorder\": OptionInfo(\"\", \"Extra networks tab order\").needs_reload_ui(),\r\n    \"textual_inversion_print_at_load\": OptionInfo(False, \"Print a list of Textual Inversion embeddings when loading model\"),\r\n    \"textual_inversion_add_hashes_to_infotext\": OptionInfo(True, \"Add Textual Inversion hashes to infotext\"),\r\n    \"sd_hypernetwork\": OptionInfo(\"None\", \"Add hypernetwork to prompt\", gr.Dropdown, lambda: {\"choices\": [\"None\", *shared.hypernetworks]}, refresh=shared_items.reload_hypernetworks),\r\n}))\r\n\r\noptions_templates.update(options_section(('ui_prompt_editing', \"Prompt editing\", \"ui\"), {\r\n    \"keyedit_precision_attention\": OptionInfo(0.1, \"Precision for (attention:1.1) when editing the prompt with Ctrl+up/down\", gr.Slider, {\"minimum\": 0.01, \"maximum\": 0.2, \"step\": 0.001}),\r\n    \"keyedit_precision_extra\": OptionInfo(0.05, \"Precision for <extra networks:0.9> when editing the prompt with Ctrl+up/down\", gr.Slider, {\"minimum\": 0.01, \"maximum\": 0.2, \"step\": 0.001}),\r\n    \"keyedit_delimiters\": OptionInfo(r\".,\\/!?%^*;:{}=`~() \", \"Word delimiters when editing the prompt with Ctrl+up/down\"),\r\n    \"keyedit_delimiters_whitespace\": OptionInfo([\"Tab\", \"Carriage Return\", \"Line Feed\"], \"Ctrl+up/down whitespace delimiters\", gr.CheckboxGroup, lambda: {\"choices\": [\"Tab\", \"Carriage Return\", \"Line Feed\"]}),\r\n    \"keyedit_move\": OptionInfo(True, \"Alt+left/right moves prompt elements\"),\r\n    \"disable_token_counters\": OptionInfo(False, \"Disable prompt token counters\"),\r\n    \"include_styles_into_token_counters\": OptionInfo(True, \"Count tokens of enabled styles\").info(\"When calculating how many tokens the prompt has, also consider tokens added by enabled styles.\"),\r\n}))\r\n\r\noptions_templates.update(options_section(('ui_gallery', \"Gallery\", \"ui\"), {\r\n    \"return_grid\": OptionInfo(True, \"Show grid in gallery\"),\r\n    \"do_not_show_images\": OptionInfo(False, \"Do not show any images in gallery\"),\r\n    \"js_modal_lightbox\": OptionInfo(True, \"Full page image viewer: enable\"),\r\n    \"js_modal_lightbox_initially_zoomed\": OptionInfo(True, \"Full page image viewer: show images zoomed in by default\"),\r\n    \"js_modal_lightbox_gamepad\": OptionInfo(False, \"Full page image viewer: navigate with gamepad\"),\r\n    \"js_modal_lightbox_gamepad_repeat\": OptionInfo(250, \"Full page image viewer: gamepad repeat period\").info(\"in milliseconds\"),\r\n    \"sd_webui_modal_lightbox_icon_opacity\": OptionInfo(1, \"Full page image viewer: control icon unfocused opacity\", gr.Slider, {\"minimum\": 0.0, \"maximum\": 1, \"step\": 0.01}, onchange=shared.reload_gradio_theme).info('for mouse only').needs_reload_ui(),\r\n    \"sd_webui_modal_lightbox_toolbar_opacity\": OptionInfo(0.9, \"Full page image viewer: tool bar opacity\", gr.Slider, {\"minimum\": 0.0, \"maximum\": 1, \"step\": 0.01}, onchange=shared.reload_gradio_theme).info('for mouse only').needs_reload_ui(),\r\n    \"gallery_height\": OptionInfo(\"\", \"Gallery height\", gr.Textbox).info(\"can be any valid CSS value, for example 768px or 20em\").needs_reload_ui(),\r\n    \"open_dir_button_choice\": OptionInfo(\"Subdirectory\", \"What directory the [📂] button opens\", gr.Radio, {\"choices\": [\"Output Root\", \"Subdirectory\", \"Subdirectory (even temp dir)\"]}),\r\n}))\r\n\r\noptions_templates.update(options_section(('ui_alternatives', \"UI alternatives\", \"ui\"), {\r\n    \"compact_prompt_box\": OptionInfo(False, \"Compact prompt layout\").info(\"puts prompt and negative prompt inside the Generate tab, leaving more vertical space for the image on the right\").needs_reload_ui(),\r\n    \"samplers_in_dropdown\": OptionInfo(True, \"Use dropdown for sampler selection instead of radio group\").needs_reload_ui(),\r\n    \"dimensions_and_batch_together\": OptionInfo(True, \"Show Width/Height and Batch sliders in same row\").needs_reload_ui(),\r\n    \"sd_checkpoint_dropdown_use_short\": OptionInfo(False, \"Checkpoint dropdown: use filenames without paths\").info(\"models in subdirectories like photo/sd15.ckpt will be listed as just sd15.ckpt\"),\r\n    \"hires_fix_show_sampler\": OptionInfo(False, \"Hires fix: show hires checkpoint and sampler selection\").needs_reload_ui(),\r\n    \"hires_fix_show_prompts\": OptionInfo(False, \"Hires fix: show hires prompt and negative prompt\").needs_reload_ui(),\r\n    \"txt2img_settings_accordion\": OptionInfo(False, \"Settings in txt2img hidden under Accordion\").needs_reload_ui(),\r\n    \"img2img_settings_accordion\": OptionInfo(False, \"Settings in img2img hidden under Accordion\").needs_reload_ui(),\r\n    \"interrupt_after_current\": OptionInfo(True, \"Don't Interrupt in the middle\").info(\"when using Interrupt button, if generating more than one image, stop after the generation of an image has finished, instead of immediately\"),\r\n}))\r\n\r\noptions_templates.update(options_section(('ui', \"User interface\", \"ui\"), {\r\n    \"localization\": OptionInfo(\"None\", \"Localization\", gr.Dropdown, lambda: {\"choices\": [\"None\"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_reload_ui(),\r\n    \"quicksettings_list\": OptionInfo([\"sd_model_checkpoint\"], \"Quicksettings list\", ui_components.DropdownMulti, lambda: {\"choices\": list(shared.opts.data_labels.keys())}).js(\"info\", \"settingsHintsShowQuicksettings\").info(\"setting entries that appear at the top of page rather than in settings tab\").needs_reload_ui(),\r\n    \"ui_tab_order\": OptionInfo([], \"UI tab order\", ui_components.DropdownMulti, lambda: {\"choices\": list(shared.tab_names)}).needs_reload_ui(),\r\n    \"hidden_tabs\": OptionInfo([], \"Hidden UI tabs\", ui_components.DropdownMulti, lambda: {\"choices\": list(shared.tab_names)}).needs_reload_ui(),\r\n    \"ui_reorder_list\": OptionInfo([], \"UI item order for txt2img/img2img tabs\", ui_components.DropdownMulti, lambda: {\"choices\": list(shared_items.ui_reorder_categories())}).info(\"selected items appear first\").needs_reload_ui(),\r\n    \"gradio_theme\": OptionInfo(\"Default\", \"Gradio theme\", ui_components.DropdownEditable, lambda: {\"choices\": [\"Default\"] + shared_gradio_themes.gradio_hf_hub_themes}).info(\"you can also manually enter any of themes from the <a href='https://huggingface.co/spaces/gradio/theme-gallery'>gallery</a>.\").needs_reload_ui(),\r\n    \"gradio_themes_cache\": OptionInfo(True, \"Cache gradio themes locally\").info(\"disable to update the selected Gradio theme\"),\r\n    \"show_progress_in_title\": OptionInfo(True, \"Show generation progress in window title.\"),\r\n    \"send_seed\": OptionInfo(True, \"Send seed when sending prompt or image to other interface\"),\r\n    \"send_size\": OptionInfo(True, \"Send size when sending prompt or image to another interface\"),\r\n    \"enable_reloading_ui_scripts\": OptionInfo(False, \"Reload UI scripts when using Reload UI option\").info(\"useful for developing: if you make changes to UI scripts code, it is applied when the UI is reloded.\"),\r\n\r\n}))\r\n\r\n\r\noptions_templates.update(options_section(('infotext', \"Infotext\", \"ui\"), {\r\n    \"infotext_explanation\": OptionHTML(\"\"\"\r\nInfotext is what this software calls the text that contains generation parameters and can be used to generate the same picture again.\r\nIt is displayed in UI below the image. To use infotext, paste it into the prompt and click the ↙️ paste button.\r\n\"\"\"),\r\n    \"enable_pnginfo\": OptionInfo(True, \"Write infotext to metadata of the generated image\"),\r\n    \"save_txt\": OptionInfo(False, \"Create a text file with infotext next to every generated image\"),\r\n\r\n    \"add_model_name_to_info\": OptionInfo(True, \"Add model name to infotext\"),\r\n    \"add_model_hash_to_info\": OptionInfo(True, \"Add model hash to infotext\"),\r\n    \"add_vae_name_to_info\": OptionInfo(True, \"Add VAE name to infotext\"),\r\n    \"add_vae_hash_to_info\": OptionInfo(True, \"Add VAE hash to infotext\"),\r\n    \"add_user_name_to_info\": OptionInfo(False, \"Add user name to infotext when authenticated\"),\r\n    \"add_version_to_infotext\": OptionInfo(True, \"Add program version to infotext\"),\r\n    \"disable_weights_auto_swap\": OptionInfo(True, \"Disregard checkpoint information from pasted infotext\").info(\"when reading generation parameters from text into UI\"),\r\n    \"infotext_skip_pasting\": OptionInfo([], \"Disregard fields from pasted infotext\", ui_components.DropdownMulti, lambda: {\"choices\": shared_items.get_infotext_names()}),\r\n    \"infotext_styles\": OptionInfo(\"Apply if any\", \"Infer styles from prompts of pasted infotext\", gr.Radio, {\"choices\": [\"Ignore\", \"Apply\", \"Discard\", \"Apply if any\"]}).info(\"when reading generation parameters from text into UI)\").html(\"\"\"<ul style='margin-left: 1.5em'>\r\n<li>Ignore: keep prompt and styles dropdown as it is.</li>\r\n<li>Apply: remove style text from prompt, always replace styles dropdown value with found styles (even if none are found).</li>\r\n<li>Discard: remove style text from prompt, keep styles dropdown as it is.</li>\r\n<li>Apply if any: remove style text from prompt; if any styles are found in prompt, put them into styles dropdown, otherwise keep it as it is.</li>\r\n</ul>\"\"\"),\r\n\r\n}))\r\n\r\noptions_templates.update(options_section(('ui', \"Live previews\", \"ui\"), {\r\n    \"show_progressbar\": OptionInfo(True, \"Show progressbar\"),\r\n    \"live_previews_enable\": OptionInfo(True, \"Show live previews of the created image\"),\r\n    \"live_previews_image_format\": OptionInfo(\"png\", \"Live preview file format\", gr.Radio, {\"choices\": [\"jpeg\", \"png\", \"webp\"]}),\r\n    \"show_progress_grid\": OptionInfo(True, \"Show previews of all images generated in a batch as a grid\"),\r\n    \"show_progress_every_n_steps\": OptionInfo(10, \"Live preview display period\", gr.Slider, {\"minimum\": -1, \"maximum\": 32, \"step\": 1}).info(\"in sampling steps - show new live preview image every N sampling steps; -1 = only show after completion of batch\"),\r\n    \"show_progress_type\": OptionInfo(\"Approx NN\", \"Live preview method\", gr.Radio, {\"choices\": [\"Full\", \"Approx NN\", \"Approx cheap\", \"TAESD\"]}).info(\"Full = slow but pretty; Approx NN and TAESD = fast but low quality; Approx cheap = super fast but terrible otherwise\"),\r\n    \"live_preview_allow_lowvram_full\": OptionInfo(False, \"Allow Full live preview method with lowvram/medvram\").info(\"If not, Approx NN will be used instead; Full live preview method is very detrimental to speed if lowvram/medvram optimizations are enabled\"),\r\n    \"live_preview_content\": OptionInfo(\"Prompt\", \"Live preview subject\", gr.Radio, {\"choices\": [\"Combined\", \"Prompt\", \"Negative prompt\"]}),\r\n    \"live_preview_refresh_period\": OptionInfo(1000, \"Progressbar and preview update period\").info(\"in milliseconds\"),\r\n    \"live_preview_fast_interrupt\": OptionInfo(False, \"Return image with chosen live preview method on interrupt\").info(\"makes interrupts faster\"),\r\n    \"js_live_preview_in_modal_lightbox\": OptionInfo(False, \"Show Live preview in full page image viewer\"),\r\n    \"prevent_screen_sleep_during_generation\": OptionInfo(True, \"Prevent screen sleep during generation\"),\r\n}))\r\n\r\noptions_templates.update(options_section(('sampler-params', \"Sampler parameters\", \"sd\"), {\r\n    \"hide_samplers\": OptionInfo([], \"Hide samplers in user interface\", gr.CheckboxGroup, lambda: {\"choices\": [x.name for x in shared_items.list_samplers()]}).needs_reload_ui(),\r\n    \"eta_ddim\": OptionInfo(0.0, \"Eta for DDIM\", gr.Slider, {\"minimum\": 0.0, \"maximum\": 1.0, \"step\": 0.01}, infotext='Eta DDIM').info(\"noise multiplier; higher = more unpredictable results\"),\r\n    \"eta_ancestral\": OptionInfo(1.0, \"Eta for k-diffusion samplers\", gr.Slider, {\"minimum\": 0.0, \"maximum\": 1.0, \"step\": 0.01}, infotext='Eta').info(\"noise multiplier; currently only applies to ancestral samplers (i.e. Euler a) and SDE samplers\"),\r\n    \"ddim_discretize\": OptionInfo('uniform', \"img2img DDIM discretize\", gr.Radio, {\"choices\": ['uniform', 'quad']}),\r\n    's_churn': OptionInfo(0.0, \"sigma churn\", gr.Slider, {\"minimum\": 0.0, \"maximum\": 100.0, \"step\": 0.01}, infotext='Sigma churn').info('amount of stochasticity; only applies to Euler, Heun, and DPM2'),\r\n    's_tmin':  OptionInfo(0.0, \"sigma tmin\",  gr.Slider, {\"minimum\": 0.0, \"maximum\": 10.0, \"step\": 0.01}, infotext='Sigma tmin').info('enable stochasticity; start value of the sigma range; only applies to Euler, Heun, and DPM2'),\r\n    's_tmax':  OptionInfo(0.0, \"sigma tmax\",  gr.Slider, {\"minimum\": 0.0, \"maximum\": 999.0, \"step\": 0.01}, infotext='Sigma tmax').info(\"0 = inf; end value of the sigma range; only applies to Euler, Heun, and DPM2\"),\r\n    's_noise': OptionInfo(1.0, \"sigma noise\", gr.Slider, {\"minimum\": 0.0, \"maximum\": 1.1, \"step\": 0.001}, infotext='Sigma noise').info('amount of additional noise to counteract loss of detail during sampling'),\r\n    'sigma_min': OptionInfo(0.0, \"sigma min\", gr.Number, infotext='Schedule min sigma').info(\"0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler\"),\r\n    'sigma_max': OptionInfo(0.0, \"sigma max\", gr.Number, infotext='Schedule max sigma').info(\"0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler\"),\r\n    'rho':  OptionInfo(0.0, \"rho\", gr.Number, infotext='Schedule rho').info(\"0 = default (7 for karras, 1 for polyexponential); higher values result in a steeper noise schedule (decreases faster)\"),\r\n    'eta_noise_seed_delta': OptionInfo(0, \"Eta noise seed delta\", gr.Number, {\"precision\": 0}, infotext='ENSD').info(\"ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images\"),\r\n    'always_discard_next_to_last_sigma': OptionInfo(False, \"Always discard next-to-last sigma\", infotext='Discard penultimate sigma').link(\"PR\", \"https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044\"),\r\n    'sgm_noise_multiplier': OptionInfo(False, \"SGM noise multiplier\", infotext='SGM noise multiplier').link(\"PR\", \"https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12818\").info(\"Match initial noise to official SDXL implementation - only useful for reproducing images\"),\r\n    'uni_pc_variant': OptionInfo(\"bh1\", \"UniPC variant\", gr.Radio, {\"choices\": [\"bh1\", \"bh2\", \"vary_coeff\"]}, infotext='UniPC variant'),\r\n    'uni_pc_skip_type': OptionInfo(\"time_uniform\", \"UniPC skip type\", gr.Radio, {\"choices\": [\"time_uniform\", \"time_quadratic\", \"logSNR\"]}, infotext='UniPC skip type'),\r\n    'uni_pc_order': OptionInfo(3, \"UniPC order\", gr.Slider, {\"minimum\": 1, \"maximum\": 50, \"step\": 1}, infotext='UniPC order').info(\"must be < sampling steps\"),\r\n    'uni_pc_lower_order_final': OptionInfo(True, \"UniPC lower order final\", infotext='UniPC lower order final'),\r\n    'sd_noise_schedule': OptionInfo(\"Default\", \"Noise schedule for sampling\", gr.Radio, {\"choices\": [\"Default\", \"Zero Terminal SNR\"]}, infotext=\"Noise Schedule\").info(\"for use with zero terminal SNR trained models\"),\r\n    'skip_early_cond': OptionInfo(0.0, \"Ignore negative prompt during early sampling\", gr.Slider, {\"minimum\": 0.0, \"maximum\": 1.0, \"step\": 0.01}, infotext=\"Skip Early CFG\").info(\"disables CFG on a proportion of steps at the beginning of generation; 0=skip none; 1=skip all; can both improve sample diversity/quality and speed up sampling\"),\r\n    'beta_dist_alpha': OptionInfo(0.6, \"Beta scheduler - alpha\", gr.Slider, {\"minimum\": 0.01, \"maximum\": 1.0, \"step\": 0.01}, infotext='Beta scheduler alpha').info('Default = 0.6; the alpha parameter of the beta distribution used in Beta sampling'),\r\n    'beta_dist_beta': OptionInfo(0.6, \"Beta scheduler - beta\", gr.Slider, {\"minimum\": 0.01, \"maximum\": 1.0, \"step\": 0.01}, infotext='Beta scheduler beta').info('Default = 0.6; the beta parameter of the beta distribution used in Beta sampling'),\r\n}))\r\n\r\noptions_templates.update(options_section(('postprocessing', \"Postprocessing\", \"postprocessing\"), {\r\n    'postprocessing_enable_in_main_ui': OptionInfo([], \"Enable postprocessing operations in txt2img and img2img tabs\", ui_components.DropdownMulti, lambda: {\"choices\": [x.name for x in shared_items.postprocessing_scripts()]}),\r\n    'postprocessing_disable_in_extras': OptionInfo([], \"Disable postprocessing operations in extras tab\", ui_components.DropdownMulti, lambda: {\"choices\": [x.name for x in shared_items.postprocessing_scripts()]}),\r\n    'postprocessing_operation_order': OptionInfo([], \"Postprocessing operation order\", ui_components.DropdownMulti, lambda: {\"choices\": [x.name for x in shared_items.postprocessing_scripts()]}),\r\n    'upscaling_max_images_in_cache': OptionInfo(5, \"Maximum number of images in upscaling cache\", gr.Slider, {\"minimum\": 0, \"maximum\": 10, \"step\": 1}),\r\n    'postprocessing_existing_caption_action': OptionInfo(\"Ignore\", \"Action for existing captions\", gr.Radio, {\"choices\": [\"Ignore\", \"Keep\", \"Prepend\", \"Append\"]}).info(\"when generating captions using postprocessing; Ignore = use generated; Keep = use original; Prepend/Append = combine both\"),\r\n}))\r\n\r\noptions_templates.update(options_section((None, \"Hidden options\"), {\r\n    \"disabled_extensions\": OptionInfo([], \"Disable these extensions\"),\r\n    \"disable_all_extensions\": OptionInfo(\"none\", \"Disable all extensions (preserves the list of disabled extensions)\", gr.Radio, {\"choices\": [\"none\", \"extra\", \"all\"]}),\r\n    \"restore_config_state_file\": OptionInfo(\"\", \"Config state file to restore from, under 'config-states/' folder\"),\r\n    \"sd_checkpoint_hash\": OptionInfo(\"\", \"SHA256 hash of the current checkpoint\"),\r\n}))\r\n"
  },
  {
    "path": "modules/shared_state.py",
    "content": "import datetime\r\nimport logging\r\nimport threading\r\nimport time\r\n\r\nfrom modules import errors, shared, devices\r\nfrom typing import Optional\r\n\r\nlog = logging.getLogger(__name__)\r\n\r\n\r\nclass State:\r\n    skipped = False\r\n    interrupted = False\r\n    stopping_generation = False\r\n    job = \"\"\r\n    job_no = 0\r\n    job_count = 0\r\n    processing_has_refined_job_count = False\r\n    job_timestamp = '0'\r\n    sampling_step = 0\r\n    sampling_steps = 0\r\n    current_latent = None\r\n    current_image = None\r\n    current_image_sampling_step = 0\r\n    id_live_preview = 0\r\n    textinfo = None\r\n    time_start = None\r\n    server_start = None\r\n    _server_command_signal = threading.Event()\r\n    _server_command: Optional[str] = None\r\n\r\n    def __init__(self):\r\n        self.server_start = time.time()\r\n\r\n    @property\r\n    def need_restart(self) -> bool:\r\n        # Compatibility getter for need_restart.\r\n        return self.server_command == \"restart\"\r\n\r\n    @need_restart.setter\r\n    def need_restart(self, value: bool) -> None:\r\n        # Compatibility setter for need_restart.\r\n        if value:\r\n            self.server_command = \"restart\"\r\n\r\n    @property\r\n    def server_command(self):\r\n        return self._server_command\r\n\r\n    @server_command.setter\r\n    def server_command(self, value: Optional[str]) -> None:\r\n        \"\"\"\r\n        Set the server command to `value` and signal that it's been set.\r\n        \"\"\"\r\n        self._server_command = value\r\n        self._server_command_signal.set()\r\n\r\n    def wait_for_server_command(self, timeout: Optional[float] = None) -> Optional[str]:\r\n        \"\"\"\r\n        Wait for server command to get set; return and clear the value and signal.\r\n        \"\"\"\r\n        if self._server_command_signal.wait(timeout):\r\n            self._server_command_signal.clear()\r\n            req = self._server_command\r\n            self._server_command = None\r\n            return req\r\n        return None\r\n\r\n    def request_restart(self) -> None:\r\n        self.interrupt()\r\n        self.server_command = \"restart\"\r\n        log.info(\"Received restart request\")\r\n\r\n    def skip(self):\r\n        self.skipped = True\r\n        log.info(\"Received skip request\")\r\n\r\n    def interrupt(self):\r\n        self.interrupted = True\r\n        log.info(\"Received interrupt request\")\r\n\r\n    def stop_generating(self):\r\n        self.stopping_generation = True\r\n        log.info(\"Received stop generating request\")\r\n\r\n    def nextjob(self):\r\n        if shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps == -1:\r\n            self.do_set_current_image()\r\n\r\n        self.job_no += 1\r\n        self.sampling_step = 0\r\n        self.current_image_sampling_step = 0\r\n\r\n    def dict(self):\r\n        obj = {\r\n            \"skipped\": self.skipped,\r\n            \"interrupted\": self.interrupted,\r\n            \"stopping_generation\": self.stopping_generation,\r\n            \"job\": self.job,\r\n            \"job_count\": self.job_count,\r\n            \"job_timestamp\": self.job_timestamp,\r\n            \"job_no\": self.job_no,\r\n            \"sampling_step\": self.sampling_step,\r\n            \"sampling_steps\": self.sampling_steps,\r\n        }\r\n\r\n        return obj\r\n\r\n    def begin(self, job: str = \"(unknown)\"):\r\n        self.sampling_step = 0\r\n        self.time_start = time.time()\r\n        self.job_count = -1\r\n        self.processing_has_refined_job_count = False\r\n        self.job_no = 0\r\n        self.job_timestamp = datetime.datetime.now().strftime(\"%Y%m%d%H%M%S\")\r\n        self.current_latent = None\r\n        self.current_image = None\r\n        self.current_image_sampling_step = 0\r\n        self.id_live_preview = 0\r\n        self.skipped = False\r\n        self.interrupted = False\r\n        self.stopping_generation = False\r\n        self.textinfo = None\r\n        self.job = job\r\n        devices.torch_gc()\r\n        log.info(\"Starting job %s\", job)\r\n\r\n    def end(self):\r\n        duration = time.time() - self.time_start\r\n        log.info(\"Ending job %s (%.2f seconds)\", self.job, duration)\r\n        self.job = \"\"\r\n        self.job_count = 0\r\n\r\n        devices.torch_gc()\r\n\r\n    def set_current_image(self):\r\n        \"\"\"if enough sampling steps have been made after the last call to this, sets self.current_image from self.current_latent, and modifies self.id_live_preview accordingly\"\"\"\r\n        if not shared.parallel_processing_allowed:\r\n            return\r\n\r\n        if self.sampling_step - self.current_image_sampling_step >= shared.opts.show_progress_every_n_steps and shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps != -1:\r\n            self.do_set_current_image()\r\n\r\n    def do_set_current_image(self):\r\n        if self.current_latent is None:\r\n            return\r\n\r\n        import modules.sd_samplers\r\n\r\n        try:\r\n            if shared.opts.show_progress_grid:\r\n                self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))\r\n            else:\r\n                self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))\r\n\r\n            self.current_image_sampling_step = self.sampling_step\r\n\r\n        except Exception:\r\n            # when switching models during generation, VAE would be on CPU, so creating an image will fail.\r\n            # we silently ignore this error\r\n            errors.record_exception()\r\n\r\n    def assign_current_image(self, image):\r\n        if shared.opts.live_previews_image_format == 'jpeg' and image.mode in ('RGBA', 'P'):\r\n            image = image.convert('RGB')\r\n        self.current_image = image\r\n        self.id_live_preview += 1\r\n"
  },
  {
    "path": "modules/shared_total_tqdm.py",
    "content": "import tqdm\r\n\r\nfrom modules import shared\r\n\r\n\r\nclass TotalTQDM:\r\n    def __init__(self):\r\n        self._tqdm = None\r\n\r\n    def reset(self):\r\n        self._tqdm = tqdm.tqdm(\r\n            desc=\"Total progress\",\r\n            total=shared.state.job_count * shared.state.sampling_steps,\r\n            position=1,\r\n            file=shared.progress_print_out\r\n        )\r\n\r\n    def update(self):\r\n        if not shared.opts.multiple_tqdm or shared.cmd_opts.disable_console_progressbars:\r\n            return\r\n        if self._tqdm is None:\r\n            self.reset()\r\n        self._tqdm.update()\r\n\r\n    def updateTotal(self, new_total):\r\n        if not shared.opts.multiple_tqdm or shared.cmd_opts.disable_console_progressbars:\r\n            return\r\n        if self._tqdm is None:\r\n            self.reset()\r\n        self._tqdm.total = new_total\r\n\r\n    def clear(self):\r\n        if self._tqdm is not None:\r\n            self._tqdm.refresh()\r\n            self._tqdm.close()\r\n            self._tqdm = None\r\n\r\n"
  },
  {
    "path": "modules/styles.py",
    "content": "from __future__ import annotations\r\nfrom pathlib import Path\r\nfrom modules import errors\r\nimport csv\r\nimport os\r\nimport typing\r\nimport shutil\r\n\r\n\r\nclass PromptStyle(typing.NamedTuple):\r\n    name: str\r\n    prompt: str | None\r\n    negative_prompt: str | None\r\n    path: str | None = None\r\n\r\n\r\ndef merge_prompts(style_prompt: str, prompt: str) -> str:\r\n    if \"{prompt}\" in style_prompt:\r\n        res = style_prompt.replace(\"{prompt}\", prompt)\r\n    else:\r\n        parts = filter(None, (prompt.strip(), style_prompt.strip()))\r\n        res = \", \".join(parts)\r\n\r\n    return res\r\n\r\n\r\ndef apply_styles_to_prompt(prompt, styles):\r\n    for style in styles:\r\n        prompt = merge_prompts(style, prompt)\r\n\r\n    return prompt\r\n\r\n\r\ndef extract_style_text_from_prompt(style_text, prompt):\r\n    \"\"\"This function extracts the text from a given prompt based on a provided style text. It checks if the style text contains the placeholder {prompt} or if it appears at the end of the prompt. If a match is found, it returns True along with the extracted text. Otherwise, it returns False and the original prompt.\r\n\r\n    extract_style_text_from_prompt(\"masterpiece\", \"1girl, art by greg, masterpiece\") outputs (True, \"1girl, art by greg\")\r\n    extract_style_text_from_prompt(\"masterpiece, {prompt}\", \"masterpiece, 1girl, art by greg\") outputs (True, \"1girl, art by greg\")\r\n    extract_style_text_from_prompt(\"masterpiece, {prompt}\", \"exquisite, 1girl, art by greg\") outputs (False, \"exquisite, 1girl, art by greg\")\r\n    \"\"\"\r\n\r\n    stripped_prompt = prompt.strip()\r\n    stripped_style_text = style_text.strip()\r\n\r\n    if \"{prompt}\" in stripped_style_text:\r\n        left, _, right = stripped_style_text.partition(\"{prompt}\")\r\n        if stripped_prompt.startswith(left) and stripped_prompt.endswith(right):\r\n            prompt = stripped_prompt[len(left):len(stripped_prompt)-len(right)]\r\n            return True, prompt\r\n    else:\r\n        if stripped_prompt.endswith(stripped_style_text):\r\n            prompt = stripped_prompt[:len(stripped_prompt)-len(stripped_style_text)]\r\n\r\n            if prompt.endswith(', '):\r\n                prompt = prompt[:-2]\r\n\r\n            return True, prompt\r\n\r\n    return False, prompt\r\n\r\n\r\ndef extract_original_prompts(style: PromptStyle, prompt, negative_prompt):\r\n    \"\"\"\r\n    Takes a style and compares it to the prompt and negative prompt. If the style\r\n    matches, returns True plus the prompt and negative prompt with the style text\r\n    removed. Otherwise, returns False with the original prompt and negative prompt.\r\n    \"\"\"\r\n    if not style.prompt and not style.negative_prompt:\r\n        return False, prompt, negative_prompt\r\n\r\n    match_positive, extracted_positive = extract_style_text_from_prompt(style.prompt, prompt)\r\n    if not match_positive:\r\n        return False, prompt, negative_prompt\r\n\r\n    match_negative, extracted_negative = extract_style_text_from_prompt(style.negative_prompt, negative_prompt)\r\n    if not match_negative:\r\n        return False, prompt, negative_prompt\r\n\r\n    return True, extracted_positive, extracted_negative\r\n\r\n\r\nclass StyleDatabase:\r\n    def __init__(self, paths: list[str | Path]):\r\n        self.no_style = PromptStyle(\"None\", \"\", \"\", None)\r\n        self.styles = {}\r\n        self.paths = paths\r\n        self.all_styles_files: list[Path] = []\r\n\r\n        folder, file = os.path.split(self.paths[0])\r\n        if '*' in file or '?' in file:\r\n            # if the first path is a wildcard pattern, find the first match else use \"folder/styles.csv\" as the default path\r\n            self.default_path = next(Path(folder).glob(file), Path(os.path.join(folder, 'styles.csv')))\r\n            self.paths.insert(0, self.default_path)\r\n        else:\r\n            self.default_path = Path(self.paths[0])\r\n\r\n        self.prompt_fields = [field for field in PromptStyle._fields if field != \"path\"]\r\n\r\n        self.reload()\r\n\r\n    def reload(self):\r\n        \"\"\"\r\n        Clears the style database and reloads the styles from the CSV file(s)\r\n        matching the path used to initialize the database.\r\n        \"\"\"\r\n        self.styles.clear()\r\n\r\n        # scans for all styles files\r\n        all_styles_files = []\r\n        for pattern in self.paths:\r\n            folder, file = os.path.split(pattern)\r\n            if '*' in file or '?' in file:\r\n                found_files = Path(folder).glob(file)\r\n                [all_styles_files.append(file) for file in found_files]\r\n            else:\r\n                # if os.path.exists(pattern):\r\n                all_styles_files.append(Path(pattern))\r\n\r\n        # Remove any duplicate entries\r\n        seen = set()\r\n        self.all_styles_files = [s for s in all_styles_files if not (s in seen or seen.add(s))]\r\n\r\n        for styles_file in self.all_styles_files:\r\n            if len(all_styles_files) > 1:\r\n                # add divider when more than styles file\r\n                # '---------------- STYLES ----------------'\r\n                divider = f' {styles_file.stem.upper()} '.center(40, '-')\r\n                self.styles[divider] = PromptStyle(f\"{divider}\", None, None, \"do_not_save\")\r\n            if styles_file.is_file():\r\n                self.load_from_csv(styles_file)\r\n\r\n    def load_from_csv(self, path: str | Path):\r\n        try:\r\n            with open(path, \"r\", encoding=\"utf-8-sig\", newline=\"\") as file:\r\n                reader = csv.DictReader(file, skipinitialspace=True)\r\n                for row in reader:\r\n                    # Ignore empty rows or rows starting with a comment\r\n                    if not row or row[\"name\"].startswith(\"#\"):\r\n                        continue\r\n                    # Support loading old CSV format with \"name, text\"-columns\r\n                    prompt = row[\"prompt\"] if \"prompt\" in row else row[\"text\"]\r\n                    negative_prompt = row.get(\"negative_prompt\", \"\")\r\n                    # Add style to database\r\n                    self.styles[row[\"name\"]] = PromptStyle(\r\n                        row[\"name\"], prompt, negative_prompt, str(path)\r\n                    )\r\n        except Exception:\r\n            errors.report(f'Error loading styles from {path}: ', exc_info=True)\r\n\r\n    def get_style_paths(self) -> set:\r\n        \"\"\"Returns a set of all distinct paths of files that styles are loaded from.\"\"\"\r\n        # Update any styles without a path to the default path\r\n        for style in list(self.styles.values()):\r\n            if not style.path:\r\n                self.styles[style.name] = style._replace(path=str(self.default_path))\r\n\r\n        # Create a list of all distinct paths, including the default path\r\n        style_paths = set()\r\n        style_paths.add(str(self.default_path))\r\n        for _, style in self.styles.items():\r\n            if style.path:\r\n                style_paths.add(style.path)\r\n\r\n        # Remove any paths for styles that are just list dividers\r\n        style_paths.discard(\"do_not_save\")\r\n\r\n        return style_paths\r\n\r\n    def get_style_prompts(self, styles):\r\n        return [self.styles.get(x, self.no_style).prompt for x in styles]\r\n\r\n    def get_negative_style_prompts(self, styles):\r\n        return [self.styles.get(x, self.no_style).negative_prompt for x in styles]\r\n\r\n    def apply_styles_to_prompt(self, prompt, styles):\r\n        return apply_styles_to_prompt(\r\n            prompt, [self.styles.get(x, self.no_style).prompt for x in styles]\r\n        )\r\n\r\n    def apply_negative_styles_to_prompt(self, prompt, styles):\r\n        return apply_styles_to_prompt(\r\n            prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles]\r\n        )\r\n\r\n    def save_styles(self, path: str = None) -> None:\r\n        # The path argument is deprecated, but kept for backwards compatibility\r\n\r\n        style_paths = self.get_style_paths()\r\n\r\n        csv_names = [os.path.split(path)[1].lower() for path in style_paths]\r\n\r\n        for style_path in style_paths:\r\n            # Always keep a backup file around\r\n            if os.path.exists(style_path):\r\n                shutil.copy(style_path, f\"{style_path}.bak\")\r\n\r\n            # Write the styles to the CSV file\r\n            with open(style_path, \"w\", encoding=\"utf-8-sig\", newline=\"\") as file:\r\n                writer = csv.DictWriter(file, fieldnames=self.prompt_fields)\r\n                writer.writeheader()\r\n                for style in (s for s in self.styles.values() if s.path == style_path):\r\n                    # Skip style list dividers, e.g. \"STYLES.CSV\"\r\n                    if style.name.lower().strip(\"# \") in csv_names:\r\n                        continue\r\n                    # Write style fields, ignoring the path field\r\n                    writer.writerow(\r\n                        {k: v for k, v in style._asdict().items() if k != \"path\"}\r\n                    )\r\n\r\n    def extract_styles_from_prompt(self, prompt, negative_prompt):\r\n        extracted = []\r\n\r\n        applicable_styles = list(self.styles.values())\r\n\r\n        while True:\r\n            found_style = None\r\n\r\n            for style in applicable_styles:\r\n                is_match, new_prompt, new_neg_prompt = extract_original_prompts(\r\n                    style, prompt, negative_prompt\r\n                )\r\n                if is_match:\r\n                    found_style = style\r\n                    prompt = new_prompt\r\n                    negative_prompt = new_neg_prompt\r\n                    break\r\n\r\n            if not found_style:\r\n                break\r\n\r\n            applicable_styles.remove(found_style)\r\n            extracted.append(found_style.name)\r\n\r\n        return list(reversed(extracted)), prompt, negative_prompt\r\n"
  },
  {
    "path": "modules/sub_quadratic_attention.py",
    "content": "# original source:\n#   https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py\n# license:\n#   MIT License (see Memory Efficient Attention under the Licenses section in the web UI interface for the full license)\n# credit:\n#   Amin Rezaei (original author)\n#   Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks)\n#   brkirch (modified to use torch.narrow instead of dynamic_slice implementation)\n# implementation of:\n#   Self-attention Does Not Need O(n2) Memory\":\n#   https://arxiv.org/abs/2112.05682v2\n\nfrom functools import partial\nimport torch\nfrom torch import Tensor\nfrom torch.utils.checkpoint import checkpoint\nimport math\nfrom typing import Optional, NamedTuple\n\n\ndef narrow_trunc(\n    input: Tensor,\n    dim: int,\n    start: int,\n    length: int\n) -> Tensor:\n    return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start)\n\n\nclass AttnChunk(NamedTuple):\n    exp_values: Tensor\n    exp_weights_sum: Tensor\n    max_score: Tensor\n\n\nclass SummarizeChunk:\n    @staticmethod\n    def __call__(\n        query: Tensor,\n        key: Tensor,\n        value: Tensor,\n    ) -> AttnChunk: ...\n\n\nclass ComputeQueryChunkAttn:\n    @staticmethod\n    def __call__(\n        query: Tensor,\n        key: Tensor,\n        value: Tensor,\n    ) -> Tensor: ...\n\n\ndef _summarize_chunk(\n    query: Tensor,\n    key: Tensor,\n    value: Tensor,\n    scale: float,\n) -> AttnChunk:\n    attn_weights = torch.baddbmm(\n        torch.zeros(1, 1, 1, device=query.device, dtype=query.dtype),\n        query,\n        key.transpose(1,2),\n        alpha=scale,\n        beta=0,\n    )\n    max_score, _ = torch.max(attn_weights, -1, keepdim=True)\n    max_score = max_score.detach()\n    exp_weights = torch.exp(attn_weights - max_score)\n    exp_values = torch.bmm(exp_weights, value) if query.device.type == 'mps' else torch.bmm(exp_weights, value.to(exp_weights.dtype)).to(value.dtype)\n    max_score = max_score.squeeze(-1)\n    return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)\n\n\ndef _query_chunk_attention(\n    query: Tensor,\n    key: Tensor,\n    value: Tensor,\n    summarize_chunk: SummarizeChunk,\n    kv_chunk_size: int,\n) -> Tensor:\n    batch_x_heads, k_tokens, k_channels_per_head = key.shape\n    _, _, v_channels_per_head = value.shape\n\n    def chunk_scanner(chunk_idx: int) -> AttnChunk:\n        key_chunk = narrow_trunc(\n            key,\n            1,\n            chunk_idx,\n            kv_chunk_size\n        )\n        value_chunk = narrow_trunc(\n            value,\n            1,\n            chunk_idx,\n            kv_chunk_size\n        )\n        return summarize_chunk(query, key_chunk, value_chunk)\n\n    chunks: list[AttnChunk] = [\n        chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size)\n    ]\n    acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))\n    chunk_values, chunk_weights, chunk_max = acc_chunk\n\n    global_max, _ = torch.max(chunk_max, 0, keepdim=True)\n    max_diffs = torch.exp(chunk_max - global_max)\n    chunk_values *= torch.unsqueeze(max_diffs, -1)\n    chunk_weights *= max_diffs\n\n    all_values = chunk_values.sum(dim=0)\n    all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)\n    return all_values / all_weights\n\n\n# TODO: refactor CrossAttention#get_attention_scores to share code with this\ndef _get_attention_scores_no_kv_chunking(\n    query: Tensor,\n    key: Tensor,\n    value: Tensor,\n    scale: float,\n) -> Tensor:\n    attn_scores = torch.baddbmm(\n        torch.zeros(1, 1, 1, device=query.device, dtype=query.dtype),\n        query,\n        key.transpose(1,2),\n        alpha=scale,\n        beta=0,\n    )\n    attn_probs = attn_scores.softmax(dim=-1)\n    del attn_scores\n    hidden_states_slice = torch.bmm(attn_probs, value) if query.device.type == 'mps' else torch.bmm(attn_probs, value.to(attn_probs.dtype)).to(value.dtype)\n    return hidden_states_slice\n\n\nclass ScannedChunk(NamedTuple):\n    chunk_idx: int\n    attn_chunk: AttnChunk\n\n\ndef efficient_dot_product_attention(\n    query: Tensor,\n    key: Tensor,\n    value: Tensor,\n    query_chunk_size=1024,\n    kv_chunk_size: Optional[int] = None,\n    kv_chunk_size_min: Optional[int] = None,\n    use_checkpoint=True,\n):\n    \"\"\"Computes efficient dot-product attention given query, key, and value.\n      This is efficient version of attention presented in\n      https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements.\n      Args:\n        query: queries for calculating attention with shape of\n          `[batch * num_heads, tokens, channels_per_head]`.\n        key: keys for calculating attention with shape of\n          `[batch * num_heads, tokens, channels_per_head]`.\n        value: values to be used in attention with shape of\n          `[batch * num_heads, tokens, channels_per_head]`.\n        query_chunk_size: int: query chunks size\n        kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens)\n        kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done).\n        use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference)\n      Returns:\n        Output of shape `[batch * num_heads, query_tokens, channels_per_head]`.\n      \"\"\"\n    batch_x_heads, q_tokens, q_channels_per_head = query.shape\n    _, k_tokens, _ = key.shape\n    scale = q_channels_per_head ** -0.5\n\n    kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens)\n    if kv_chunk_size_min is not None:\n        kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)\n\n    def get_query_chunk(chunk_idx: int) -> Tensor:\n        return narrow_trunc(\n            query,\n            1,\n            chunk_idx,\n            min(query_chunk_size, q_tokens)\n        )\n\n    summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)\n    summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk\n    compute_query_chunk_attn: ComputeQueryChunkAttn = partial(\n        _get_attention_scores_no_kv_chunking,\n        scale=scale\n    ) if k_tokens <= kv_chunk_size else (\n        # fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw)\n        partial(\n            _query_chunk_attention,\n            kv_chunk_size=kv_chunk_size,\n            summarize_chunk=summarize_chunk,\n        )\n    )\n\n    if q_tokens <= query_chunk_size:\n        # fast-path for when there's just 1 query chunk\n        return compute_query_chunk_attn(\n            query=query,\n            key=key,\n            value=value,\n        )\n\n    res = torch.zeros_like(query)\n    for i in range(math.ceil(q_tokens / query_chunk_size)):\n        attn_scores = compute_query_chunk_attn(\n            query=get_query_chunk(i * query_chunk_size),\n            key=key,\n            value=value,\n        )\n\n        res[:, i * query_chunk_size:i * query_chunk_size + attn_scores.shape[1], :] = attn_scores\n\n    return res\n"
  },
  {
    "path": "modules/sysinfo.py",
    "content": "import json\r\nimport os\r\nimport sys\r\nimport subprocess\r\nimport platform\r\nimport hashlib\r\nimport re\r\nfrom pathlib import Path\r\n\r\nfrom modules import paths_internal, timer, shared_cmd_options, errors, launch_utils\r\n\r\nchecksum_token = \"DontStealMyGamePlz__WINNERS_DONT_USE_DRUGS__DONT_COPY_THAT_FLOPPY\"\r\nenvironment_whitelist = {\r\n    \"GIT\",\r\n    \"INDEX_URL\",\r\n    \"WEBUI_LAUNCH_LIVE_OUTPUT\",\r\n    \"GRADIO_ANALYTICS_ENABLED\",\r\n    \"PYTHONPATH\",\r\n    \"TORCH_INDEX_URL\",\r\n    \"TORCH_COMMAND\",\r\n    \"REQS_FILE\",\r\n    \"XFORMERS_PACKAGE\",\r\n    \"CLIP_PACKAGE\",\r\n    \"OPENCLIP_PACKAGE\",\r\n    \"ASSETS_REPO\",\r\n    \"STABLE_DIFFUSION_REPO\",\r\n    \"K_DIFFUSION_REPO\",\r\n    \"BLIP_REPO\",\r\n    \"ASSETS_COMMIT_HASH\",\r\n    \"STABLE_DIFFUSION_COMMIT_HASH\",\r\n    \"K_DIFFUSION_COMMIT_HASH\",\r\n    \"BLIP_COMMIT_HASH\",\r\n    \"COMMANDLINE_ARGS\",\r\n    \"IGNORE_CMD_ARGS_ERRORS\",\r\n}\r\n\r\n\r\ndef pretty_bytes(num, suffix=\"B\"):\r\n    for unit in [\"\", \"K\", \"M\", \"G\", \"T\", \"P\", \"E\", \"Z\", \"Y\"]:\r\n        if abs(num) < 1024 or unit == 'Y':\r\n            return f\"{num:.0f}{unit}{suffix}\"\r\n        num /= 1024\r\n\r\n\r\ndef get():\r\n    res = get_dict()\r\n\r\n    text = json.dumps(res, ensure_ascii=False, indent=4)\r\n\r\n    h = hashlib.sha256(text.encode(\"utf8\"))\r\n    text = text.replace(checksum_token, h.hexdigest())\r\n\r\n    return text\r\n\r\n\r\nre_checksum = re.compile(r'\"Checksum\": \"([0-9a-fA-F]{64})\"')\r\n\r\n\r\ndef check(x):\r\n    m = re.search(re_checksum, x)\r\n    if not m:\r\n        return False\r\n\r\n    replaced = re.sub(re_checksum, f'\"Checksum\": \"{checksum_token}\"', x)\r\n\r\n    h = hashlib.sha256(replaced.encode(\"utf8\"))\r\n    return h.hexdigest() == m.group(1)\r\n\r\n\r\ndef get_cpu_info():\r\n    cpu_info = {\"model\": platform.processor()}\r\n    try:\r\n        import psutil\r\n        cpu_info[\"count logical\"] = psutil.cpu_count(logical=True)\r\n        cpu_info[\"count physical\"] = psutil.cpu_count(logical=False)\r\n    except Exception as e:\r\n        cpu_info[\"error\"] = str(e)\r\n    return cpu_info\r\n\r\n\r\ndef get_ram_info():\r\n    try:\r\n        import psutil\r\n        ram = psutil.virtual_memory()\r\n        return {x: pretty_bytes(getattr(ram, x, 0)) for x in [\"total\", \"used\", \"free\", \"active\", \"inactive\", \"buffers\", \"cached\", \"shared\"] if getattr(ram, x, 0) != 0}\r\n    except Exception as e:\r\n        return str(e)\r\n\r\n\r\ndef get_packages():\r\n    try:\r\n        return subprocess.check_output([sys.executable, '-m', 'pip', 'freeze', '--all']).decode(\"utf8\").splitlines()\r\n    except Exception as pip_error:\r\n        try:\r\n            import importlib.metadata\r\n            packages = importlib.metadata.distributions()\r\n            return sorted([f\"{package.metadata['Name']}=={package.version}\" for package in packages])\r\n        except Exception as e2:\r\n            return {'error pip': pip_error, 'error importlib': str(e2)}\r\n\r\n\r\ndef get_dict():\r\n    config = get_config()\r\n    res = {\r\n        \"Platform\": platform.platform(),\r\n        \"Python\": platform.python_version(),\r\n        \"Version\": launch_utils.git_tag(),\r\n        \"Commit\": launch_utils.commit_hash(),\r\n        \"Git status\": git_status(paths_internal.script_path),\r\n        \"Script path\": paths_internal.script_path,\r\n        \"Data path\": paths_internal.data_path,\r\n        \"Extensions dir\": paths_internal.extensions_dir,\r\n        \"Checksum\": checksum_token,\r\n        \"Commandline\": get_argv(),\r\n        \"Torch env info\": get_torch_sysinfo(),\r\n        \"Exceptions\": errors.get_exceptions(),\r\n        \"CPU\": get_cpu_info(),\r\n        \"RAM\": get_ram_info(),\r\n        \"Extensions\": get_extensions(enabled=True, fallback_disabled_extensions=config.get('disabled_extensions', [])),\r\n        \"Inactive extensions\": get_extensions(enabled=False, fallback_disabled_extensions=config.get('disabled_extensions', [])),\r\n        \"Environment\": get_environment(),\r\n        \"Config\": config,\r\n        \"Startup\": timer.startup_record,\r\n        \"Packages\": get_packages(),\r\n    }\r\n\r\n    return res\r\n\r\n\r\ndef get_environment():\r\n    return {k: os.environ[k] for k in sorted(os.environ) if k in environment_whitelist}\r\n\r\n\r\ndef get_argv():\r\n    res = []\r\n\r\n    for v in sys.argv:\r\n        if shared_cmd_options.cmd_opts.gradio_auth and shared_cmd_options.cmd_opts.gradio_auth == v:\r\n            res.append(\"<hidden>\")\r\n            continue\r\n\r\n        if shared_cmd_options.cmd_opts.api_auth and shared_cmd_options.cmd_opts.api_auth == v:\r\n            res.append(\"<hidden>\")\r\n            continue\r\n\r\n        res.append(v)\r\n\r\n    return res\r\n\r\n\r\nre_newline = re.compile(r\"\\r*\\n\")\r\n\r\n\r\ndef get_torch_sysinfo():\r\n    try:\r\n        import torch.utils.collect_env\r\n        info = torch.utils.collect_env.get_env_info()._asdict()\r\n\r\n        return {k: re.split(re_newline, str(v)) if \"\\n\" in str(v) else v for k, v in info.items()}\r\n    except Exception as e:\r\n        return str(e)\r\n\r\n\r\ndef run_git(path, *args):\r\n    try:\r\n        return subprocess.check_output([launch_utils.git, '-C', path, *args], shell=False, encoding='utf8').strip()\r\n    except Exception as e:\r\n        return str(e)\r\n\r\n\r\ndef git_status(path):\r\n    if (Path(path) / '.git').is_dir():\r\n        return run_git(paths_internal.script_path, 'status')\r\n\r\n\r\ndef get_info_from_repo_path(path: Path):\r\n    is_repo = (path / '.git').is_dir()\r\n    return {\r\n        'name': path.name,\r\n        'path': str(path),\r\n        'commit': run_git(path, 'rev-parse', 'HEAD') if is_repo else None,\r\n        'branch': run_git(path, 'branch', '--show-current') if is_repo else None,\r\n        'remote': run_git(path, 'remote', 'get-url', 'origin') if is_repo else None,\r\n    }\r\n\r\n\r\ndef get_extensions(*, enabled, fallback_disabled_extensions=None):\r\n    try:\r\n        from modules import extensions\r\n        if extensions.extensions:\r\n            def to_json(x: extensions.Extension):\r\n                return {\r\n                    \"name\": x.name,\r\n                    \"path\": x.path,\r\n                    \"commit\": x.commit_hash,\r\n                    \"branch\": x.branch,\r\n                    \"remote\": x.remote,\r\n                }\r\n            return [to_json(x) for x in extensions.extensions if not x.is_builtin and x.enabled == enabled]\r\n        else:\r\n            return [get_info_from_repo_path(d) for d in Path(paths_internal.extensions_dir).iterdir() if d.is_dir() and enabled != (str(d.name) in fallback_disabled_extensions)]\r\n    except Exception as e:\r\n        return str(e)\r\n\r\n\r\ndef get_config():\r\n    try:\r\n        from modules import shared\r\n        return shared.opts.data\r\n    except Exception as _:\r\n        try:\r\n            with open(shared_cmd_options.cmd_opts.ui_settings_file, 'r') as f:\r\n                return json.load(f)\r\n        except Exception as e:\r\n            return str(e)\r\n"
  },
  {
    "path": "modules/textual_inversion/autocrop.py",
    "content": "import cv2\r\nimport requests\r\nimport os\r\nimport numpy as np\r\nfrom PIL import ImageDraw\r\nfrom modules import paths_internal\r\nfrom pkg_resources import parse_version\r\n\r\nGREEN = \"#0F0\"\r\nBLUE = \"#00F\"\r\nRED = \"#F00\"\r\n\r\n\r\ndef crop_image(im, settings):\r\n    \"\"\" Intelligently crop an image to the subject matter \"\"\"\r\n\r\n    scale_by = 1\r\n    if is_landscape(im.width, im.height):\r\n        scale_by = settings.crop_height / im.height\r\n    elif is_portrait(im.width, im.height):\r\n        scale_by = settings.crop_width / im.width\r\n    elif is_square(im.width, im.height):\r\n        if is_square(settings.crop_width, settings.crop_height):\r\n            scale_by = settings.crop_width / im.width\r\n        elif is_landscape(settings.crop_width, settings.crop_height):\r\n            scale_by = settings.crop_width / im.width\r\n        elif is_portrait(settings.crop_width, settings.crop_height):\r\n            scale_by = settings.crop_height / im.height\r\n\r\n    im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))\r\n    im_debug = im.copy()\r\n\r\n    focus = focal_point(im_debug, settings)\r\n\r\n    # take the focal point and turn it into crop coordinates that try to center over the focal\r\n    # point but then get adjusted back into the frame\r\n    y_half = int(settings.crop_height / 2)\r\n    x_half = int(settings.crop_width / 2)\r\n\r\n    x1 = focus.x - x_half\r\n    if x1 < 0:\r\n        x1 = 0\r\n    elif x1 + settings.crop_width > im.width:\r\n        x1 = im.width - settings.crop_width\r\n\r\n    y1 = focus.y - y_half\r\n    if y1 < 0:\r\n        y1 = 0\r\n    elif y1 + settings.crop_height > im.height:\r\n        y1 = im.height - settings.crop_height\r\n\r\n    x2 = x1 + settings.crop_width\r\n    y2 = y1 + settings.crop_height\r\n\r\n    crop = [x1, y1, x2, y2]\r\n\r\n    results = []\r\n\r\n    results.append(im.crop(tuple(crop)))\r\n\r\n    if settings.annotate_image:\r\n        d = ImageDraw.Draw(im_debug)\r\n        rect = list(crop)\r\n        rect[2] -= 1\r\n        rect[3] -= 1\r\n        d.rectangle(rect, outline=GREEN)\r\n        results.append(im_debug)\r\n        if settings.desktop_view_image:\r\n            im_debug.show()\r\n\r\n    return results\r\n\r\n\r\ndef focal_point(im, settings):\r\n    corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else []\r\n    entropy_points = image_entropy_points(im, settings) if settings.entropy_points_weight > 0 else []\r\n    face_points = image_face_points(im, settings) if settings.face_points_weight > 0 else []\r\n\r\n    pois = []\r\n\r\n    weight_pref_total = 0\r\n    if corner_points:\r\n        weight_pref_total += settings.corner_points_weight\r\n    if entropy_points:\r\n        weight_pref_total += settings.entropy_points_weight\r\n    if face_points:\r\n        weight_pref_total += settings.face_points_weight\r\n\r\n    corner_centroid = None\r\n    if corner_points:\r\n        corner_centroid = centroid(corner_points)\r\n        corner_centroid.weight = settings.corner_points_weight / weight_pref_total\r\n        pois.append(corner_centroid)\r\n\r\n    entropy_centroid = None\r\n    if entropy_points:\r\n        entropy_centroid = centroid(entropy_points)\r\n        entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total\r\n        pois.append(entropy_centroid)\r\n\r\n    face_centroid = None\r\n    if face_points:\r\n        face_centroid = centroid(face_points)\r\n        face_centroid.weight = settings.face_points_weight / weight_pref_total\r\n        pois.append(face_centroid)\r\n\r\n    average_point = poi_average(pois, settings)\r\n\r\n    if settings.annotate_image:\r\n        d = ImageDraw.Draw(im)\r\n        max_size = min(im.width, im.height) * 0.07\r\n        if corner_centroid is not None:\r\n            color = BLUE\r\n            box = corner_centroid.bounding(max_size * corner_centroid.weight)\r\n            d.text((box[0], box[1] - 15), f\"Edge: {corner_centroid.weight:.02f}\", fill=color)\r\n            d.ellipse(box, outline=color)\r\n            if len(corner_points) > 1:\r\n                for f in corner_points:\r\n                    d.rectangle(f.bounding(4), outline=color)\r\n        if entropy_centroid is not None:\r\n            color = \"#ff0\"\r\n            box = entropy_centroid.bounding(max_size * entropy_centroid.weight)\r\n            d.text((box[0], box[1] - 15), f\"Entropy: {entropy_centroid.weight:.02f}\", fill=color)\r\n            d.ellipse(box, outline=color)\r\n            if len(entropy_points) > 1:\r\n                for f in entropy_points:\r\n                    d.rectangle(f.bounding(4), outline=color)\r\n        if face_centroid is not None:\r\n            color = RED\r\n            box = face_centroid.bounding(max_size * face_centroid.weight)\r\n            d.text((box[0], box[1] - 15), f\"Face: {face_centroid.weight:.02f}\", fill=color)\r\n            d.ellipse(box, outline=color)\r\n            if len(face_points) > 1:\r\n                for f in face_points:\r\n                    d.rectangle(f.bounding(4), outline=color)\r\n\r\n        d.ellipse(average_point.bounding(max_size), outline=GREEN)\r\n\r\n    return average_point\r\n\r\n\r\ndef image_face_points(im, settings):\r\n    if settings.dnn_model_path is not None:\r\n        detector = cv2.FaceDetectorYN.create(\r\n            settings.dnn_model_path,\r\n            \"\",\r\n            (im.width, im.height),\r\n            0.9,  # score threshold\r\n            0.3,  # nms threshold\r\n            5000  # keep top k before nms\r\n        )\r\n        faces = detector.detect(np.array(im))\r\n        results = []\r\n        if faces[1] is not None:\r\n            for face in faces[1]:\r\n                x = face[0]\r\n                y = face[1]\r\n                w = face[2]\r\n                h = face[3]\r\n                results.append(\r\n                    PointOfInterest(\r\n                        int(x + (w * 0.5)),  # face focus left/right is center\r\n                        int(y + (h * 0.33)),  # face focus up/down is close to the top of the head\r\n                        size=w,\r\n                        weight=1 / len(faces[1])\r\n                    )\r\n                )\r\n        return results\r\n    else:\r\n        np_im = np.array(im)\r\n        gray = cv2.cvtColor(np_im, cv2.COLOR_BGR2GRAY)\r\n\r\n        tries = [\r\n            [f'{cv2.data.haarcascades}haarcascade_eye.xml', 0.01],\r\n            [f'{cv2.data.haarcascades}haarcascade_frontalface_default.xml', 0.05],\r\n            [f'{cv2.data.haarcascades}haarcascade_profileface.xml', 0.05],\r\n            [f'{cv2.data.haarcascades}haarcascade_frontalface_alt.xml', 0.05],\r\n            [f'{cv2.data.haarcascades}haarcascade_frontalface_alt2.xml', 0.05],\r\n            [f'{cv2.data.haarcascades}haarcascade_frontalface_alt_tree.xml', 0.05],\r\n            [f'{cv2.data.haarcascades}haarcascade_eye_tree_eyeglasses.xml', 0.05],\r\n            [f'{cv2.data.haarcascades}haarcascade_upperbody.xml', 0.05]\r\n        ]\r\n        for t in tries:\r\n            classifier = cv2.CascadeClassifier(t[0])\r\n            minsize = int(min(im.width, im.height) * t[1])  # at least N percent of the smallest side\r\n            try:\r\n                faces = classifier.detectMultiScale(gray, scaleFactor=1.1,\r\n                                                    minNeighbors=7, minSize=(minsize, minsize),\r\n                                                    flags=cv2.CASCADE_SCALE_IMAGE)\r\n            except Exception:\r\n                continue\r\n\r\n            if faces:\r\n                rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces]\r\n                return [PointOfInterest((r[0] + r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0] - r[2]),\r\n                                        weight=1 / len(rects)) for r in rects]\r\n    return []\r\n\r\n\r\ndef image_corner_points(im, settings):\r\n    grayscale = im.convert(\"L\")\r\n\r\n    # naive attempt at preventing focal points from collecting at watermarks near the bottom\r\n    gd = ImageDraw.Draw(grayscale)\r\n    gd.rectangle([0, im.height * .9, im.width, im.height], fill=\"#999\")\r\n\r\n    np_im = np.array(grayscale)\r\n\r\n    points = cv2.goodFeaturesToTrack(\r\n        np_im,\r\n        maxCorners=100,\r\n        qualityLevel=0.04,\r\n        minDistance=min(grayscale.width, grayscale.height) * 0.06,\r\n        useHarrisDetector=False,\r\n    )\r\n\r\n    if points is None:\r\n        return []\r\n\r\n    focal_points = []\r\n    for point in points:\r\n        x, y = point.ravel()\r\n        focal_points.append(PointOfInterest(x, y, size=4, weight=1 / len(points)))\r\n\r\n    return focal_points\r\n\r\n\r\ndef image_entropy_points(im, settings):\r\n    landscape = im.height < im.width\r\n    portrait = im.height > im.width\r\n    if landscape:\r\n        move_idx = [0, 2]\r\n        move_max = im.size[0]\r\n    elif portrait:\r\n        move_idx = [1, 3]\r\n        move_max = im.size[1]\r\n    else:\r\n        return []\r\n\r\n    e_max = 0\r\n    crop_current = [0, 0, settings.crop_width, settings.crop_height]\r\n    crop_best = crop_current\r\n    while crop_current[move_idx[1]] < move_max:\r\n        crop = im.crop(tuple(crop_current))\r\n        e = image_entropy(crop)\r\n\r\n        if (e > e_max):\r\n            e_max = e\r\n            crop_best = list(crop_current)\r\n\r\n        crop_current[move_idx[0]] += 4\r\n        crop_current[move_idx[1]] += 4\r\n\r\n    x_mid = int(crop_best[0] + settings.crop_width / 2)\r\n    y_mid = int(crop_best[1] + settings.crop_height / 2)\r\n\r\n    return [PointOfInterest(x_mid, y_mid, size=25, weight=1.0)]\r\n\r\n\r\ndef image_entropy(im):\r\n    # greyscale image entropy\r\n    # band = np.asarray(im.convert(\"L\"))\r\n    band = np.asarray(im.convert(\"1\"), dtype=np.uint8)\r\n    hist, _ = np.histogram(band, bins=range(0, 256))\r\n    hist = hist[hist > 0]\r\n    return -np.log2(hist / hist.sum()).sum()\r\n\r\n\r\ndef centroid(pois):\r\n    x = [poi.x for poi in pois]\r\n    y = [poi.y for poi in pois]\r\n    return PointOfInterest(sum(x) / len(pois), sum(y) / len(pois))\r\n\r\n\r\ndef poi_average(pois, settings):\r\n    weight = 0.0\r\n    x = 0.0\r\n    y = 0.0\r\n    for poi in pois:\r\n        weight += poi.weight\r\n        x += poi.x * poi.weight\r\n        y += poi.y * poi.weight\r\n    avg_x = round(weight and x / weight)\r\n    avg_y = round(weight and y / weight)\r\n\r\n    return PointOfInterest(avg_x, avg_y)\r\n\r\n\r\ndef is_landscape(w, h):\r\n    return w > h\r\n\r\n\r\ndef is_portrait(w, h):\r\n    return h > w\r\n\r\n\r\ndef is_square(w, h):\r\n    return w == h\r\n\r\n\r\nmodel_dir_opencv = os.path.join(paths_internal.models_path, 'opencv')\r\nif parse_version(cv2.__version__) >= parse_version('4.8'):\r\n    model_file_path = os.path.join(model_dir_opencv, 'face_detection_yunet_2023mar.onnx')\r\n    model_url = 'https://github.com/opencv/opencv_zoo/blob/b6e370b10f641879a87890d44e42173077154a05/models/face_detection_yunet/face_detection_yunet_2023mar.onnx?raw=true'\r\nelse:\r\n    model_file_path = os.path.join(model_dir_opencv, 'face_detection_yunet.onnx')\r\n    model_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'\r\n\r\n\r\ndef download_and_cache_models():\r\n    if not os.path.exists(model_file_path):\r\n        os.makedirs(model_dir_opencv, exist_ok=True)\r\n        print(f\"downloading face detection model from '{model_url}' to '{model_file_path}'\")\r\n        response = requests.get(model_url)\r\n        with open(model_file_path, \"wb\") as f:\r\n            f.write(response.content)\r\n    return model_file_path\r\n\r\n\r\nclass PointOfInterest:\r\n    def __init__(self, x, y, weight=1.0, size=10):\r\n        self.x = x\r\n        self.y = y\r\n        self.weight = weight\r\n        self.size = size\r\n\r\n    def bounding(self, size):\r\n        return [\r\n            self.x - size // 2,\r\n            self.y - size // 2,\r\n            self.x + size // 2,\r\n            self.y + size // 2\r\n        ]\r\n\r\n\r\nclass Settings:\r\n    def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None):\r\n        self.crop_width = crop_width\r\n        self.crop_height = crop_height\r\n        self.corner_points_weight = corner_points_weight\r\n        self.entropy_points_weight = entropy_points_weight\r\n        self.face_points_weight = face_points_weight\r\n        self.annotate_image = annotate_image\r\n        self.desktop_view_image = False\r\n        self.dnn_model_path = dnn_model_path\r\n"
  },
  {
    "path": "modules/textual_inversion/dataset.py",
    "content": "import os\r\nimport numpy as np\r\nimport PIL\r\nimport torch\r\nfrom torch.utils.data import Dataset, DataLoader, Sampler\r\nfrom torchvision import transforms\r\nfrom collections import defaultdict\r\nfrom random import shuffle, choices\r\n\r\nimport random\r\nimport tqdm\r\nfrom modules import devices, shared, images\r\nimport re\r\n\r\nfrom ldm.modules.distributions.distributions import DiagonalGaussianDistribution\r\n\r\nre_numbers_at_start = re.compile(r\"^[-\\d]+\\s*\")\r\n\r\n\r\nclass DatasetEntry:\r\n    def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, weight=None):\r\n        self.filename = filename\r\n        self.filename_text = filename_text\r\n        self.weight = weight\r\n        self.latent_dist = latent_dist\r\n        self.latent_sample = latent_sample\r\n        self.cond = cond\r\n        self.cond_text = cond_text\r\n        self.pixel_values = pixel_values\r\n\r\n\r\nclass PersonalizedBase(Dataset):\r\n    def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token=\"*\", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False, use_weight=False):\r\n        re_word = re.compile(shared.opts.dataset_filename_word_regex) if shared.opts.dataset_filename_word_regex else None\r\n\r\n        self.placeholder_token = placeholder_token\r\n\r\n        self.flip = transforms.RandomHorizontalFlip(p=flip_p)\r\n\r\n        self.dataset = []\r\n\r\n        with open(template_file, \"r\") as file:\r\n            lines = [x.strip() for x in file.readlines()]\r\n\r\n        self.lines = lines\r\n\r\n        assert data_root, 'dataset directory not specified'\r\n        assert os.path.isdir(data_root), \"Dataset directory doesn't exist\"\r\n        assert os.listdir(data_root), \"Dataset directory is empty\"\r\n\r\n        self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]\r\n\r\n        self.shuffle_tags = shuffle_tags\r\n        self.tag_drop_out = tag_drop_out\r\n        groups = defaultdict(list)\r\n\r\n        print(\"Preparing dataset...\")\r\n        for path in tqdm.tqdm(self.image_paths):\r\n            alpha_channel = None\r\n            if shared.state.interrupted:\r\n                raise Exception(\"interrupted\")\r\n            try:\r\n                image = images.read(path)\r\n                #Currently does not work for single color transparency\r\n                #We would need to read image.info['transparency'] for that\r\n                if use_weight and 'A' in image.getbands():\r\n                    alpha_channel = image.getchannel('A')\r\n                image = image.convert('RGB')\r\n                if not varsize:\r\n                    image = image.resize((width, height), PIL.Image.BICUBIC)\r\n            except Exception:\r\n                continue\r\n\r\n            text_filename = f\"{os.path.splitext(path)[0]}.txt\"\r\n            filename = os.path.basename(path)\r\n\r\n            if os.path.exists(text_filename):\r\n                with open(text_filename, \"r\", encoding=\"utf8\") as file:\r\n                    filename_text = file.read()\r\n            else:\r\n                filename_text = os.path.splitext(filename)[0]\r\n                filename_text = re.sub(re_numbers_at_start, '', filename_text)\r\n                if re_word:\r\n                    tokens = re_word.findall(filename_text)\r\n                    filename_text = (shared.opts.dataset_filename_join_string or \"\").join(tokens)\r\n\r\n            npimage = np.array(image).astype(np.uint8)\r\n            npimage = (npimage / 127.5 - 1.0).astype(np.float32)\r\n\r\n            torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32)\r\n            latent_sample = None\r\n\r\n            with devices.autocast():\r\n                latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))\r\n\r\n            #Perform latent sampling, even for random sampling.\r\n            #We need the sample dimensions for the weights\r\n            if latent_sampling_method == \"deterministic\":\r\n                if isinstance(latent_dist, DiagonalGaussianDistribution):\r\n                    # Works only for DiagonalGaussianDistribution\r\n                    latent_dist.std = 0\r\n                else:\r\n                    latent_sampling_method = \"once\"\r\n            latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)\r\n\r\n            if use_weight and alpha_channel is not None:\r\n                channels, *latent_size = latent_sample.shape\r\n                weight_img = alpha_channel.resize(latent_size)\r\n                npweight = np.array(weight_img).astype(np.float32)\r\n                #Repeat for every channel in the latent sample\r\n                weight = torch.tensor([npweight] * channels).reshape([channels] + latent_size)\r\n                #Normalize the weight to a minimum of 0 and a mean of 1, that way the loss will be comparable to default.\r\n                weight -= weight.min()\r\n                weight /= weight.mean()\r\n            elif use_weight:\r\n                #If an image does not have a alpha channel, add a ones weight map anyway so we can stack it later\r\n                weight = torch.ones(latent_sample.shape)\r\n            else:\r\n                weight = None\r\n\r\n            if latent_sampling_method == \"random\":\r\n                entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight)\r\n            else:\r\n                entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, weight=weight)\r\n\r\n            if not (self.tag_drop_out != 0 or self.shuffle_tags):\r\n                entry.cond_text = self.create_text(filename_text)\r\n\r\n            if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):\r\n                with devices.autocast():\r\n                    entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)\r\n            groups[image.size].append(len(self.dataset))\r\n            self.dataset.append(entry)\r\n            del torchdata\r\n            del latent_dist\r\n            del latent_sample\r\n            del weight\r\n\r\n        self.length = len(self.dataset)\r\n        self.groups = list(groups.values())\r\n        assert self.length > 0, \"No images have been found in the dataset.\"\r\n        self.batch_size = min(batch_size, self.length)\r\n        self.gradient_step = min(gradient_step, self.length // self.batch_size)\r\n        self.latent_sampling_method = latent_sampling_method\r\n\r\n        if len(groups) > 1:\r\n            print(\"Buckets:\")\r\n            for (w, h), ids in sorted(groups.items(), key=lambda x: x[0]):\r\n                print(f\"  {w}x{h}: {len(ids)}\")\r\n            print()\r\n\r\n    def create_text(self, filename_text):\r\n        text = random.choice(self.lines)\r\n        tags = filename_text.split(',')\r\n        if self.tag_drop_out != 0:\r\n            tags = [t for t in tags if random.random() > self.tag_drop_out]\r\n        if self.shuffle_tags:\r\n            random.shuffle(tags)\r\n        text = text.replace(\"[filewords]\", ','.join(tags))\r\n        text = text.replace(\"[name]\", self.placeholder_token)\r\n        return text\r\n\r\n    def __len__(self):\r\n        return self.length\r\n\r\n    def __getitem__(self, i):\r\n        entry = self.dataset[i]\r\n        if self.tag_drop_out != 0 or self.shuffle_tags:\r\n            entry.cond_text = self.create_text(entry.filename_text)\r\n        if self.latent_sampling_method == \"random\":\r\n            entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)\r\n        return entry\r\n\r\n\r\nclass GroupedBatchSampler(Sampler):\r\n    def __init__(self, data_source: PersonalizedBase, batch_size: int):\r\n        super().__init__(data_source)\r\n\r\n        n = len(data_source)\r\n        self.groups = data_source.groups\r\n        self.len = n_batch = n // batch_size\r\n        expected = [len(g) / n * n_batch * batch_size for g in data_source.groups]\r\n        self.base = [int(e) // batch_size for e in expected]\r\n        self.n_rand_batches = nrb = n_batch - sum(self.base)\r\n        self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]\r\n        self.batch_size = batch_size\r\n\r\n    def __len__(self):\r\n        return self.len\r\n\r\n    def __iter__(self):\r\n        b = self.batch_size\r\n\r\n        for g in self.groups:\r\n            shuffle(g)\r\n\r\n        batches = []\r\n        for g in self.groups:\r\n            batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))\r\n        for _ in range(self.n_rand_batches):\r\n            rand_group = choices(self.groups, self.probs)[0]\r\n            batches.append(choices(rand_group, k=b))\r\n\r\n        shuffle(batches)\r\n\r\n        yield from batches\r\n\r\n\r\nclass PersonalizedDataLoader(DataLoader):\r\n    def __init__(self, dataset, latent_sampling_method=\"once\", batch_size=1, pin_memory=False):\r\n        super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory)\r\n        if latent_sampling_method == \"random\":\r\n            self.collate_fn = collate_wrapper_random\r\n        else:\r\n            self.collate_fn = collate_wrapper\r\n\r\n\r\nclass BatchLoader:\r\n    def __init__(self, data):\r\n        self.cond_text = [entry.cond_text for entry in data]\r\n        self.cond = [entry.cond for entry in data]\r\n        self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)\r\n        if all(entry.weight is not None for entry in data):\r\n            self.weight = torch.stack([entry.weight for entry in data]).squeeze(1)\r\n        else:\r\n            self.weight = None\r\n        #self.emb_index = [entry.emb_index for entry in data]\r\n        #print(self.latent_sample.device)\r\n\r\n    def pin_memory(self):\r\n        self.latent_sample = self.latent_sample.pin_memory()\r\n        return self\r\n\r\ndef collate_wrapper(batch):\r\n    return BatchLoader(batch)\r\n\r\nclass BatchLoaderRandom(BatchLoader):\r\n    def __init__(self, data):\r\n        super().__init__(data)\r\n\r\n    def pin_memory(self):\r\n        return self\r\n\r\ndef collate_wrapper_random(batch):\r\n    return BatchLoaderRandom(batch)\r\n"
  },
  {
    "path": "modules/textual_inversion/image_embedding.py",
    "content": "import base64\r\nimport json\r\nimport os.path\r\nimport warnings\r\nimport logging\r\n\r\nimport numpy as np\r\nimport zlib\r\nfrom PIL import Image, ImageDraw\r\nimport torch\r\n\r\nlogger = logging.getLogger(__name__)\r\n\r\n\r\nclass EmbeddingEncoder(json.JSONEncoder):\r\n    def default(self, obj):\r\n        if isinstance(obj, torch.Tensor):\r\n            return {'TORCHTENSOR': obj.cpu().detach().numpy().tolist()}\r\n        return json.JSONEncoder.default(self, obj)\r\n\r\n\r\nclass EmbeddingDecoder(json.JSONDecoder):\r\n    def __init__(self, *args, **kwargs):\r\n        json.JSONDecoder.__init__(self, *args, object_hook=self.object_hook, **kwargs)\r\n\r\n    def object_hook(self, d):\r\n        if 'TORCHTENSOR' in d:\r\n            return torch.from_numpy(np.array(d['TORCHTENSOR']))\r\n        return d\r\n\r\n\r\ndef embedding_to_b64(data):\r\n    d = json.dumps(data, cls=EmbeddingEncoder)\r\n    return base64.b64encode(d.encode())\r\n\r\n\r\ndef embedding_from_b64(data):\r\n    d = base64.b64decode(data)\r\n    return json.loads(d, cls=EmbeddingDecoder)\r\n\r\n\r\ndef lcg(m=2**32, a=1664525, c=1013904223, seed=0):\r\n    while True:\r\n        seed = (a * seed + c) % m\r\n        yield seed % 255\r\n\r\n\r\ndef xor_block(block):\r\n    g = lcg()\r\n    randblock = np.array([next(g) for _ in range(np.prod(block.shape))]).astype(np.uint8).reshape(block.shape)\r\n    return np.bitwise_xor(block.astype(np.uint8), randblock & 0x0F)\r\n\r\n\r\ndef style_block(block, sequence):\r\n    im = Image.new('RGB', (block.shape[1], block.shape[0]))\r\n    draw = ImageDraw.Draw(im)\r\n    i = 0\r\n    for x in range(-6, im.size[0], 8):\r\n        for yi, y in enumerate(range(-6, im.size[1], 8)):\r\n            offset = 0\r\n            if yi % 2 == 0:\r\n                offset = 4\r\n            shade = sequence[i % len(sequence)]\r\n            i += 1\r\n            draw.ellipse((x+offset, y, x+6+offset, y+6), fill=(shade, shade, shade))\r\n\r\n    fg = np.array(im).astype(np.uint8) & 0xF0\r\n\r\n    return block ^ fg\r\n\r\n\r\ndef insert_image_data_embed(image, data):\r\n    d = 3\r\n    data_compressed = zlib.compress(json.dumps(data, cls=EmbeddingEncoder).encode(), level=9)\r\n    data_np_ = np.frombuffer(data_compressed, np.uint8).copy()\r\n    data_np_high = data_np_ >> 4\r\n    data_np_low = data_np_ & 0x0F\r\n\r\n    h = image.size[1]\r\n    next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0] % h))\r\n    next_size = next_size + ((h*d)-(next_size % (h*d)))\r\n\r\n    data_np_low = np.resize(data_np_low, next_size)\r\n    data_np_low = data_np_low.reshape((h, -1, d))\r\n\r\n    data_np_high = np.resize(data_np_high, next_size)\r\n    data_np_high = data_np_high.reshape((h, -1, d))\r\n\r\n    edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024]\r\n    edge_style = (np.abs(edge_style)/np.max(np.abs(edge_style))*255).astype(np.uint8)\r\n\r\n    data_np_low = style_block(data_np_low, sequence=edge_style)\r\n    data_np_low = xor_block(data_np_low)\r\n    data_np_high = style_block(data_np_high, sequence=edge_style[::-1])\r\n    data_np_high = xor_block(data_np_high)\r\n\r\n    im_low = Image.fromarray(data_np_low, mode='RGB')\r\n    im_high = Image.fromarray(data_np_high, mode='RGB')\r\n\r\n    background = Image.new('RGB', (image.size[0]+im_low.size[0]+im_high.size[0]+2, image.size[1]), (0, 0, 0))\r\n    background.paste(im_low, (0, 0))\r\n    background.paste(image, (im_low.size[0]+1, 0))\r\n    background.paste(im_high, (im_low.size[0]+1+image.size[0]+1, 0))\r\n\r\n    return background\r\n\r\n\r\ndef crop_black(img, tol=0):\r\n    mask = (img > tol).all(2)\r\n    mask0, mask1 = mask.any(0), mask.any(1)\r\n    col_start, col_end = mask0.argmax(), mask.shape[1]-mask0[::-1].argmax()\r\n    row_start, row_end = mask1.argmax(), mask.shape[0]-mask1[::-1].argmax()\r\n    return img[row_start:row_end, col_start:col_end]\r\n\r\n\r\ndef extract_image_data_embed(image):\r\n    d = 3\r\n    outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1], image.size[0], d).astype(np.uint8)) & 0x0F\r\n    black_cols = np.where(np.sum(outarr, axis=(0, 2)) == 0)\r\n    if black_cols[0].shape[0] < 2:\r\n        logger.debug(f'{os.path.basename(getattr(image, \"filename\", \"unknown image file\"))}: no embedded information found.')\r\n        return None\r\n\r\n    data_block_lower = outarr[:, :black_cols[0].min(), :].astype(np.uint8)\r\n    data_block_upper = outarr[:, black_cols[0].max()+1:, :].astype(np.uint8)\r\n\r\n    data_block_lower = xor_block(data_block_lower)\r\n    data_block_upper = xor_block(data_block_upper)\r\n\r\n    data_block = (data_block_upper << 4) | (data_block_lower)\r\n    data_block = data_block.flatten().tobytes()\r\n\r\n    data = zlib.decompress(data_block)\r\n    return json.loads(data, cls=EmbeddingDecoder)\r\n\r\n\r\ndef caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, textfont=None):\r\n    from modules.images import get_font\r\n    if textfont:\r\n        warnings.warn(\r\n            'passing in a textfont to caption_image_overlay is deprecated and does nothing',\r\n            DeprecationWarning,\r\n            stacklevel=2,\r\n        )\r\n    from math import cos\r\n\r\n    image = srcimage.copy()\r\n    fontsize = 32\r\n    factor = 1.5\r\n    gradient = Image.new('RGBA', (1, image.size[1]), color=(0, 0, 0, 0))\r\n    for y in range(image.size[1]):\r\n        mag = 1-cos(y/image.size[1]*factor)\r\n        mag = max(mag, 1-cos((image.size[1]-y)/image.size[1]*factor*1.1))\r\n        gradient.putpixel((0, y), (0, 0, 0, int(mag*255)))\r\n    image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size))\r\n\r\n    draw = ImageDraw.Draw(image)\r\n\r\n    font = get_font(fontsize)\r\n    padding = 10\r\n\r\n    _, _, w, h = draw.textbbox((0, 0), title, font=font)\r\n    fontsize = min(int(fontsize * (((image.size[0]*0.75)-(padding*4))/w)), 72)\r\n    font = get_font(fontsize)\r\n    _, _, w, h = draw.textbbox((0, 0), title, font=font)\r\n    draw.text((padding, padding), title, anchor='lt', font=font, fill=(255, 255, 255, 230))\r\n\r\n    _, _, w, h = draw.textbbox((0, 0), footerLeft, font=font)\r\n    fontsize_left = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)\r\n    _, _, w, h = draw.textbbox((0, 0), footerMid, font=font)\r\n    fontsize_mid = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)\r\n    _, _, w, h = draw.textbbox((0, 0), footerRight, font=font)\r\n    fontsize_right = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)\r\n\r\n    font = get_font(min(fontsize_left, fontsize_mid, fontsize_right))\r\n\r\n    draw.text((padding, image.size[1]-padding),               footerLeft, anchor='ls', font=font, fill=(255, 255, 255, 230))\r\n    draw.text((image.size[0]/2, image.size[1]-padding),       footerMid, anchor='ms', font=font, fill=(255, 255, 255, 230))\r\n    draw.text((image.size[0]-padding, image.size[1]-padding), footerRight, anchor='rs', font=font, fill=(255, 255, 255, 230))\r\n\r\n    return image\r\n\r\n\r\nif __name__ == '__main__':\r\n\r\n    testEmbed = Image.open('test_embedding.png')\r\n    data = extract_image_data_embed(testEmbed)\r\n    assert data is not None\r\n\r\n    data = embedding_from_b64(testEmbed.text['sd-ti-embedding'])\r\n    assert data is not None\r\n\r\n    image = Image.new('RGBA', (512, 512), (255, 255, 200, 255))\r\n    cap_image = caption_image_overlay(image, 'title', 'footerLeft', 'footerMid', 'footerRight')\r\n\r\n    test_embed = {'string_to_param': {'*': torch.from_numpy(np.random.random((2, 4096)))}}\r\n\r\n    embedded_image = insert_image_data_embed(cap_image, test_embed)\r\n\r\n    retrieved_embed = extract_image_data_embed(embedded_image)\r\n\r\n    assert str(retrieved_embed) == str(test_embed)\r\n\r\n    embedded_image2 = insert_image_data_embed(cap_image, retrieved_embed)\r\n\r\n    assert embedded_image == embedded_image2\r\n\r\n    g = lcg()\r\n    shared_random = np.array([next(g) for _ in range(100)]).astype(np.uint8).tolist()\r\n\r\n    reference_random = [253, 242, 127,  44, 157,  27, 239, 133,  38,  79, 167,   4, 177,\r\n                         95, 130,  79,  78,  14,  52, 215, 220, 194, 126,  28, 240, 179,\r\n                        160, 153, 149,  50, 105,  14,  21, 218, 199,  18,  54, 198, 193,\r\n                         38, 128,  19,  53, 195, 124,  75, 205,  12,   6, 145,   0,  28,\r\n                         30, 148,   8,  45, 218, 171,  55, 249,  97, 166,  12,  35,   0,\r\n                         41, 221, 122, 215, 170,  31, 113, 186,  97, 119,  31,  23, 185,\r\n                         66, 140,  30,  41,  37,  63, 137, 109, 216,  55, 159, 145,  82,\r\n                         204, 86,  73, 222,  44, 198, 118, 240,  97]\r\n\r\n    assert shared_random == reference_random\r\n\r\n    hunna_kay_random_sum = sum(np.array([next(g) for _ in range(100000)]).astype(np.uint8).tolist())\r\n\r\n    assert 12731374 == hunna_kay_random_sum\r\n"
  },
  {
    "path": "modules/textual_inversion/learn_schedule.py",
    "content": "import tqdm\r\n\r\n\r\nclass LearnScheduleIterator:\r\n    def __init__(self, learn_rate, max_steps, cur_step=0):\r\n        \"\"\"\r\n        specify learn_rate as \"0.001:100, 0.00001:1000, 1e-5:10000\" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000\r\n        \"\"\"\r\n\r\n        pairs = learn_rate.split(',')\r\n        self.rates = []\r\n        self.it = 0\r\n        self.maxit = 0\r\n        try:\r\n            for pair in pairs:\r\n                if not pair.strip():\r\n                    continue\r\n                tmp = pair.split(':')\r\n                if len(tmp) == 2:\r\n                    step = int(tmp[1])\r\n                    if step > cur_step:\r\n                        self.rates.append((float(tmp[0]), min(step, max_steps)))\r\n                        self.maxit += 1\r\n                        if step > max_steps:\r\n                            return\r\n                    elif step == -1:\r\n                        self.rates.append((float(tmp[0]), max_steps))\r\n                        self.maxit += 1\r\n                        return\r\n                else:\r\n                    self.rates.append((float(tmp[0]), max_steps))\r\n                    self.maxit += 1\r\n                    return\r\n            assert self.rates\r\n        except (ValueError, AssertionError) as e:\r\n            raise Exception('Invalid learning rate schedule. It should be a number or, for example, like \"0.001:100, 0.00001:1000, 1e-5:10000\" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.') from e\r\n\r\n\r\n    def __iter__(self):\r\n        return self\r\n\r\n    def __next__(self):\r\n        if self.it < self.maxit:\r\n            self.it += 1\r\n            return self.rates[self.it - 1]\r\n        else:\r\n            raise StopIteration\r\n\r\n\r\nclass LearnRateScheduler:\r\n    def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True):\r\n        self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step)\r\n        (self.learn_rate,  self.end_step) = next(self.schedules)\r\n        self.verbose = verbose\r\n\r\n        if self.verbose:\r\n            print(f'Training at rate of {self.learn_rate} until step {self.end_step}')\r\n\r\n        self.finished = False\r\n\r\n    def step(self, step_number):\r\n        if step_number < self.end_step:\r\n            return False\r\n\r\n        try:\r\n            (self.learn_rate, self.end_step) = next(self.schedules)\r\n        except StopIteration:\r\n            self.finished = True\r\n            return False\r\n        return True\r\n\r\n    def apply(self, optimizer, step_number):\r\n        if not self.step(step_number):\r\n            return\r\n\r\n        if self.verbose:\r\n            tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}')\r\n\r\n        for pg in optimizer.param_groups:\r\n            pg['lr'] = self.learn_rate\r\n\r\n"
  },
  {
    "path": "modules/textual_inversion/saving_settings.py",
    "content": "import datetime\r\nimport json\r\nimport os\r\n\r\nsaved_params_shared = {\r\n    \"batch_size\",\r\n    \"clip_grad_mode\",\r\n    \"clip_grad_value\",\r\n    \"create_image_every\",\r\n    \"data_root\",\r\n    \"gradient_step\",\r\n    \"initial_step\",\r\n    \"latent_sampling_method\",\r\n    \"learn_rate\",\r\n    \"log_directory\",\r\n    \"model_hash\",\r\n    \"model_name\",\r\n    \"num_of_dataset_images\",\r\n    \"steps\",\r\n    \"template_file\",\r\n    \"training_height\",\r\n    \"training_width\",\r\n}\r\nsaved_params_ti = {\r\n    \"embedding_name\",\r\n    \"num_vectors_per_token\",\r\n    \"save_embedding_every\",\r\n    \"save_image_with_stored_embedding\",\r\n}\r\nsaved_params_hypernet = {\r\n    \"activation_func\",\r\n    \"add_layer_norm\",\r\n    \"hypernetwork_name\",\r\n    \"layer_structure\",\r\n    \"save_hypernetwork_every\",\r\n    \"use_dropout\",\r\n    \"weight_init\",\r\n}\r\nsaved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet\r\nsaved_params_previews = {\r\n    \"preview_cfg_scale\",\r\n    \"preview_height\",\r\n    \"preview_negative_prompt\",\r\n    \"preview_prompt\",\r\n    \"preview_sampler_index\",\r\n    \"preview_seed\",\r\n    \"preview_steps\",\r\n    \"preview_width\",\r\n}\r\n\r\n\r\ndef save_settings_to_file(log_directory, all_params):\r\n    now = datetime.datetime.now()\r\n    params = {\"datetime\": now.strftime(\"%Y-%m-%d %H:%M:%S\")}\r\n\r\n    keys = saved_params_all\r\n    if all_params.get('preview_from_txt2img'):\r\n        keys = keys | saved_params_previews\r\n\r\n    params.update({k: v for k, v in all_params.items() if k in keys})\r\n\r\n    filename = f'settings-{now.strftime(\"%Y-%m-%d-%H-%M-%S\")}.json'\r\n    with open(os.path.join(log_directory, filename), \"w\") as file:\r\n        json.dump(params, file, indent=4)\r\n"
  },
  {
    "path": "modules/textual_inversion/textual_inversion.py",
    "content": "import os\r\nfrom collections import namedtuple\r\nfrom contextlib import closing\r\n\r\nimport torch\r\nimport tqdm\r\nimport html\r\nimport datetime\r\nimport csv\r\nimport safetensors.torch\r\n\r\nimport numpy as np\r\nfrom PIL import Image, PngImagePlugin\r\n\r\nfrom modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes\r\nimport modules.textual_inversion.dataset\r\nfrom modules.textual_inversion.learn_schedule import LearnRateScheduler\r\n\r\nfrom modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay\r\nfrom modules.textual_inversion.saving_settings import save_settings_to_file\r\n\r\n\r\nTextualInversionTemplate = namedtuple(\"TextualInversionTemplate\", [\"name\", \"path\"])\r\ntextual_inversion_templates = {}\r\n\r\n\r\ndef list_textual_inversion_templates():\r\n    textual_inversion_templates.clear()\r\n\r\n    for root, _, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir):\r\n        for fn in fns:\r\n            path = os.path.join(root, fn)\r\n\r\n            textual_inversion_templates[fn] = TextualInversionTemplate(fn, path)\r\n\r\n    return textual_inversion_templates\r\n\r\n\r\nclass Embedding:\r\n    def __init__(self, vec, name, step=None):\r\n        self.vec = vec\r\n        self.name = name\r\n        self.step = step\r\n        self.shape = None\r\n        self.vectors = 0\r\n        self.cached_checksum = None\r\n        self.sd_checkpoint = None\r\n        self.sd_checkpoint_name = None\r\n        self.optimizer_state_dict = None\r\n        self.filename = None\r\n        self.hash = None\r\n        self.shorthash = None\r\n\r\n    def save(self, filename):\r\n        embedding_data = {\r\n            \"string_to_token\": {\"*\": 265},\r\n            \"string_to_param\": {\"*\": self.vec},\r\n            \"name\": self.name,\r\n            \"step\": self.step,\r\n            \"sd_checkpoint\": self.sd_checkpoint,\r\n            \"sd_checkpoint_name\": self.sd_checkpoint_name,\r\n        }\r\n\r\n        torch.save(embedding_data, filename)\r\n\r\n        if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None:\r\n            optimizer_saved_dict = {\r\n                'hash': self.checksum(),\r\n                'optimizer_state_dict': self.optimizer_state_dict,\r\n            }\r\n            torch.save(optimizer_saved_dict, f\"{filename}.optim\")\r\n\r\n    def checksum(self):\r\n        if self.cached_checksum is not None:\r\n            return self.cached_checksum\r\n\r\n        def const_hash(a):\r\n            r = 0\r\n            for v in a:\r\n                r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF\r\n            return r\r\n\r\n        self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'\r\n        return self.cached_checksum\r\n\r\n    def set_hash(self, v):\r\n        self.hash = v\r\n        self.shorthash = self.hash[0:12]\r\n\r\n\r\nclass DirWithTextualInversionEmbeddings:\r\n    def __init__(self, path):\r\n        self.path = path\r\n        self.mtime = None\r\n\r\n    def has_changed(self):\r\n        if not os.path.isdir(self.path):\r\n            return False\r\n\r\n        mt = os.path.getmtime(self.path)\r\n        if self.mtime is None or mt > self.mtime:\r\n            return True\r\n\r\n    def update(self):\r\n        if not os.path.isdir(self.path):\r\n            return\r\n\r\n        self.mtime = os.path.getmtime(self.path)\r\n\r\n\r\nclass EmbeddingDatabase:\r\n    def __init__(self):\r\n        self.ids_lookup = {}\r\n        self.word_embeddings = {}\r\n        self.skipped_embeddings = {}\r\n        self.expected_shape = -1\r\n        self.embedding_dirs = {}\r\n        self.previously_displayed_embeddings = ()\r\n\r\n    def add_embedding_dir(self, path):\r\n        self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)\r\n\r\n    def clear_embedding_dirs(self):\r\n        self.embedding_dirs.clear()\r\n\r\n    def register_embedding(self, embedding, model):\r\n        return self.register_embedding_by_name(embedding, model, embedding.name)\r\n\r\n    def register_embedding_by_name(self, embedding, model, name):\r\n        ids = model.cond_stage_model.tokenize([name])[0]\r\n        first_id = ids[0]\r\n        if first_id not in self.ids_lookup:\r\n            self.ids_lookup[first_id] = []\r\n        if name in self.word_embeddings:\r\n            # remove old one from the lookup list\r\n            lookup = [x for x in self.ids_lookup[first_id] if x[1].name!=name]\r\n        else:\r\n            lookup = self.ids_lookup[first_id]\r\n        if embedding is not None:\r\n            lookup += [(ids, embedding)]\r\n        self.ids_lookup[first_id] = sorted(lookup, key=lambda x: len(x[0]), reverse=True)\r\n        if embedding is None:\r\n            # unregister embedding with specified name\r\n            if name in self.word_embeddings:\r\n                del self.word_embeddings[name]\r\n            if len(self.ids_lookup[first_id])==0:\r\n                del self.ids_lookup[first_id]\r\n            return None\r\n        self.word_embeddings[name] = embedding\r\n        return embedding\r\n\r\n    def get_expected_shape(self):\r\n        devices.torch_npu_set_device()\r\n        vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(\",\", 1)\r\n        return vec.shape[1]\r\n\r\n    def load_from_file(self, path, filename):\r\n        name, ext = os.path.splitext(filename)\r\n        ext = ext.upper()\r\n\r\n        if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:\r\n            _, second_ext = os.path.splitext(name)\r\n            if second_ext.upper() == '.PREVIEW':\r\n                return\r\n\r\n            embed_image = Image.open(path)\r\n            if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:\r\n                data = embedding_from_b64(embed_image.text['sd-ti-embedding'])\r\n                name = data.get('name', name)\r\n            else:\r\n                data = extract_image_data_embed(embed_image)\r\n                if data:\r\n                    name = data.get('name', name)\r\n                else:\r\n                    # if data is None, means this is not an embedding, just a preview image\r\n                    return\r\n        elif ext in ['.BIN', '.PT']:\r\n            data = torch.load(path, map_location=\"cpu\")\r\n        elif ext in ['.SAFETENSORS']:\r\n            data = safetensors.torch.load_file(path, device=\"cpu\")\r\n        else:\r\n            return\r\n\r\n        if data is not None:\r\n            embedding = create_embedding_from_data(data, name, filename=filename, filepath=path)\r\n\r\n            if self.expected_shape == -1 or self.expected_shape == embedding.shape:\r\n                self.register_embedding(embedding, shared.sd_model)\r\n            else:\r\n                self.skipped_embeddings[name] = embedding\r\n        else:\r\n            print(f\"Unable to load Textual inversion embedding due to data issue: '{name}'.\")\r\n\r\n\r\n    def load_from_dir(self, embdir):\r\n        if not os.path.isdir(embdir.path):\r\n            return\r\n\r\n        for root, _, fns in os.walk(embdir.path, followlinks=True):\r\n            for fn in fns:\r\n                try:\r\n                    fullfn = os.path.join(root, fn)\r\n\r\n                    if os.stat(fullfn).st_size == 0:\r\n                        continue\r\n\r\n                    self.load_from_file(fullfn, fn)\r\n                except Exception:\r\n                    errors.report(f\"Error loading embedding {fn}\", exc_info=True)\r\n                    continue\r\n\r\n    def load_textual_inversion_embeddings(self, force_reload=False):\r\n        if not force_reload:\r\n            need_reload = False\r\n            for embdir in self.embedding_dirs.values():\r\n                if embdir.has_changed():\r\n                    need_reload = True\r\n                    break\r\n\r\n            if not need_reload:\r\n                return\r\n\r\n        self.ids_lookup.clear()\r\n        self.word_embeddings.clear()\r\n        self.skipped_embeddings.clear()\r\n        self.expected_shape = self.get_expected_shape()\r\n\r\n        for embdir in self.embedding_dirs.values():\r\n            self.load_from_dir(embdir)\r\n            embdir.update()\r\n\r\n        # re-sort word_embeddings because load_from_dir may not load in alphabetic order.\r\n        # using a temporary copy so we don't reinitialize self.word_embeddings in case other objects have a reference to it.\r\n        sorted_word_embeddings = {e.name: e for e in sorted(self.word_embeddings.values(), key=lambda e: e.name.lower())}\r\n        self.word_embeddings.clear()\r\n        self.word_embeddings.update(sorted_word_embeddings)\r\n\r\n        displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys()))\r\n        if shared.opts.textual_inversion_print_at_load and self.previously_displayed_embeddings != displayed_embeddings:\r\n            self.previously_displayed_embeddings = displayed_embeddings\r\n            print(f\"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}\")\r\n            if self.skipped_embeddings:\r\n                print(f\"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}\")\r\n\r\n    def find_embedding_at_position(self, tokens, offset):\r\n        token = tokens[offset]\r\n        possible_matches = self.ids_lookup.get(token, None)\r\n\r\n        if possible_matches is None:\r\n            return None, None\r\n\r\n        for ids, embedding in possible_matches:\r\n            if tokens[offset:offset + len(ids)] == ids:\r\n                return embedding, len(ids)\r\n\r\n        return None, None\r\n\r\n\r\ndef create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):\r\n    cond_model = shared.sd_model.cond_stage_model\r\n\r\n    with devices.autocast():\r\n        cond_model([\"\"])  # will send cond model to GPU if lowvram/medvram is active\r\n\r\n    #cond_model expects at least some text, so we provide '*' as backup.\r\n    embedded = cond_model.encode_embedding_init_text(init_text or '*', num_vectors_per_token)\r\n    vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)\r\n\r\n    #Only copy if we provided an init_text, otherwise keep vectors as zeros\r\n    if init_text:\r\n        for i in range(num_vectors_per_token):\r\n            vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]\r\n\r\n    # Remove illegal characters from name.\r\n    name = \"\".join( x for x in name if (x.isalnum() or x in \"._- \"))\r\n    fn = os.path.join(shared.cmd_opts.embeddings_dir, f\"{name}.pt\")\r\n    if not overwrite_old:\r\n        assert not os.path.exists(fn), f\"file {fn} already exists\"\r\n\r\n    embedding = Embedding(vec, name)\r\n    embedding.step = 0\r\n    embedding.save(fn)\r\n\r\n    return fn\r\n\r\n\r\ndef create_embedding_from_data(data, name, filename='unknown embedding file', filepath=None):\r\n    if 'string_to_param' in data:  # textual inversion embeddings\r\n        param_dict = data['string_to_param']\r\n        param_dict = getattr(param_dict, '_parameters', param_dict)  # fix for torch 1.12.1 loading saved file from torch 1.11\r\n        assert len(param_dict) == 1, 'embedding file has multiple terms in it'\r\n        emb = next(iter(param_dict.items()))[1]\r\n        vec = emb.detach().to(devices.device, dtype=torch.float32)\r\n        shape = vec.shape[-1]\r\n        vectors = vec.shape[0]\r\n    elif type(data) == dict and 'clip_g' in data and 'clip_l' in data:  # SDXL embedding\r\n        vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}\r\n        shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]\r\n        vectors = data['clip_g'].shape[0]\r\n    elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:  # diffuser concepts\r\n        assert len(data.keys()) == 1, 'embedding file has multiple terms in it'\r\n\r\n        emb = next(iter(data.values()))\r\n        if len(emb.shape) == 1:\r\n            emb = emb.unsqueeze(0)\r\n        vec = emb.detach().to(devices.device, dtype=torch.float32)\r\n        shape = vec.shape[-1]\r\n        vectors = vec.shape[0]\r\n    else:\r\n        raise Exception(f\"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.\")\r\n\r\n    embedding = Embedding(vec, name)\r\n    embedding.step = data.get('step', None)\r\n    embedding.sd_checkpoint = data.get('sd_checkpoint', None)\r\n    embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)\r\n    embedding.vectors = vectors\r\n    embedding.shape = shape\r\n\r\n    if filepath:\r\n        embedding.filename = filepath\r\n        embedding.set_hash(hashes.sha256(filepath, \"textual_inversion/\" + name) or '')\r\n\r\n    return embedding\r\n\r\n\r\ndef write_loss(log_directory, filename, step, epoch_len, values):\r\n    if shared.opts.training_write_csv_every == 0:\r\n        return\r\n\r\n    if step % shared.opts.training_write_csv_every != 0:\r\n        return\r\n    write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True\r\n\r\n    with open(os.path.join(log_directory, filename), \"a+\", newline='') as fout:\r\n        csv_writer = csv.DictWriter(fout, fieldnames=[\"step\", \"epoch\", \"epoch_step\", *(values.keys())])\r\n\r\n        if write_csv_header:\r\n            csv_writer.writeheader()\r\n\r\n        epoch = (step - 1) // epoch_len\r\n        epoch_step = (step - 1) % epoch_len\r\n\r\n        csv_writer.writerow({\r\n            \"step\": step,\r\n            \"epoch\": epoch,\r\n            \"epoch_step\": epoch_step,\r\n            **values,\r\n        })\r\n\r\ndef tensorboard_setup(log_directory):\r\n    from torch.utils.tensorboard import SummaryWriter\r\n    os.makedirs(os.path.join(log_directory, \"tensorboard\"), exist_ok=True)\r\n    return SummaryWriter(\r\n            log_dir=os.path.join(log_directory, \"tensorboard\"),\r\n            flush_secs=shared.opts.training_tensorboard_flush_every)\r\n\r\ndef tensorboard_add(tensorboard_writer, loss, global_step, step, learn_rate, epoch_num):\r\n    tensorboard_add_scaler(tensorboard_writer, \"Loss/train\", loss, global_step)\r\n    tensorboard_add_scaler(tensorboard_writer, f\"Loss/train/epoch-{epoch_num}\", loss, step)\r\n    tensorboard_add_scaler(tensorboard_writer, \"Learn rate/train\", learn_rate, global_step)\r\n    tensorboard_add_scaler(tensorboard_writer, f\"Learn rate/train/epoch-{epoch_num}\", learn_rate, step)\r\n\r\ndef tensorboard_add_scaler(tensorboard_writer, tag, value, step):\r\n    tensorboard_writer.add_scalar(tag=tag,\r\n        scalar_value=value, global_step=step)\r\n\r\ndef tensorboard_add_image(tensorboard_writer, tag, pil_image, step):\r\n    # Convert a pil image to a torch tensor\r\n    img_tensor = torch.as_tensor(np.array(pil_image, copy=True))\r\n    img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0],\r\n        len(pil_image.getbands()))\r\n    img_tensor = img_tensor.permute((2, 0, 1))\r\n\r\n    tensorboard_writer.add_image(tag, img_tensor, global_step=step)\r\n\r\ndef validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name=\"embedding\"):\r\n    assert model_name, f\"{name} not selected\"\r\n    assert learn_rate, \"Learning rate is empty or 0\"\r\n    assert isinstance(batch_size, int), \"Batch size must be integer\"\r\n    assert batch_size > 0, \"Batch size must be positive\"\r\n    assert isinstance(gradient_step, int), \"Gradient accumulation step must be integer\"\r\n    assert gradient_step > 0, \"Gradient accumulation step must be positive\"\r\n    assert data_root, \"Dataset directory is empty\"\r\n    assert os.path.isdir(data_root), \"Dataset directory doesn't exist\"\r\n    assert os.listdir(data_root), \"Dataset directory is empty\"\r\n    assert template_filename, \"Prompt template file not selected\"\r\n    assert template_file, f\"Prompt template file {template_filename} not found\"\r\n    assert os.path.isfile(template_file.path), f\"Prompt template file {template_filename} doesn't exist\"\r\n    assert steps, \"Max steps is empty or 0\"\r\n    assert isinstance(steps, int), \"Max steps must be integer\"\r\n    assert steps > 0, \"Max steps must be positive\"\r\n    assert isinstance(save_model_every, int), \"Save {name} must be integer\"\r\n    assert save_model_every >= 0, \"Save {name} must be positive or 0\"\r\n    assert isinstance(create_image_every, int), \"Create image must be integer\"\r\n    assert create_image_every >= 0, \"Create image must be positive or 0\"\r\n    if save_model_every or create_image_every:\r\n        assert log_directory, \"Log directory is empty\"\r\n\r\n\r\ndef train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_name, preview_cfg_scale, preview_seed, preview_width, preview_height):\r\n    from modules import processing\r\n\r\n    save_embedding_every = save_embedding_every or 0\r\n    create_image_every = create_image_every or 0\r\n    template_file = textual_inversion_templates.get(template_filename, None)\r\n    validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_embedding_every, create_image_every, log_directory, name=\"embedding\")\r\n    template_file = template_file.path\r\n\r\n    shared.state.job = \"train-embedding\"\r\n    shared.state.textinfo = \"Initializing textual inversion training...\"\r\n    shared.state.job_count = steps\r\n\r\n    filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')\r\n\r\n    log_directory = os.path.join(log_directory, datetime.datetime.now().strftime(\"%Y-%m-%d\"), embedding_name)\r\n    unload = shared.opts.unload_models_when_training\r\n\r\n    if save_embedding_every > 0:\r\n        embedding_dir = os.path.join(log_directory, \"embeddings\")\r\n        os.makedirs(embedding_dir, exist_ok=True)\r\n    else:\r\n        embedding_dir = None\r\n\r\n    if create_image_every > 0:\r\n        images_dir = os.path.join(log_directory, \"images\")\r\n        os.makedirs(images_dir, exist_ok=True)\r\n    else:\r\n        images_dir = None\r\n\r\n    if create_image_every > 0 and save_image_with_stored_embedding:\r\n        images_embeds_dir = os.path.join(log_directory, \"image_embeddings\")\r\n        os.makedirs(images_embeds_dir, exist_ok=True)\r\n    else:\r\n        images_embeds_dir = None\r\n\r\n    hijack = sd_hijack.model_hijack\r\n\r\n    embedding = hijack.embedding_db.word_embeddings[embedding_name]\r\n    checkpoint = sd_models.select_checkpoint()\r\n\r\n    initial_step = embedding.step or 0\r\n    if initial_step >= steps:\r\n        shared.state.textinfo = \"Model has already been trained beyond specified max steps\"\r\n        return embedding, filename\r\n\r\n    scheduler = LearnRateScheduler(learn_rate, steps, initial_step)\r\n    clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == \"value\" else \\\r\n        torch.nn.utils.clip_grad_norm_ if clip_grad_mode == \"norm\" else \\\r\n        None\r\n    if clip_grad:\r\n        clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)\r\n    # dataset loading may take a while, so input validations and early returns should be done before this\r\n    shared.state.textinfo = f\"Preparing dataset from {html.escape(data_root)}...\"\r\n    old_parallel_processing_allowed = shared.parallel_processing_allowed\r\n\r\n    tensorboard_writer = None\r\n    if shared.opts.training_enable_tensorboard:\r\n        try:\r\n            tensorboard_writer = tensorboard_setup(log_directory)\r\n        except ImportError:\r\n            errors.report(\"Error initializing tensorboard\", exc_info=True)\r\n\r\n    pin_memory = shared.opts.pin_memory\r\n\r\n    ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize, use_weight=use_weight)\r\n\r\n    if shared.opts.save_training_settings_to_txt:\r\n        save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()})\r\n\r\n    latent_sampling_method = ds.latent_sampling_method\r\n\r\n    dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)\r\n\r\n    if unload:\r\n        shared.parallel_processing_allowed = False\r\n        shared.sd_model.first_stage_model.to(devices.cpu)\r\n\r\n    embedding.vec.requires_grad = True\r\n    optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)\r\n    if shared.opts.save_optimizer_state:\r\n        optimizer_state_dict = None\r\n        if os.path.exists(f\"{filename}.optim\"):\r\n            optimizer_saved_dict = torch.load(f\"{filename}.optim\", map_location='cpu')\r\n            if embedding.checksum() == optimizer_saved_dict.get('hash', None):\r\n                optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)\r\n\r\n        if optimizer_state_dict is not None:\r\n            optimizer.load_state_dict(optimizer_state_dict)\r\n            print(\"Loaded existing optimizer from checkpoint\")\r\n        else:\r\n            print(\"No saved optimizer exists in checkpoint\")\r\n\r\n    scaler = torch.cuda.amp.GradScaler()\r\n\r\n    batch_size = ds.batch_size\r\n    gradient_step = ds.gradient_step\r\n    # n steps = batch_size * gradient_step * n image processed\r\n    steps_per_epoch = len(ds) // batch_size // gradient_step\r\n    max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step\r\n    loss_step = 0\r\n    _loss_step = 0 #internal\r\n\r\n    last_saved_file = \"<none>\"\r\n    last_saved_image = \"<none>\"\r\n    forced_filename = \"<none>\"\r\n    embedding_yet_to_be_embedded = False\r\n\r\n    is_training_inpainting_model = shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}\r\n    img_c = None\r\n\r\n    pbar = tqdm.tqdm(total=steps - initial_step)\r\n    try:\r\n        sd_hijack_checkpoint.add()\r\n\r\n        for _ in range((steps-initial_step) * gradient_step):\r\n            if scheduler.finished:\r\n                break\r\n            if shared.state.interrupted:\r\n                break\r\n            for j, batch in enumerate(dl):\r\n                # works as a drop_last=True for gradient accumulation\r\n                if j == max_steps_per_epoch:\r\n                    break\r\n                scheduler.apply(optimizer, embedding.step)\r\n                if scheduler.finished:\r\n                    break\r\n                if shared.state.interrupted:\r\n                    break\r\n\r\n                if clip_grad:\r\n                    clip_grad_sched.step(embedding.step)\r\n\r\n                with devices.autocast():\r\n                    x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)\r\n                    if use_weight:\r\n                        w = batch.weight.to(devices.device, non_blocking=pin_memory)\r\n                    c = shared.sd_model.cond_stage_model(batch.cond_text)\r\n\r\n                    if is_training_inpainting_model:\r\n                        if img_c is None:\r\n                            img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height)\r\n\r\n                        cond = {\"c_concat\": [img_c], \"c_crossattn\": [c]}\r\n                    else:\r\n                        cond = c\r\n\r\n                    if use_weight:\r\n                        loss = shared.sd_model.weighted_forward(x, cond, w)[0] / gradient_step\r\n                        del w\r\n                    else:\r\n                        loss = shared.sd_model.forward(x, cond)[0] / gradient_step\r\n                    del x\r\n\r\n                    _loss_step += loss.item()\r\n                scaler.scale(loss).backward()\r\n\r\n                # go back until we reach gradient accumulation steps\r\n                if (j + 1) % gradient_step != 0:\r\n                    continue\r\n\r\n                if clip_grad:\r\n                    clip_grad(embedding.vec, clip_grad_sched.learn_rate)\r\n\r\n                scaler.step(optimizer)\r\n                scaler.update()\r\n                embedding.step += 1\r\n                pbar.update()\r\n                optimizer.zero_grad(set_to_none=True)\r\n                loss_step = _loss_step\r\n                _loss_step = 0\r\n\r\n                steps_done = embedding.step + 1\r\n\r\n                epoch_num = embedding.step // steps_per_epoch\r\n                epoch_step = embedding.step % steps_per_epoch\r\n\r\n                description = f\"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}] loss: {loss_step:.7f}\"\r\n                pbar.set_description(description)\r\n                if embedding_dir is not None and steps_done % save_embedding_every == 0:\r\n                    # Before saving, change name to match current checkpoint.\r\n                    embedding_name_every = f'{embedding_name}-{steps_done}'\r\n                    last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')\r\n                    save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)\r\n                    embedding_yet_to_be_embedded = True\r\n\r\n                write_loss(log_directory, \"textual_inversion_loss.csv\", embedding.step, steps_per_epoch, {\r\n                    \"loss\": f\"{loss_step:.7f}\",\r\n                    \"learn_rate\": scheduler.learn_rate\r\n                })\r\n\r\n                if images_dir is not None and steps_done % create_image_every == 0:\r\n                    forced_filename = f'{embedding_name}-{steps_done}'\r\n                    last_saved_image = os.path.join(images_dir, forced_filename)\r\n\r\n                    shared.sd_model.first_stage_model.to(devices.device)\r\n\r\n                    p = processing.StableDiffusionProcessingTxt2Img(\r\n                        sd_model=shared.sd_model,\r\n                        do_not_save_grid=True,\r\n                        do_not_save_samples=True,\r\n                        do_not_reload_embeddings=True,\r\n                    )\r\n\r\n                    if preview_from_txt2img:\r\n                        p.prompt = preview_prompt\r\n                        p.negative_prompt = preview_negative_prompt\r\n                        p.steps = preview_steps\r\n                        p.sampler_name = sd_samplers.samplers_map[preview_sampler_name.lower()]\r\n                        p.cfg_scale = preview_cfg_scale\r\n                        p.seed = preview_seed\r\n                        p.width = preview_width\r\n                        p.height = preview_height\r\n                    else:\r\n                        p.prompt = batch.cond_text[0]\r\n                        p.steps = 20\r\n                        p.width = training_width\r\n                        p.height = training_height\r\n\r\n                    preview_text = p.prompt\r\n\r\n                    with closing(p):\r\n                        processed = processing.process_images(p)\r\n                        image = processed.images[0] if len(processed.images) > 0 else None\r\n\r\n                    if unload:\r\n                        shared.sd_model.first_stage_model.to(devices.cpu)\r\n\r\n                    if image is not None:\r\n                        shared.state.assign_current_image(image)\r\n\r\n                        last_saved_image, last_text_info = images.save_image(image, images_dir, \"\", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)\r\n                        last_saved_image += f\", prompt: {preview_text}\"\r\n\r\n                        if tensorboard_writer and shared.opts.training_tensorboard_save_images:\r\n                            tensorboard_add_image(tensorboard_writer, f\"Validation at epoch {epoch_num}\", image, embedding.step)\r\n\r\n                    if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:\r\n\r\n                        last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')\r\n\r\n                        info = PngImagePlugin.PngInfo()\r\n                        data = torch.load(last_saved_file)\r\n                        info.add_text(\"sd-ti-embedding\", embedding_to_b64(data))\r\n\r\n                        title = f\"<{data.get('name', '???')}>\"\r\n\r\n                        try:\r\n                            vectorSize = list(data['string_to_param'].values())[0].shape[0]\r\n                        except Exception:\r\n                            vectorSize = '?'\r\n\r\n                        checkpoint = sd_models.select_checkpoint()\r\n                        footer_left = checkpoint.model_name\r\n                        footer_mid = f'[{checkpoint.shorthash}]'\r\n                        footer_right = f'{vectorSize}v {steps_done}s'\r\n\r\n                        captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)\r\n                        captioned_image = insert_image_data_embed(captioned_image, data)\r\n\r\n                        captioned_image.save(last_saved_image_chunks, \"PNG\", pnginfo=info)\r\n                        embedding_yet_to_be_embedded = False\r\n\r\n                    last_saved_image, last_text_info = images.save_image(image, images_dir, \"\", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)\r\n                    last_saved_image += f\", prompt: {preview_text}\"\r\n\r\n                shared.state.job_no = embedding.step\r\n\r\n                shared.state.textinfo = f\"\"\"\r\n<p>\r\nLoss: {loss_step:.7f}<br/>\r\nStep: {steps_done}<br/>\r\nLast prompt: {html.escape(batch.cond_text[0])}<br/>\r\nLast saved embedding: {html.escape(last_saved_file)}<br/>\r\nLast saved image: {html.escape(last_saved_image)}<br/>\r\n</p>\r\n\"\"\"\r\n        filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')\r\n        save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)\r\n    except Exception:\r\n        errors.report(\"Error training embedding\", exc_info=True)\r\n    finally:\r\n        pbar.leave = False\r\n        pbar.close()\r\n        shared.sd_model.first_stage_model.to(devices.device)\r\n        shared.parallel_processing_allowed = old_parallel_processing_allowed\r\n        sd_hijack_checkpoint.remove()\r\n\r\n    return embedding, filename\r\n\r\n\r\ndef save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):\r\n    old_embedding_name = embedding.name\r\n    old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, \"sd_checkpoint\") else None\r\n    old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, \"sd_checkpoint_name\") else None\r\n    old_cached_checksum = embedding.cached_checksum if hasattr(embedding, \"cached_checksum\") else None\r\n    try:\r\n        embedding.sd_checkpoint = checkpoint.shorthash\r\n        embedding.sd_checkpoint_name = checkpoint.model_name\r\n        if remove_cached_checksum:\r\n            embedding.cached_checksum = None\r\n        embedding.name = embedding_name\r\n        embedding.optimizer_state_dict = optimizer.state_dict()\r\n        embedding.save(filename)\r\n    except:\r\n        embedding.sd_checkpoint = old_sd_checkpoint\r\n        embedding.sd_checkpoint_name = old_sd_checkpoint_name\r\n        embedding.name = old_embedding_name\r\n        embedding.cached_checksum = old_cached_checksum\r\n        raise\r\n"
  },
  {
    "path": "modules/textual_inversion/ui.py",
    "content": "import html\r\n\r\nimport gradio as gr\r\n\r\nimport modules.textual_inversion.textual_inversion\r\nfrom modules import sd_hijack, shared\r\n\r\n\r\ndef create_embedding(name, initialization_text, nvpt, overwrite_old):\r\n    filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text)\r\n\r\n    sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()\r\n\r\n    return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f\"Created: {filename}\", \"\"\r\n\r\n\r\ndef train_embedding(*args):\r\n\r\n    assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'\r\n\r\n    apply_optimizations = shared.opts.training_xattention_optimizations\r\n    try:\r\n        if not apply_optimizations:\r\n            sd_hijack.undo_optimizations()\r\n\r\n        embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)\r\n\r\n        res = f\"\"\"\r\nTraining {'interrupted' if shared.state.interrupted else 'finished'} at {embedding.step} steps.\r\nEmbedding saved to {html.escape(filename)}\r\n\"\"\"\r\n        return res, \"\"\r\n    except Exception:\r\n        raise\r\n    finally:\r\n        if not apply_optimizations:\r\n            sd_hijack.apply_optimizations()\r\n\r\n"
  },
  {
    "path": "modules/timer.py",
    "content": "import time\r\nimport argparse\r\n\r\n\r\nclass TimerSubcategory:\r\n    def __init__(self, timer, category):\r\n        self.timer = timer\r\n        self.category = category\r\n        self.start = None\r\n        self.original_base_category = timer.base_category\r\n\r\n    def __enter__(self):\r\n        self.start = time.time()\r\n        self.timer.base_category = self.original_base_category + self.category + \"/\"\r\n        self.timer.subcategory_level += 1\r\n\r\n        if self.timer.print_log:\r\n            print(f\"{'  ' * self.timer.subcategory_level}{self.category}:\")\r\n\r\n    def __exit__(self, exc_type, exc_val, exc_tb):\r\n        elapsed_for_subcategroy = time.time() - self.start\r\n        self.timer.base_category = self.original_base_category\r\n        self.timer.add_time_to_record(self.original_base_category + self.category, elapsed_for_subcategroy)\r\n        self.timer.subcategory_level -= 1\r\n        self.timer.record(self.category, disable_log=True)\r\n\r\n\r\nclass Timer:\r\n    def __init__(self, print_log=False):\r\n        self.start = time.time()\r\n        self.records = {}\r\n        self.total = 0\r\n        self.base_category = ''\r\n        self.print_log = print_log\r\n        self.subcategory_level = 0\r\n\r\n    def elapsed(self):\r\n        end = time.time()\r\n        res = end - self.start\r\n        self.start = end\r\n        return res\r\n\r\n    def add_time_to_record(self, category, amount):\r\n        if category not in self.records:\r\n            self.records[category] = 0\r\n\r\n        self.records[category] += amount\r\n\r\n    def record(self, category, extra_time=0, disable_log=False):\r\n        e = self.elapsed()\r\n\r\n        self.add_time_to_record(self.base_category + category, e + extra_time)\r\n\r\n        self.total += e + extra_time\r\n\r\n        if self.print_log and not disable_log:\r\n            print(f\"{'  ' * self.subcategory_level}{category}: done in {e + extra_time:.3f}s\")\r\n\r\n    def subcategory(self, name):\r\n        self.elapsed()\r\n\r\n        subcat = TimerSubcategory(self, name)\r\n        return subcat\r\n\r\n    def summary(self):\r\n        res = f\"{self.total:.1f}s\"\r\n\r\n        additions = [(category, time_taken) for category, time_taken in self.records.items() if time_taken >= 0.1 and '/' not in category]\r\n        if not additions:\r\n            return res\r\n\r\n        res += \" (\"\r\n        res += \", \".join([f\"{category}: {time_taken:.1f}s\" for category, time_taken in additions])\r\n        res += \")\"\r\n\r\n        return res\r\n\r\n    def dump(self):\r\n        return {'total': self.total, 'records': self.records}\r\n\r\n    def reset(self):\r\n        self.__init__()\r\n\r\n\r\nparser = argparse.ArgumentParser(add_help=False)\r\nparser.add_argument(\"--log-startup\", action='store_true', help=\"print a detailed log of what's happening at startup\")\r\nargs = parser.parse_known_args()[0]\r\n\r\nstartup_timer = Timer(print_log=args.log_startup)\r\n\r\nstartup_record = None\r\n"
  },
  {
    "path": "modules/torch_utils.py",
    "content": "from __future__ import annotations\n\nimport torch.nn\nimport torch\n\n\ndef get_param(model) -> torch.nn.Parameter:\n    \"\"\"\n    Find the first parameter in a model or module.\n    \"\"\"\n    if hasattr(model, \"model\") and hasattr(model.model, \"parameters\"):\n        # Unpeel a model descriptor to get at the actual Torch module.\n        model = model.model\n\n    for param in model.parameters():\n        return param\n\n    raise ValueError(f\"No parameters found in model {model!r}\")\n\n\ndef float64(t: torch.Tensor):\n    \"\"\"return torch.float64 if device is not mps or xpu, else return torch.float32\"\"\"\n    if t.device.type in ['mps', 'xpu']:\n        return torch.float32\n    return torch.float64\n"
  },
  {
    "path": "modules/txt2img.py",
    "content": "import json\r\nfrom contextlib import closing\r\n\r\nimport modules.scripts\r\nfrom modules import processing, infotext_utils\r\nfrom modules.infotext_utils import create_override_settings_dict, parse_generation_parameters\r\nfrom modules.shared import opts\r\nimport modules.shared as shared\r\nfrom modules.ui import plaintext_to_html\r\nfrom PIL import Image\r\nimport gradio as gr\r\n\r\n\r\ndef txt2img_create_processing(id_task: str, request: gr.Request, prompt: str, negative_prompt: str, prompt_styles, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_scheduler: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args, force_enable_hr=False):\r\n    override_settings = create_override_settings_dict(override_settings_texts)\r\n\r\n    if force_enable_hr:\r\n        enable_hr = True\r\n\r\n    p = processing.StableDiffusionProcessingTxt2Img(\r\n        sd_model=shared.sd_model,\r\n        outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,\r\n        outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,\r\n        prompt=prompt,\r\n        styles=prompt_styles,\r\n        negative_prompt=negative_prompt,\r\n        batch_size=batch_size,\r\n        n_iter=n_iter,\r\n        cfg_scale=cfg_scale,\r\n        width=width,\r\n        height=height,\r\n        enable_hr=enable_hr,\r\n        denoising_strength=denoising_strength,\r\n        hr_scale=hr_scale,\r\n        hr_upscaler=hr_upscaler,\r\n        hr_second_pass_steps=hr_second_pass_steps,\r\n        hr_resize_x=hr_resize_x,\r\n        hr_resize_y=hr_resize_y,\r\n        hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name,\r\n        hr_sampler_name=None if hr_sampler_name == 'Use same sampler' else hr_sampler_name,\r\n        hr_scheduler=None if hr_scheduler == 'Use same scheduler' else hr_scheduler,\r\n        hr_prompt=hr_prompt,\r\n        hr_negative_prompt=hr_negative_prompt,\r\n        override_settings=override_settings,\r\n    )\r\n\r\n    p.scripts = modules.scripts.scripts_txt2img\r\n    p.script_args = args\r\n\r\n    p.user = request.username\r\n\r\n    if shared.opts.enable_console_prompts:\r\n        print(f\"\\ntxt2img: {prompt}\", file=shared.progress_print_out)\r\n\r\n    return p\r\n\r\n\r\ndef txt2img_upscale(id_task: str, request: gr.Request, gallery, gallery_index, generation_info, *args):\r\n    assert len(gallery) > 0, 'No image to upscale'\r\n    assert 0 <= gallery_index < len(gallery), f'Bad image index: {gallery_index}'\r\n\r\n    p = txt2img_create_processing(id_task, request, *args, force_enable_hr=True)\r\n    p.batch_size = 1\r\n    p.n_iter = 1\r\n    # txt2img_upscale attribute that signifies this is called by txt2img_upscale\r\n    p.txt2img_upscale = True\r\n\r\n    geninfo = json.loads(generation_info)\r\n\r\n    image_info = gallery[gallery_index] if 0 <= gallery_index < len(gallery) else gallery[0]\r\n    p.firstpass_image = infotext_utils.image_from_url_text(image_info)\r\n\r\n    parameters = parse_generation_parameters(geninfo.get('infotexts')[gallery_index], [])\r\n    p.seed = parameters.get('Seed', -1)\r\n    p.subseed = parameters.get('Variation seed', -1)\r\n\r\n    p.override_settings['save_images_before_highres_fix'] = False\r\n\r\n    with closing(p):\r\n        processed = modules.scripts.scripts_txt2img.run(p, *p.script_args)\r\n\r\n        if processed is None:\r\n            processed = processing.process_images(p)\r\n\r\n    shared.total_tqdm.clear()\r\n\r\n    new_gallery = []\r\n    for i, image in enumerate(gallery):\r\n        if i == gallery_index:\r\n            geninfo[\"infotexts\"][gallery_index: gallery_index+1] = processed.infotexts\r\n            new_gallery.extend(processed.images)\r\n        else:\r\n            fake_image = Image.new(mode=\"RGB\", size=(1, 1))\r\n            fake_image.already_saved_as = image[\"name\"].rsplit('?', 1)[0]\r\n            new_gallery.append(fake_image)\r\n\r\n    geninfo[\"infotexts\"][gallery_index] = processed.info\r\n\r\n    return new_gallery, json.dumps(geninfo), plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname=\"comments\")\r\n\r\n\r\ndef txt2img(id_task: str, request: gr.Request, *args):\r\n    p = txt2img_create_processing(id_task, request, *args)\r\n\r\n    with closing(p):\r\n        processed = modules.scripts.scripts_txt2img.run(p, *p.script_args)\r\n\r\n        if processed is None:\r\n            processed = processing.process_images(p)\r\n\r\n    shared.total_tqdm.clear()\r\n\r\n    generation_info_js = processed.js()\r\n    if opts.samples_log_stdout:\r\n        print(generation_info_js)\r\n\r\n    if opts.do_not_show_images:\r\n        processed.images = []\r\n\r\n    return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments, classname=\"comments\")\r\n"
  },
  {
    "path": "modules/ui.py",
    "content": "import datetime\r\nimport mimetypes\r\nimport os\r\nimport sys\r\nfrom functools import reduce\r\nimport warnings\r\nfrom contextlib import ExitStack\r\n\r\nimport gradio as gr\r\nimport gradio.utils\r\nimport numpy as np\r\nfrom PIL import Image, PngImagePlugin  # noqa: F401\r\nfrom modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call, wrap_gradio_call_no_job # noqa: F401\r\n\r\nfrom modules import gradio_extensons, sd_schedulers  # noqa: F401\r\nfrom modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, scripts, sd_samplers, processing, ui_extra_networks, ui_toprow, launch_utils\r\nfrom modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion, ResizeHandleRow\r\nfrom modules.paths import script_path\r\nfrom modules.ui_common import create_refresh_button\r\nfrom modules.ui_gradio_extensions import reload_javascript\r\n\r\nfrom modules.shared import opts, cmd_opts\r\n\r\nimport modules.infotext_utils as parameters_copypaste\r\nimport modules.hypernetworks.ui as hypernetworks_ui\r\nimport modules.textual_inversion.ui as textual_inversion_ui\r\nimport modules.textual_inversion.textual_inversion as textual_inversion\r\nimport modules.shared as shared\r\nfrom modules import prompt_parser\r\nfrom modules.sd_hijack import model_hijack\r\nfrom modules.infotext_utils import image_from_url_text, PasteField\r\n\r\ncreate_setting_component = ui_settings.create_setting_component\r\n\r\nwarnings.filterwarnings(\"default\" if opts.show_warnings else \"ignore\", category=UserWarning)\r\nwarnings.filterwarnings(\"default\" if opts.show_gradio_deprecation_warnings else \"ignore\", category=gr.deprecation.GradioDeprecationWarning)\r\n\r\n# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI\r\nmimetypes.init()\r\nmimetypes.add_type('application/javascript', '.js')\r\nmimetypes.add_type('application/javascript', '.mjs')\r\n\r\n# Likewise, add explicit content-type header for certain missing image types\r\nmimetypes.add_type('image/webp', '.webp')\r\nmimetypes.add_type('image/avif', '.avif')\r\n\r\nif not cmd_opts.share and not cmd_opts.listen:\r\n    # fix gradio phoning home\r\n    gradio.utils.version_check = lambda: None\r\n    gradio.utils.get_local_ip_address = lambda: '127.0.0.1'\r\n\r\nif cmd_opts.ngrok is not None:\r\n    import modules.ngrok as ngrok\r\n    print('ngrok authtoken detected, trying to connect...')\r\n    ngrok.connect(\r\n        cmd_opts.ngrok,\r\n        cmd_opts.port if cmd_opts.port is not None else 7860,\r\n        cmd_opts.ngrok_options\r\n        )\r\n\r\n\r\ndef gr_show(visible=True):\r\n    return {\"visible\": visible, \"__type__\": \"update\"}\r\n\r\n\r\nsample_img2img = \"assets/stable-samples/img2img/sketch-mountains-input.jpg\"\r\nsample_img2img = sample_img2img if os.path.exists(sample_img2img) else None\r\n\r\n# Using constants for these since the variation selector isn't visible.\r\n# Important that they exactly match script.js for tooltip to work.\r\nrandom_symbol = '\\U0001f3b2\\ufe0f'  # 🎲️\r\nreuse_symbol = '\\u267b\\ufe0f'  # ♻️\r\npaste_symbol = '\\u2199\\ufe0f'  # ↙\r\nrefresh_symbol = '\\U0001f504'  # 🔄\r\nsave_style_symbol = '\\U0001f4be'  # 💾\r\napply_style_symbol = '\\U0001f4cb'  # 📋\r\nclear_prompt_symbol = '\\U0001f5d1\\ufe0f'  # 🗑️\r\nextra_networks_symbol = '\\U0001F3B4'  # 🎴\r\nswitch_values_symbol = '\\U000021C5' # ⇅\r\nrestore_progress_symbol = '\\U0001F300' # 🌀\r\ndetect_image_size_symbol = '\\U0001F4D0'  # 📐\r\n\r\n\r\nplaintext_to_html = ui_common.plaintext_to_html\r\n\r\n\r\ndef send_gradio_gallery_to_image(x):\r\n    if len(x) == 0:\r\n        return None\r\n    return image_from_url_text(x[0])\r\n\r\n\r\ndef calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y):\r\n    if not enable:\r\n        return \"\"\r\n\r\n    p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y)\r\n    p.calculate_target_resolution()\r\n\r\n    return f\"from <span class='resolution'>{p.width}x{p.height}</span> to <span class='resolution'>{p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}</span>\"\r\n\r\n\r\ndef resize_from_to_html(width, height, scale_by):\r\n    target_width = int(width * scale_by)\r\n    target_height = int(height * scale_by)\r\n\r\n    if not target_width or not target_height:\r\n        return \"no image selected\"\r\n\r\n    return f\"resize: from <span class='resolution'>{width}x{height}</span> to <span class='resolution'>{target_width}x{target_height}</span>\"\r\n\r\n\r\ndef process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_dir, *ii_singles):\r\n    if mode in {0, 1, 3, 4}:\r\n        return [interrogation_function(ii_singles[mode]), None]\r\n    elif mode == 2:\r\n        return [interrogation_function(ii_singles[mode][\"image\"]), None]\r\n    elif mode == 5:\r\n        assert not shared.cmd_opts.hide_ui_dir_config, \"Launched with --hide-ui-dir-config, batch img2img disabled\"\r\n        images = shared.listfiles(ii_input_dir)\r\n        print(f\"Will process {len(images)} images.\")\r\n        if ii_output_dir != \"\":\r\n            os.makedirs(ii_output_dir, exist_ok=True)\r\n        else:\r\n            ii_output_dir = ii_input_dir\r\n\r\n        for image in images:\r\n            img = Image.open(image)\r\n            filename = os.path.basename(image)\r\n            left, _ = os.path.splitext(filename)\r\n            print(interrogation_function(img), file=open(os.path.join(ii_output_dir, f\"{left}.txt\"), 'a', encoding='utf-8'))\r\n\r\n        return [gr.update(), None]\r\n\r\n\r\ndef interrogate(image):\r\n    prompt = shared.interrogator.interrogate(image.convert(\"RGB\"))\r\n    return gr.update() if prompt is None else prompt\r\n\r\n\r\ndef interrogate_deepbooru(image):\r\n    prompt = deepbooru.model.tag(image)\r\n    return gr.update() if prompt is None else prompt\r\n\r\n\r\ndef connect_clear_prompt(button):\r\n    \"\"\"Given clear button, prompt, and token_counter objects, setup clear prompt button click event\"\"\"\r\n    button.click(\r\n        _js=\"clear_prompt\",\r\n        fn=None,\r\n        inputs=[],\r\n        outputs=[],\r\n    )\r\n\r\n\r\ndef update_token_counter(text, steps, styles, *, is_positive=True):\r\n    params = script_callbacks.BeforeTokenCounterParams(text, steps, styles, is_positive=is_positive)\r\n    script_callbacks.before_token_counter_callback(params)\r\n    text = params.prompt\r\n    steps = params.steps\r\n    styles = params.styles\r\n    is_positive = params.is_positive\r\n\r\n    if shared.opts.include_styles_into_token_counters:\r\n        apply_styles = shared.prompt_styles.apply_styles_to_prompt if is_positive else shared.prompt_styles.apply_negative_styles_to_prompt\r\n        text = apply_styles(text, styles)\r\n\r\n    try:\r\n        text, _ = extra_networks.parse_prompt(text)\r\n\r\n        if is_positive:\r\n            _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text])\r\n        else:\r\n            prompt_flat_list = [text]\r\n\r\n        prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps)\r\n\r\n    except Exception:\r\n        # a parsing error can happen here during typing, and we don't want to bother the user with\r\n        # messages related to it in console\r\n        prompt_schedules = [[[steps, text]]]\r\n\r\n    flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)\r\n    prompts = [prompt_text for step, prompt_text in flat_prompts]\r\n    token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0])\r\n    return f\"<span class='gr-box gr-text-input'>{token_count}/{max_length}</span>\"\r\n\r\n\r\ndef update_negative_prompt_token_counter(*args):\r\n    return update_token_counter(*args, is_positive=False)\r\n\r\n\r\ndef setup_progressbar(*args, **kwargs):\r\n    pass\r\n\r\n\r\ndef apply_setting(key, value):\r\n    if value is None:\r\n        return gr.update()\r\n\r\n    if shared.cmd_opts.freeze_settings:\r\n        return gr.update()\r\n\r\n    # dont allow model to be swapped when model hash exists in prompt\r\n    if key == \"sd_model_checkpoint\" and opts.disable_weights_auto_swap:\r\n        return gr.update()\r\n\r\n    if key == \"sd_model_checkpoint\":\r\n        ckpt_info = sd_models.get_closet_checkpoint_match(value)\r\n\r\n        if ckpt_info is not None:\r\n            value = ckpt_info.title\r\n        else:\r\n            return gr.update()\r\n\r\n    comp_args = opts.data_labels[key].component_args\r\n    if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:\r\n        return\r\n\r\n    valtype = type(opts.data_labels[key].default)\r\n    oldval = opts.data.get(key, None)\r\n    opts.data[key] = valtype(value) if valtype != type(None) else value\r\n    if oldval != value and opts.data_labels[key].onchange is not None:\r\n        opts.data_labels[key].onchange()\r\n\r\n    opts.save(shared.config_filename)\r\n    return getattr(opts, key)\r\n\r\n\r\ndef create_output_panel(tabname, outdir, toprow=None):\r\n    return ui_common.create_output_panel(tabname, outdir, toprow)\r\n\r\n\r\ndef ordered_ui_categories():\r\n    user_order = {x.strip(): i * 2 + 1 for i, x in enumerate(shared.opts.ui_reorder_list)}\r\n\r\n    for _, category in sorted(enumerate(shared_items.ui_reorder_categories()), key=lambda x: user_order.get(x[1], x[0] * 2 + 0)):\r\n        yield category\r\n\r\n\r\ndef create_override_settings_dropdown(tabname, row):\r\n    dropdown = gr.Dropdown([], label=\"Override settings\", visible=False, elem_id=f\"{tabname}_override_settings\", multiselect=True)\r\n\r\n    dropdown.change(\r\n        fn=lambda x: gr.Dropdown.update(visible=bool(x)),\r\n        inputs=[dropdown],\r\n        outputs=[dropdown],\r\n    )\r\n\r\n    return dropdown\r\n\r\n\r\ndef create_ui():\r\n    import modules.img2img\r\n    import modules.txt2img\r\n\r\n    reload_javascript()\r\n\r\n    parameters_copypaste.reset()\r\n\r\n    settings = ui_settings.UiSettings()\r\n    settings.register_settings()\r\n\r\n    scripts.scripts_current = scripts.scripts_txt2img\r\n    scripts.scripts_txt2img.initialize_scripts(is_img2img=False)\r\n\r\n    with gr.Blocks(analytics_enabled=False) as txt2img_interface:\r\n        toprow = ui_toprow.Toprow(is_img2img=False, is_compact=shared.opts.compact_prompt_box)\r\n\r\n        dummy_component = gr.Label(visible=False)\r\n\r\n        extra_tabs = gr.Tabs(elem_id=\"txt2img_extra_tabs\", elem_classes=[\"extra-networks\"])\r\n        extra_tabs.__enter__()\r\n\r\n        with gr.Tab(\"Generation\", id=\"txt2img_generation\") as txt2img_generation_tab, ResizeHandleRow(equal_height=False):\r\n            with ExitStack() as stack:\r\n                if shared.opts.txt2img_settings_accordion:\r\n                    stack.enter_context(gr.Accordion(\"Open for Settings\", open=False))\r\n                stack.enter_context(gr.Column(variant='compact', elem_id=\"txt2img_settings\"))\r\n\r\n                scripts.scripts_txt2img.prepare_ui()\r\n\r\n                for category in ordered_ui_categories():\r\n                    if category == \"prompt\":\r\n                        toprow.create_inline_toprow_prompts()\r\n\r\n                    elif category == \"dimensions\":\r\n                        with FormRow():\r\n                            with gr.Column(elem_id=\"txt2img_column_size\", scale=4):\r\n                                width = gr.Slider(minimum=64, maximum=2048, step=8, label=\"Width\", value=512, elem_id=\"txt2img_width\")\r\n                                height = gr.Slider(minimum=64, maximum=2048, step=8, label=\"Height\", value=512, elem_id=\"txt2img_height\")\r\n\r\n                            with gr.Column(elem_id=\"txt2img_dimensions_row\", scale=1, elem_classes=\"dimensions-tools\"):\r\n                                res_switch_btn = ToolButton(value=switch_values_symbol, elem_id=\"txt2img_res_switch_btn\", tooltip=\"Switch width/height\")\r\n\r\n                            if opts.dimensions_and_batch_together:\r\n                                with gr.Column(elem_id=\"txt2img_column_batch\"):\r\n                                    batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id=\"txt2img_batch_count\")\r\n                                    batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id=\"txt2img_batch_size\")\r\n\r\n                    elif category == \"cfg\":\r\n                        with gr.Row():\r\n                            cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id=\"txt2img_cfg_scale\")\r\n\r\n                    elif category == \"checkboxes\":\r\n                        with FormRow(elem_classes=\"checkboxes-row\", variant=\"compact\"):\r\n                            pass\r\n\r\n                    elif category == \"accordions\":\r\n                        with gr.Row(elem_id=\"txt2img_accordions\", elem_classes=\"accordions\"):\r\n                            with InputAccordion(False, label=\"Hires. fix\", elem_id=\"txt2img_hr\") as enable_hr:\r\n                                with enable_hr.extra():\r\n                                    hr_final_resolution = FormHTML(value=\"\", elem_id=\"txtimg_hr_finalres\", label=\"Upscaled resolution\", interactive=False, min_width=0)\r\n\r\n                                with FormRow(elem_id=\"txt2img_hires_fix_row1\", variant=\"compact\"):\r\n                                    hr_upscaler = gr.Dropdown(label=\"Upscaler\", elem_id=\"txt2img_hr_upscaler\", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)\r\n                                    hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id=\"txt2img_hires_steps\")\r\n                                    denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id=\"txt2img_denoising_strength\")\r\n\r\n                                with FormRow(elem_id=\"txt2img_hires_fix_row2\", variant=\"compact\"):\r\n                                    hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label=\"Upscale by\", value=2.0, elem_id=\"txt2img_hr_scale\")\r\n                                    hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label=\"Resize width to\", value=0, elem_id=\"txt2img_hr_resize_x\")\r\n                                    hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label=\"Resize height to\", value=0, elem_id=\"txt2img_hr_resize_y\")\r\n\r\n                                with FormRow(elem_id=\"txt2img_hires_fix_row3\", variant=\"compact\", visible=opts.hires_fix_show_sampler) as hr_sampler_container:\r\n\r\n                                    hr_checkpoint_name = gr.Dropdown(label='Checkpoint', elem_id=\"hr_checkpoint\", choices=[\"Use same checkpoint\"] + modules.sd_models.checkpoint_tiles(use_short=True), value=\"Use same checkpoint\")\r\n                                    create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {\"choices\": [\"Use same checkpoint\"] + modules.sd_models.checkpoint_tiles(use_short=True)}, \"hr_checkpoint_refresh\")\r\n\r\n                                    hr_sampler_name = gr.Dropdown(label='Hires sampling method', elem_id=\"hr_sampler\", choices=[\"Use same sampler\"] + sd_samplers.visible_sampler_names(), value=\"Use same sampler\")\r\n                                    hr_scheduler = gr.Dropdown(label='Hires schedule type', elem_id=\"hr_scheduler\", choices=[\"Use same scheduler\"] + [x.label for x in sd_schedulers.schedulers], value=\"Use same scheduler\")\r\n\r\n                                with FormRow(elem_id=\"txt2img_hires_fix_row4\", variant=\"compact\", visible=opts.hires_fix_show_prompts) as hr_prompts_container:\r\n                                    with gr.Column(scale=80):\r\n                                        with gr.Row():\r\n                                            hr_prompt = gr.Textbox(label=\"Hires prompt\", elem_id=\"hires_prompt\", show_label=False, lines=3, placeholder=\"Prompt for hires fix pass.\\nLeave empty to use the same prompt as in first pass.\", elem_classes=[\"prompt\"])\r\n                                    with gr.Column(scale=80):\r\n                                        with gr.Row():\r\n                                            hr_negative_prompt = gr.Textbox(label=\"Hires negative prompt\", elem_id=\"hires_neg_prompt\", show_label=False, lines=3, placeholder=\"Negative prompt for hires fix pass.\\nLeave empty to use the same negative prompt as in first pass.\", elem_classes=[\"prompt\"])\r\n\r\n                            scripts.scripts_txt2img.setup_ui_for_section(category)\r\n\r\n                    elif category == \"batch\":\r\n                        if not opts.dimensions_and_batch_together:\r\n                            with FormRow(elem_id=\"txt2img_column_batch\"):\r\n                                batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id=\"txt2img_batch_count\")\r\n                                batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id=\"txt2img_batch_size\")\r\n\r\n                    elif category == \"override_settings\":\r\n                        with FormRow(elem_id=\"txt2img_override_settings_row\") as row:\r\n                            override_settings = create_override_settings_dropdown('txt2img', row)\r\n\r\n                    elif category == \"scripts\":\r\n                        with FormGroup(elem_id=\"txt2img_script_container\"):\r\n                            custom_inputs = scripts.scripts_txt2img.setup_ui()\r\n\r\n                    if category not in {\"accordions\"}:\r\n                        scripts.scripts_txt2img.setup_ui_for_section(category)\r\n\r\n            hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]\r\n\r\n            for component in hr_resolution_preview_inputs:\r\n                event = component.release if isinstance(component, gr.Slider) else component.change\r\n\r\n                event(\r\n                    fn=calc_resolution_hires,\r\n                    inputs=hr_resolution_preview_inputs,\r\n                    outputs=[hr_final_resolution],\r\n                    show_progress=False,\r\n                )\r\n                event(\r\n                    None,\r\n                    _js=\"onCalcResolutionHires\",\r\n                    inputs=hr_resolution_preview_inputs,\r\n                    outputs=[],\r\n                    show_progress=False,\r\n                )\r\n\r\n            output_panel = create_output_panel(\"txt2img\", opts.outdir_txt2img_samples, toprow)\r\n\r\n            txt2img_inputs = [\r\n                dummy_component,\r\n                toprow.prompt,\r\n                toprow.negative_prompt,\r\n                toprow.ui_styles.dropdown,\r\n                batch_count,\r\n                batch_size,\r\n                cfg_scale,\r\n                height,\r\n                width,\r\n                enable_hr,\r\n                denoising_strength,\r\n                hr_scale,\r\n                hr_upscaler,\r\n                hr_second_pass_steps,\r\n                hr_resize_x,\r\n                hr_resize_y,\r\n                hr_checkpoint_name,\r\n                hr_sampler_name,\r\n                hr_scheduler,\r\n                hr_prompt,\r\n                hr_negative_prompt,\r\n                override_settings,\r\n            ] + custom_inputs\r\n\r\n            txt2img_outputs = [\r\n                output_panel.gallery,\r\n                output_panel.generation_info,\r\n                output_panel.infotext,\r\n                output_panel.html_log,\r\n            ]\r\n\r\n            txt2img_args = dict(\r\n                fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),\r\n                _js=\"submit\",\r\n                inputs=txt2img_inputs,\r\n                outputs=txt2img_outputs,\r\n                show_progress=False,\r\n            )\r\n\r\n            toprow.prompt.submit(**txt2img_args)\r\n            toprow.submit.click(**txt2img_args)\r\n\r\n            output_panel.button_upscale.click(\r\n                fn=wrap_gradio_gpu_call(modules.txt2img.txt2img_upscale, extra_outputs=[None, '', '']),\r\n                _js=\"submit_txt2img_upscale\",\r\n                inputs=txt2img_inputs[0:1] + [output_panel.gallery, dummy_component, output_panel.generation_info] + txt2img_inputs[1:],\r\n                outputs=txt2img_outputs,\r\n                show_progress=False,\r\n            )\r\n\r\n            res_switch_btn.click(fn=None, _js=\"function(){switchWidthHeight('txt2img')}\", inputs=None, outputs=None, show_progress=False)\r\n\r\n            toprow.restore_progress_button.click(\r\n                fn=progress.restore_progress,\r\n                _js=\"restoreProgressTxt2img\",\r\n                inputs=[dummy_component],\r\n                outputs=[\r\n                    output_panel.gallery,\r\n                    output_panel.generation_info,\r\n                    output_panel.infotext,\r\n                    output_panel.html_log,\r\n                ],\r\n                show_progress=False,\r\n            )\r\n\r\n            txt2img_paste_fields = [\r\n                PasteField(toprow.prompt, \"Prompt\", api=\"prompt\"),\r\n                PasteField(toprow.negative_prompt, \"Negative prompt\", api=\"negative_prompt\"),\r\n                PasteField(cfg_scale, \"CFG scale\", api=\"cfg_scale\"),\r\n                PasteField(width, \"Size-1\", api=\"width\"),\r\n                PasteField(height, \"Size-2\", api=\"height\"),\r\n                PasteField(batch_size, \"Batch size\", api=\"batch_size\"),\r\n                PasteField(toprow.ui_styles.dropdown, lambda d: d[\"Styles array\"] if isinstance(d.get(\"Styles array\"), list) else gr.update(), api=\"styles\"),\r\n                PasteField(denoising_strength, \"Denoising strength\", api=\"denoising_strength\"),\r\n                PasteField(enable_hr, lambda d: \"Denoising strength\" in d and (\"Hires upscale\" in d or \"Hires upscaler\" in d or \"Hires resize-1\" in d), api=\"enable_hr\"),\r\n                PasteField(hr_scale, \"Hires upscale\", api=\"hr_scale\"),\r\n                PasteField(hr_upscaler, \"Hires upscaler\", api=\"hr_upscaler\"),\r\n                PasteField(hr_second_pass_steps, \"Hires steps\", api=\"hr_second_pass_steps\"),\r\n                PasteField(hr_resize_x, \"Hires resize-1\", api=\"hr_resize_x\"),\r\n                PasteField(hr_resize_y, \"Hires resize-2\", api=\"hr_resize_y\"),\r\n                PasteField(hr_checkpoint_name, \"Hires checkpoint\", api=\"hr_checkpoint_name\"),\r\n                PasteField(hr_sampler_name, sd_samplers.get_hr_sampler_from_infotext, api=\"hr_sampler_name\"),\r\n                PasteField(hr_scheduler, sd_samplers.get_hr_scheduler_from_infotext, api=\"hr_scheduler\"),\r\n                PasteField(hr_sampler_container, lambda d: gr.update(visible=True) if d.get(\"Hires sampler\", \"Use same sampler\") != \"Use same sampler\" or d.get(\"Hires checkpoint\", \"Use same checkpoint\") != \"Use same checkpoint\" or d.get(\"Hires schedule type\", \"Use same scheduler\") != \"Use same scheduler\" else gr.update()),\r\n                PasteField(hr_prompt, \"Hires prompt\", api=\"hr_prompt\"),\r\n                PasteField(hr_negative_prompt, \"Hires negative prompt\", api=\"hr_negative_prompt\"),\r\n                PasteField(hr_prompts_container, lambda d: gr.update(visible=True) if d.get(\"Hires prompt\", \"\") != \"\" or d.get(\"Hires negative prompt\", \"\") != \"\" else gr.update()),\r\n                *scripts.scripts_txt2img.infotext_fields\r\n            ]\r\n            parameters_copypaste.add_paste_fields(\"txt2img\", None, txt2img_paste_fields, override_settings)\r\n            parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(\r\n                paste_button=toprow.paste, tabname=\"txt2img\", source_text_component=toprow.prompt, source_image_component=None,\r\n            ))\r\n\r\n            steps = scripts.scripts_txt2img.script('Sampler').steps\r\n\r\n            txt2img_preview_params = [\r\n                toprow.prompt,\r\n                toprow.negative_prompt,\r\n                steps,\r\n                scripts.scripts_txt2img.script('Sampler').sampler_name,\r\n                cfg_scale,\r\n                scripts.scripts_txt2img.script('Seed').seed,\r\n                width,\r\n                height,\r\n            ]\r\n\r\n            toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter])\r\n            toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter])\r\n            toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter])\r\n            toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter])\r\n\r\n        extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img')\r\n        ui_extra_networks.setup_ui(extra_networks_ui, output_panel.gallery)\r\n\r\n        extra_tabs.__exit__()\r\n\r\n    scripts.scripts_current = scripts.scripts_img2img\r\n    scripts.scripts_img2img.initialize_scripts(is_img2img=True)\r\n\r\n    with gr.Blocks(analytics_enabled=False) as img2img_interface:\r\n        toprow = ui_toprow.Toprow(is_img2img=True, is_compact=shared.opts.compact_prompt_box)\r\n\r\n        extra_tabs = gr.Tabs(elem_id=\"img2img_extra_tabs\", elem_classes=[\"extra-networks\"])\r\n        extra_tabs.__enter__()\r\n\r\n        with gr.Tab(\"Generation\", id=\"img2img_generation\") as img2img_generation_tab, ResizeHandleRow(equal_height=False):\r\n            with ExitStack() as stack:\r\n                if shared.opts.img2img_settings_accordion:\r\n                    stack.enter_context(gr.Accordion(\"Open for Settings\", open=False))\r\n                stack.enter_context(gr.Column(variant='compact', elem_id=\"img2img_settings\"))\r\n\r\n                copy_image_buttons = []\r\n                copy_image_destinations = {}\r\n\r\n                def add_copy_image_controls(tab_name, elem):\r\n                    with gr.Row(variant=\"compact\", elem_id=f\"img2img_copy_to_{tab_name}\"):\r\n                        gr.HTML(\"Copy image to: \", elem_id=f\"img2img_label_copy_to_{tab_name}\")\r\n\r\n                        for title, name in zip(['img2img', 'sketch', 'inpaint', 'inpaint sketch'], ['img2img', 'sketch', 'inpaint', 'inpaint_sketch']):\r\n                            if name == tab_name:\r\n                                gr.Button(title, interactive=False)\r\n                                copy_image_destinations[name] = elem\r\n                                continue\r\n\r\n                            button = gr.Button(title)\r\n                            copy_image_buttons.append((button, name, elem))\r\n\r\n                scripts.scripts_img2img.prepare_ui()\r\n\r\n                for category in ordered_ui_categories():\r\n                    if category == \"prompt\":\r\n                        toprow.create_inline_toprow_prompts()\r\n\r\n                    if category == \"image\":\r\n                        with gr.Tabs(elem_id=\"mode_img2img\"):\r\n                            img2img_selected_tab = gr.Number(value=0, visible=False)\r\n\r\n                            with gr.TabItem('img2img', id='img2img', elem_id=\"img2img_img2img_tab\") as tab_img2img:\r\n                                init_img = gr.Image(label=\"Image for img2img\", elem_id=\"img2img_image\", show_label=False, source=\"upload\", interactive=True, type=\"pil\", tool=\"editor\", image_mode=\"RGBA\", height=opts.img2img_editor_height)\r\n                                add_copy_image_controls('img2img', init_img)\r\n\r\n                            with gr.TabItem('Sketch', id='img2img_sketch', elem_id=\"img2img_img2img_sketch_tab\") as tab_sketch:\r\n                                sketch = gr.Image(label=\"Image for img2img\", elem_id=\"img2img_sketch\", show_label=False, source=\"upload\", interactive=True, type=\"pil\", tool=\"color-sketch\", image_mode=\"RGB\", height=opts.img2img_editor_height, brush_color=opts.img2img_sketch_default_brush_color)\r\n                                add_copy_image_controls('sketch', sketch)\r\n\r\n                            with gr.TabItem('Inpaint', id='inpaint', elem_id=\"img2img_inpaint_tab\") as tab_inpaint:\r\n                                init_img_with_mask = gr.Image(label=\"Image for inpainting with mask\", show_label=False, elem_id=\"img2maskimg\", source=\"upload\", interactive=True, type=\"pil\", tool=\"sketch\", image_mode=\"RGBA\", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_mask_brush_color)\r\n                                add_copy_image_controls('inpaint', init_img_with_mask)\r\n\r\n                            with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id=\"img2img_inpaint_sketch_tab\") as tab_inpaint_color:\r\n                                inpaint_color_sketch = gr.Image(label=\"Color sketch inpainting\", show_label=False, elem_id=\"inpaint_sketch\", source=\"upload\", interactive=True, type=\"pil\", tool=\"color-sketch\", image_mode=\"RGB\", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color)\r\n                                inpaint_color_sketch_orig = gr.State(None)\r\n                                add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)\r\n\r\n                                def update_orig(image, state):\r\n                                    if image is not None:\r\n                                        same_size = state is not None and state.size == image.size\r\n                                        has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1))\r\n                                        edited = same_size and has_exact_match\r\n                                        return image if not edited or state is None else state\r\n\r\n                                inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig)\r\n\r\n                            with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id=\"img2img_inpaint_upload_tab\") as tab_inpaint_upload:\r\n                                init_img_inpaint = gr.Image(label=\"Image for img2img\", show_label=False, source=\"upload\", interactive=True, type=\"pil\", elem_id=\"img_inpaint_base\")\r\n                                init_mask_inpaint = gr.Image(label=\"Mask\", source=\"upload\", interactive=True, type=\"pil\", image_mode=\"RGBA\", elem_id=\"img_inpaint_mask\")\r\n\r\n                            with gr.TabItem('Batch', id='batch', elem_id=\"img2img_batch_tab\") as tab_batch:\r\n                                with gr.Tabs(elem_id=\"img2img_batch_source\"):\r\n                                    img2img_batch_source_type = gr.Textbox(visible=False, value=\"upload\")\r\n                                    with gr.TabItem('Upload', id='batch_upload', elem_id=\"img2img_batch_upload_tab\") as tab_batch_upload:\r\n                                        img2img_batch_upload = gr.Files(label=\"Files\", interactive=True, elem_id=\"img2img_batch_upload\")\r\n                                    with gr.TabItem('From directory', id='batch_from_dir', elem_id=\"img2img_batch_from_dir_tab\") as tab_batch_from_dir:\r\n                                        hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else ''\r\n                                        gr.HTML(\r\n                                            \"<p style='padding-bottom: 1em;' class=\\\"text-gray-500\\\">Process images in a directory on the same machine where the server is running.\" +\r\n                                            \"<br>Use an empty output directory to save pictures normally instead of writing to the output directory.\" +\r\n                                            f\"<br>Add inpaint batch mask directory to enable inpaint batch processing.\"\r\n                                            f\"{hidden}</p>\"\r\n                                        )\r\n                                        img2img_batch_input_dir = gr.Textbox(label=\"Input directory\", **shared.hide_dirs, elem_id=\"img2img_batch_input_dir\")\r\n                                        img2img_batch_output_dir = gr.Textbox(label=\"Output directory\", **shared.hide_dirs, elem_id=\"img2img_batch_output_dir\")\r\n                                        img2img_batch_inpaint_mask_dir = gr.Textbox(label=\"Inpaint batch mask directory (required for inpaint batch processing only)\", **shared.hide_dirs, elem_id=\"img2img_batch_inpaint_mask_dir\")\r\n                                tab_batch_upload.select(fn=lambda: \"upload\", inputs=[], outputs=[img2img_batch_source_type])\r\n                                tab_batch_from_dir.select(fn=lambda: \"from dir\", inputs=[], outputs=[img2img_batch_source_type])\r\n                                with gr.Accordion(\"PNG info\", open=False):\r\n                                    img2img_batch_use_png_info = gr.Checkbox(label=\"Append png info to prompts\", elem_id=\"img2img_batch_use_png_info\")\r\n                                    img2img_batch_png_info_dir = gr.Textbox(label=\"PNG info directory\", **shared.hide_dirs, placeholder=\"Leave empty to use input directory\", elem_id=\"img2img_batch_png_info_dir\")\r\n                                    img2img_batch_png_info_props = gr.CheckboxGroup([\"Prompt\", \"Negative prompt\", \"Seed\", \"CFG scale\", \"Sampler\", \"Steps\", \"Model hash\"], label=\"Parameters to take from png info\", info=\"Prompts from png info will be appended to prompts set in ui.\")\r\n\r\n                            img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]\r\n\r\n                            for i, tab in enumerate(img2img_tabs):\r\n                                tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[img2img_selected_tab])\r\n\r\n                        def copy_image(img):\r\n                            if isinstance(img, dict) and 'image' in img:\r\n                                return img['image']\r\n\r\n                            return img\r\n\r\n                        for button, name, elem in copy_image_buttons:\r\n                            button.click(\r\n                                fn=copy_image,\r\n                                inputs=[elem],\r\n                                outputs=[copy_image_destinations[name]],\r\n                            )\r\n                            button.click(\r\n                                fn=lambda: None,\r\n                                _js=f\"switch_to_{name.replace(' ', '_')}\",\r\n                                inputs=[],\r\n                                outputs=[],\r\n                            )\r\n\r\n                        with FormRow():\r\n                            resize_mode = gr.Radio(label=\"Resize mode\", elem_id=\"resize_mode\", choices=[\"Just resize\", \"Crop and resize\", \"Resize and fill\", \"Just resize (latent upscale)\"], type=\"index\", value=\"Just resize\")\r\n\r\n                    elif category == \"dimensions\":\r\n                        with FormRow():\r\n                            with gr.Column(elem_id=\"img2img_column_size\", scale=4):\r\n                                selected_scale_tab = gr.Number(value=0, visible=False)\r\n\r\n                                with gr.Tabs(elem_id=\"img2img_tabs_resize\"):\r\n                                    with gr.Tab(label=\"Resize to\", id=\"to\", elem_id=\"img2img_tab_resize_to\") as tab_scale_to:\r\n                                        with FormRow():\r\n                                            with gr.Column(elem_id=\"img2img_column_size\", scale=4):\r\n                                                width = gr.Slider(minimum=64, maximum=2048, step=8, label=\"Width\", value=512, elem_id=\"img2img_width\")\r\n                                                height = gr.Slider(minimum=64, maximum=2048, step=8, label=\"Height\", value=512, elem_id=\"img2img_height\")\r\n                                            with gr.Column(elem_id=\"img2img_dimensions_row\", scale=1, elem_classes=\"dimensions-tools\"):\r\n                                                res_switch_btn = ToolButton(value=switch_values_symbol, elem_id=\"img2img_res_switch_btn\", tooltip=\"Switch width/height\")\r\n                                                detect_image_size_btn = ToolButton(value=detect_image_size_symbol, elem_id=\"img2img_detect_image_size_btn\", tooltip=\"Auto detect size from img2img\")\r\n\r\n                                    with gr.Tab(label=\"Resize by\", id=\"by\", elem_id=\"img2img_tab_resize_by\") as tab_scale_by:\r\n                                        scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label=\"Scale\", value=1.0, elem_id=\"img2img_scale\")\r\n\r\n                                        with FormRow():\r\n                                            scale_by_html = FormHTML(resize_from_to_html(0, 0, 0.0), elem_id=\"img2img_scale_resolution_preview\")\r\n                                            gr.Slider(label=\"Unused\", elem_id=\"img2img_unused_scale_by_slider\")\r\n                                            button_update_resize_to = gr.Button(visible=False, elem_id=\"img2img_update_resize_to\")\r\n\r\n                                    on_change_args = dict(\r\n                                        fn=resize_from_to_html,\r\n                                        _js=\"currentImg2imgSourceResolution\",\r\n                                        inputs=[dummy_component, dummy_component, scale_by],\r\n                                        outputs=scale_by_html,\r\n                                        show_progress=False,\r\n                                    )\r\n\r\n                                    scale_by.release(**on_change_args)\r\n                                    button_update_resize_to.click(**on_change_args)\r\n\r\n                            tab_scale_to.select(fn=lambda: 0, inputs=[], outputs=[selected_scale_tab])\r\n                            tab_scale_by.select(fn=lambda: 1, inputs=[], outputs=[selected_scale_tab])\r\n\r\n                            if opts.dimensions_and_batch_together:\r\n                                with gr.Column(elem_id=\"img2img_column_batch\"):\r\n                                    batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id=\"img2img_batch_count\")\r\n                                    batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id=\"img2img_batch_size\")\r\n\r\n                    elif category == \"denoising\":\r\n                        denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id=\"img2img_denoising_strength\")\r\n\r\n                    elif category == \"cfg\":\r\n                        with gr.Row():\r\n                            cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id=\"img2img_cfg_scale\")\r\n                            image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id=\"img2img_image_cfg_scale\", visible=False)\r\n\r\n                    elif category == \"checkboxes\":\r\n                        with FormRow(elem_classes=\"checkboxes-row\", variant=\"compact\"):\r\n                            pass\r\n\r\n                    elif category == \"accordions\":\r\n                        with gr.Row(elem_id=\"img2img_accordions\", elem_classes=\"accordions\"):\r\n                            scripts.scripts_img2img.setup_ui_for_section(category)\r\n\r\n                    elif category == \"batch\":\r\n                        if not opts.dimensions_and_batch_together:\r\n                            with FormRow(elem_id=\"img2img_column_batch\"):\r\n                                batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id=\"img2img_batch_count\")\r\n                                batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id=\"img2img_batch_size\")\r\n\r\n                    elif category == \"override_settings\":\r\n                        with FormRow(elem_id=\"img2img_override_settings_row\") as row:\r\n                            override_settings = create_override_settings_dropdown('img2img', row)\r\n\r\n                    elif category == \"scripts\":\r\n                        with FormGroup(elem_id=\"img2img_script_container\"):\r\n                            custom_inputs = scripts.scripts_img2img.setup_ui()\r\n\r\n                    elif category == \"inpaint\":\r\n                        with FormGroup(elem_id=\"inpaint_controls\", visible=False) as inpaint_controls:\r\n                            with FormRow():\r\n                                mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=\"img2img_mask_blur\")\r\n                                mask_alpha = gr.Slider(label=\"Mask transparency\", visible=False, elem_id=\"img2img_mask_alpha\")\r\n\r\n                            with FormRow():\r\n                                inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type=\"index\", elem_id=\"img2img_mask_mode\")\r\n\r\n                            with FormRow():\r\n                                inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type=\"index\", elem_id=\"img2img_inpainting_fill\")\r\n\r\n                            with FormRow():\r\n                                with gr.Column():\r\n                                    inpaint_full_res = gr.Radio(label=\"Inpaint area\", choices=[\"Whole picture\", \"Only masked\"], type=\"index\", value=\"Whole picture\", elem_id=\"img2img_inpaint_full_res\")\r\n\r\n                                with gr.Column(scale=4):\r\n                                    inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id=\"img2img_inpaint_full_res_padding\")\r\n\r\n                    if category not in {\"accordions\"}:\r\n                        scripts.scripts_img2img.setup_ui_for_section(category)\r\n\r\n            # the code below is meant to update the resolution label after the image in the image selection UI has changed.\r\n            # as it is now the event keeps firing continuously for inpaint edits, which ruins the page with constant requests.\r\n            # I assume this must be a gradio bug and for now we'll just do it for non-inpaint inputs.\r\n            for component in [init_img, sketch]:\r\n                component.change(fn=lambda: None, _js=\"updateImg2imgResizeToTextAfterChangingImage\", inputs=[], outputs=[], show_progress=False)\r\n\r\n            def select_img2img_tab(tab):\r\n                return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3),\r\n\r\n            for i, elem in enumerate(img2img_tabs):\r\n                elem.select(\r\n                    fn=lambda tab=i: select_img2img_tab(tab),\r\n                    inputs=[],\r\n                    outputs=[inpaint_controls, mask_alpha],\r\n                )\r\n\r\n            output_panel = create_output_panel(\"img2img\", opts.outdir_img2img_samples, toprow)\r\n\r\n            img2img_args = dict(\r\n                fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),\r\n                _js=\"submit_img2img\",\r\n                inputs=[\r\n                    dummy_component,\r\n                    dummy_component,\r\n                    toprow.prompt,\r\n                    toprow.negative_prompt,\r\n                    toprow.ui_styles.dropdown,\r\n                    init_img,\r\n                    sketch,\r\n                    init_img_with_mask,\r\n                    inpaint_color_sketch,\r\n                    inpaint_color_sketch_orig,\r\n                    init_img_inpaint,\r\n                    init_mask_inpaint,\r\n                    mask_blur,\r\n                    mask_alpha,\r\n                    inpainting_fill,\r\n                    batch_count,\r\n                    batch_size,\r\n                    cfg_scale,\r\n                    image_cfg_scale,\r\n                    denoising_strength,\r\n                    selected_scale_tab,\r\n                    height,\r\n                    width,\r\n                    scale_by,\r\n                    resize_mode,\r\n                    inpaint_full_res,\r\n                    inpaint_full_res_padding,\r\n                    inpainting_mask_invert,\r\n                    img2img_batch_input_dir,\r\n                    img2img_batch_output_dir,\r\n                    img2img_batch_inpaint_mask_dir,\r\n                    override_settings,\r\n                    img2img_batch_use_png_info,\r\n                    img2img_batch_png_info_props,\r\n                    img2img_batch_png_info_dir,\r\n                    img2img_batch_source_type,\r\n                    img2img_batch_upload,\r\n                ] + custom_inputs,\r\n                outputs=[\r\n                    output_panel.gallery,\r\n                    output_panel.generation_info,\r\n                    output_panel.infotext,\r\n                    output_panel.html_log,\r\n                ],\r\n                show_progress=False,\r\n            )\r\n\r\n            interrogate_args = dict(\r\n                _js=\"get_img2img_tab_index\",\r\n                inputs=[\r\n                    dummy_component,\r\n                    img2img_batch_input_dir,\r\n                    img2img_batch_output_dir,\r\n                    init_img,\r\n                    sketch,\r\n                    init_img_with_mask,\r\n                    inpaint_color_sketch,\r\n                    init_img_inpaint,\r\n                ],\r\n                outputs=[toprow.prompt, dummy_component],\r\n            )\r\n\r\n            toprow.prompt.submit(**img2img_args)\r\n            toprow.submit.click(**img2img_args)\r\n\r\n            res_switch_btn.click(fn=None, _js=\"function(){switchWidthHeight('img2img')}\", inputs=None, outputs=None, show_progress=False)\r\n\r\n            detect_image_size_btn.click(\r\n                fn=lambda w, h, _: (w or gr.update(), h or gr.update()),\r\n                _js=\"currentImg2imgSourceResolution\",\r\n                inputs=[dummy_component, dummy_component, dummy_component],\r\n                outputs=[width, height],\r\n                show_progress=False,\r\n            )\r\n\r\n            toprow.restore_progress_button.click(\r\n                fn=progress.restore_progress,\r\n                _js=\"restoreProgressImg2img\",\r\n                inputs=[dummy_component],\r\n                outputs=[\r\n                    output_panel.gallery,\r\n                    output_panel.generation_info,\r\n                    output_panel.infotext,\r\n                    output_panel.html_log,\r\n                ],\r\n                show_progress=False,\r\n            )\r\n\r\n            toprow.button_interrogate.click(\r\n                fn=lambda *args: process_interrogate(interrogate, *args),\r\n                **interrogate_args,\r\n            )\r\n\r\n            toprow.button_deepbooru.click(\r\n                fn=lambda *args: process_interrogate(interrogate_deepbooru, *args),\r\n                **interrogate_args,\r\n            )\r\n\r\n            steps = scripts.scripts_img2img.script('Sampler').steps\r\n\r\n            toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter])\r\n            toprow.ui_styles.dropdown.change(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter])\r\n            toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.token_counter])\r\n            toprow.negative_token_button.click(fn=wrap_queued_call(update_negative_prompt_token_counter), inputs=[toprow.negative_prompt, steps, toprow.ui_styles.dropdown], outputs=[toprow.negative_token_counter])\r\n\r\n            img2img_paste_fields = [\r\n                (toprow.prompt, \"Prompt\"),\r\n                (toprow.negative_prompt, \"Negative prompt\"),\r\n                (cfg_scale, \"CFG scale\"),\r\n                (image_cfg_scale, \"Image CFG scale\"),\r\n                (width, \"Size-1\"),\r\n                (height, \"Size-2\"),\r\n                (batch_size, \"Batch size\"),\r\n                (toprow.ui_styles.dropdown, lambda d: d[\"Styles array\"] if isinstance(d.get(\"Styles array\"), list) else gr.update()),\r\n                (denoising_strength, \"Denoising strength\"),\r\n                (mask_blur, \"Mask blur\"),\r\n                (inpainting_mask_invert, 'Mask mode'),\r\n                (inpainting_fill, 'Masked content'),\r\n                (inpaint_full_res, 'Inpaint area'),\r\n                (inpaint_full_res_padding, 'Masked area padding'),\r\n                *scripts.scripts_img2img.infotext_fields\r\n            ]\r\n            parameters_copypaste.add_paste_fields(\"img2img\", init_img, img2img_paste_fields, override_settings)\r\n            parameters_copypaste.add_paste_fields(\"inpaint\", init_img_with_mask, img2img_paste_fields, override_settings)\r\n            parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(\r\n                paste_button=toprow.paste, tabname=\"img2img\", source_text_component=toprow.prompt, source_image_component=None,\r\n            ))\r\n\r\n        extra_networks_ui_img2img = ui_extra_networks.create_ui(img2img_interface, [img2img_generation_tab], 'img2img')\r\n        ui_extra_networks.setup_ui(extra_networks_ui_img2img, output_panel.gallery)\r\n\r\n        extra_tabs.__exit__()\r\n\r\n    scripts.scripts_current = None\r\n\r\n    with gr.Blocks(analytics_enabled=False) as extras_interface:\r\n        ui_postprocessing.create_ui()\r\n\r\n    with gr.Blocks(analytics_enabled=False) as pnginfo_interface:\r\n        with ResizeHandleRow(equal_height=False):\r\n            with gr.Column(variant='panel'):\r\n                image = gr.Image(elem_id=\"pnginfo_image\", label=\"Source\", source=\"upload\", interactive=True, type=\"pil\")\r\n\r\n            with gr.Column(variant='panel'):\r\n                html = gr.HTML()\r\n                generation_info = gr.Textbox(visible=False, elem_id=\"pnginfo_generation_info\")\r\n                html2 = gr.HTML()\r\n                with gr.Row():\r\n                    buttons = parameters_copypaste.create_buttons([\"txt2img\", \"img2img\", \"inpaint\", \"extras\"])\r\n\r\n                for tabname, button in buttons.items():\r\n                    parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(\r\n                        paste_button=button, tabname=tabname, source_text_component=generation_info, source_image_component=image,\r\n                    ))\r\n\r\n        image.change(\r\n            fn=wrap_gradio_call_no_job(modules.extras.run_pnginfo),\r\n            inputs=[image],\r\n            outputs=[html, generation_info, html2],\r\n        )\r\n\r\n    modelmerger_ui = ui_checkpoint_merger.UiCheckpointMerger()\r\n\r\n    with gr.Blocks(analytics_enabled=False) as train_interface:\r\n        with gr.Row(equal_height=False):\r\n            gr.HTML(value=\"<p style='margin-bottom: 0.7em'>See <b><a href=\\\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\\\">wiki</a></b> for detailed explanation.</p>\")\r\n\r\n        with ResizeHandleRow(variant=\"compact\", equal_height=False):\r\n            with gr.Tabs(elem_id=\"train_tabs\"):\r\n\r\n                with gr.Tab(label=\"Create embedding\", id=\"create_embedding\"):\r\n                    new_embedding_name = gr.Textbox(label=\"Name\", elem_id=\"train_new_embedding_name\")\r\n                    initialization_text = gr.Textbox(label=\"Initialization text\", value=\"*\", elem_id=\"train_initialization_text\")\r\n                    nvpt = gr.Slider(label=\"Number of vectors per token\", minimum=1, maximum=75, step=1, value=1, elem_id=\"train_nvpt\")\r\n                    overwrite_old_embedding = gr.Checkbox(value=False, label=\"Overwrite Old Embedding\", elem_id=\"train_overwrite_old_embedding\")\r\n\r\n                    with gr.Row():\r\n                        with gr.Column(scale=3):\r\n                            gr.HTML(value=\"\")\r\n\r\n                        with gr.Column():\r\n                            create_embedding = gr.Button(value=\"Create embedding\", variant='primary', elem_id=\"train_create_embedding\")\r\n\r\n                with gr.Tab(label=\"Create hypernetwork\", id=\"create_hypernetwork\"):\r\n                    new_hypernetwork_name = gr.Textbox(label=\"Name\", elem_id=\"train_new_hypernetwork_name\")\r\n                    new_hypernetwork_sizes = gr.CheckboxGroup(label=\"Modules\", value=[\"768\", \"320\", \"640\", \"1280\"], choices=[\"768\", \"1024\", \"320\", \"640\", \"1280\"], elem_id=\"train_new_hypernetwork_sizes\")\r\n                    new_hypernetwork_layer_structure = gr.Textbox(\"1, 2, 1\", label=\"Enter hypernetwork layer structure\", placeholder=\"1st and last digit must be 1. ex:'1, 2, 1'\", elem_id=\"train_new_hypernetwork_layer_structure\")\r\n                    new_hypernetwork_activation_func = gr.Dropdown(value=\"linear\", label=\"Select activation function of hypernetwork. Recommended : Swish / Linear(none)\", choices=hypernetworks_ui.keys, elem_id=\"train_new_hypernetwork_activation_func\")\r\n                    new_hypernetwork_initialization_option = gr.Dropdown(value = \"Normal\", label=\"Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise\", choices=[\"Normal\", \"KaimingUniform\", \"KaimingNormal\", \"XavierUniform\", \"XavierNormal\"], elem_id=\"train_new_hypernetwork_initialization_option\")\r\n                    new_hypernetwork_add_layer_norm = gr.Checkbox(label=\"Add layer normalization\", elem_id=\"train_new_hypernetwork_add_layer_norm\")\r\n                    new_hypernetwork_use_dropout = gr.Checkbox(label=\"Use dropout\", elem_id=\"train_new_hypernetwork_use_dropout\")\r\n                    new_hypernetwork_dropout_structure = gr.Textbox(\"0, 0, 0\", label=\"Enter hypernetwork Dropout structure (or empty). Recommended : 0~0.35 incrementing sequence: 0, 0.05, 0.15\", placeholder=\"1st and last digit must be 0 and values should be between 0 and 1. ex:'0, 0.01, 0'\")\r\n                    overwrite_old_hypernetwork = gr.Checkbox(value=False, label=\"Overwrite Old Hypernetwork\", elem_id=\"train_overwrite_old_hypernetwork\")\r\n\r\n                    with gr.Row():\r\n                        with gr.Column(scale=3):\r\n                            gr.HTML(value=\"\")\r\n\r\n                        with gr.Column():\r\n                            create_hypernetwork = gr.Button(value=\"Create hypernetwork\", variant='primary', elem_id=\"train_create_hypernetwork\")\r\n\r\n                def get_textual_inversion_template_names():\r\n                    return sorted(textual_inversion.textual_inversion_templates)\r\n\r\n                with gr.Tab(label=\"Train\", id=\"train\"):\r\n                    gr.HTML(value=\"<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\\\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\\\" style=\\\"font-weight:bold;\\\">[wiki]</a></p>\")\r\n                    with FormRow():\r\n                        train_embedding_name = gr.Dropdown(label='Embedding', elem_id=\"train_embedding\", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))\r\n                        create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {\"choices\": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, \"refresh_train_embedding_name\")\r\n\r\n                        train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id=\"train_hypernetwork\", choices=sorted(shared.hypernetworks))\r\n                        create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {\"choices\": sorted(shared.hypernetworks)}, \"refresh_train_hypernetwork_name\")\r\n\r\n                    with FormRow():\r\n                        embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder=\"Embedding Learning rate\", value=\"0.005\", elem_id=\"train_embedding_learn_rate\")\r\n                        hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder=\"Hypernetwork Learning rate\", value=\"0.00001\", elem_id=\"train_hypernetwork_learn_rate\")\r\n\r\n                    with FormRow():\r\n                        clip_grad_mode = gr.Dropdown(value=\"disabled\", label=\"Gradient Clipping\", choices=[\"disabled\", \"value\", \"norm\"])\r\n                        clip_grad_value = gr.Textbox(placeholder=\"Gradient clip value\", value=\"0.1\", show_label=False)\r\n\r\n                    with FormRow():\r\n                        batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id=\"train_batch_size\")\r\n                        gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id=\"train_gradient_step\")\r\n\r\n                    dataset_directory = gr.Textbox(label='Dataset directory', placeholder=\"Path to directory with input images\", elem_id=\"train_dataset_directory\")\r\n                    log_directory = gr.Textbox(label='Log directory', placeholder=\"Path to directory where to write outputs\", value=\"textual_inversion\", elem_id=\"train_log_directory\")\r\n\r\n                    with FormRow():\r\n                        template_file = gr.Dropdown(label='Prompt template', value=\"style_filewords.txt\", elem_id=\"train_template_file\", choices=get_textual_inversion_template_names())\r\n                        create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {\"choices\": get_textual_inversion_template_names()}, \"refrsh_train_template_file\")\r\n\r\n                    training_width = gr.Slider(minimum=64, maximum=2048, step=8, label=\"Width\", value=512, elem_id=\"train_training_width\")\r\n                    training_height = gr.Slider(minimum=64, maximum=2048, step=8, label=\"Height\", value=512, elem_id=\"train_training_height\")\r\n                    varsize = gr.Checkbox(label=\"Do not resize images\", value=False, elem_id=\"train_varsize\")\r\n                    steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id=\"train_steps\")\r\n\r\n                    with FormRow():\r\n                        create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id=\"train_create_image_every\")\r\n                        save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id=\"train_save_embedding_every\")\r\n\r\n                    use_weight = gr.Checkbox(label=\"Use PNG alpha channel as loss weight\", value=False, elem_id=\"use_weight\")\r\n\r\n                    save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id=\"train_save_image_with_stored_embedding\")\r\n                    preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id=\"train_preview_from_txt2img\")\r\n\r\n                    shuffle_tags = gr.Checkbox(label=\"Shuffle tags by ',' when creating prompts.\", value=False, elem_id=\"train_shuffle_tags\")\r\n                    tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label=\"Drop out tags when creating prompts.\", value=0, elem_id=\"train_tag_drop_out\")\r\n\r\n                    latent_sampling_method = gr.Radio(label='Choose latent sampling method', value=\"once\", choices=['once', 'deterministic', 'random'], elem_id=\"train_latent_sampling_method\")\r\n\r\n                    with gr.Row():\r\n                        train_embedding = gr.Button(value=\"Train Embedding\", variant='primary', elem_id=\"train_train_embedding\")\r\n                        interrupt_training = gr.Button(value=\"Interrupt\", elem_id=\"train_interrupt_training\")\r\n                        train_hypernetwork = gr.Button(value=\"Train Hypernetwork\", variant='primary', elem_id=\"train_train_hypernetwork\")\r\n\r\n                params = script_callbacks.UiTrainTabParams(txt2img_preview_params)\r\n\r\n                script_callbacks.ui_train_tabs_callback(params)\r\n\r\n            with gr.Column(elem_id='ti_gallery_container'):\r\n                ti_output = gr.Text(elem_id=\"ti_output\", value=\"\", show_label=False)\r\n                gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery', columns=4)\r\n                gr.HTML(elem_id=\"ti_progress\", value=\"\")\r\n                ti_outcome = gr.HTML(elem_id=\"ti_error\", value=\"\")\r\n\r\n        create_embedding.click(\r\n            fn=textual_inversion_ui.create_embedding,\r\n            inputs=[\r\n                new_embedding_name,\r\n                initialization_text,\r\n                nvpt,\r\n                overwrite_old_embedding,\r\n            ],\r\n            outputs=[\r\n                train_embedding_name,\r\n                ti_output,\r\n                ti_outcome,\r\n            ]\r\n        )\r\n\r\n        create_hypernetwork.click(\r\n            fn=hypernetworks_ui.create_hypernetwork,\r\n            inputs=[\r\n                new_hypernetwork_name,\r\n                new_hypernetwork_sizes,\r\n                overwrite_old_hypernetwork,\r\n                new_hypernetwork_layer_structure,\r\n                new_hypernetwork_activation_func,\r\n                new_hypernetwork_initialization_option,\r\n                new_hypernetwork_add_layer_norm,\r\n                new_hypernetwork_use_dropout,\r\n                new_hypernetwork_dropout_structure\r\n            ],\r\n            outputs=[\r\n                train_hypernetwork_name,\r\n                ti_output,\r\n                ti_outcome,\r\n            ]\r\n        )\r\n\r\n        train_embedding.click(\r\n            fn=wrap_gradio_gpu_call(textual_inversion_ui.train_embedding, extra_outputs=[gr.update()]),\r\n            _js=\"start_training_textual_inversion\",\r\n            inputs=[\r\n                dummy_component,\r\n                train_embedding_name,\r\n                embedding_learn_rate,\r\n                batch_size,\r\n                gradient_step,\r\n                dataset_directory,\r\n                log_directory,\r\n                training_width,\r\n                training_height,\r\n                varsize,\r\n                steps,\r\n                clip_grad_mode,\r\n                clip_grad_value,\r\n                shuffle_tags,\r\n                tag_drop_out,\r\n                latent_sampling_method,\r\n                use_weight,\r\n                create_image_every,\r\n                save_embedding_every,\r\n                template_file,\r\n                save_image_with_stored_embedding,\r\n                preview_from_txt2img,\r\n                *txt2img_preview_params,\r\n            ],\r\n            outputs=[\r\n                ti_output,\r\n                ti_outcome,\r\n            ]\r\n        )\r\n\r\n        train_hypernetwork.click(\r\n            fn=wrap_gradio_gpu_call(hypernetworks_ui.train_hypernetwork, extra_outputs=[gr.update()]),\r\n            _js=\"start_training_textual_inversion\",\r\n            inputs=[\r\n                dummy_component,\r\n                train_hypernetwork_name,\r\n                hypernetwork_learn_rate,\r\n                batch_size,\r\n                gradient_step,\r\n                dataset_directory,\r\n                log_directory,\r\n                training_width,\r\n                training_height,\r\n                varsize,\r\n                steps,\r\n                clip_grad_mode,\r\n                clip_grad_value,\r\n                shuffle_tags,\r\n                tag_drop_out,\r\n                latent_sampling_method,\r\n                use_weight,\r\n                create_image_every,\r\n                save_embedding_every,\r\n                template_file,\r\n                preview_from_txt2img,\r\n                *txt2img_preview_params,\r\n            ],\r\n            outputs=[\r\n                ti_output,\r\n                ti_outcome,\r\n            ]\r\n        )\r\n\r\n        interrupt_training.click(\r\n            fn=lambda: shared.state.interrupt(),\r\n            inputs=[],\r\n            outputs=[],\r\n        )\r\n\r\n    loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file)\r\n    ui_settings_from_file = loadsave.ui_settings.copy()\r\n\r\n    settings.create_ui(loadsave, dummy_component)\r\n\r\n    interfaces = [\r\n        (txt2img_interface, \"txt2img\", \"txt2img\"),\r\n        (img2img_interface, \"img2img\", \"img2img\"),\r\n        (extras_interface, \"Extras\", \"extras\"),\r\n        (pnginfo_interface, \"PNG Info\", \"pnginfo\"),\r\n        (modelmerger_ui.blocks, \"Checkpoint Merger\", \"modelmerger\"),\r\n        (train_interface, \"Train\", \"train\"),\r\n    ]\r\n\r\n    interfaces += script_callbacks.ui_tabs_callback()\r\n    interfaces += [(settings.interface, \"Settings\", \"settings\")]\r\n\r\n    extensions_interface = ui_extensions.create_ui()\r\n    interfaces += [(extensions_interface, \"Extensions\", \"extensions\")]\r\n\r\n    shared.tab_names = []\r\n    for _interface, label, _ifid in interfaces:\r\n        shared.tab_names.append(label)\r\n\r\n    with gr.Blocks(theme=shared.gradio_theme, analytics_enabled=False, title=\"Stable Diffusion\") as demo:\r\n        settings.add_quicksettings()\r\n\r\n        parameters_copypaste.connect_paste_params_buttons()\r\n\r\n        with gr.Tabs(elem_id=\"tabs\") as tabs:\r\n            tab_order = {k: i for i, k in enumerate(opts.ui_tab_order)}\r\n            sorted_interfaces = sorted(interfaces, key=lambda x: tab_order.get(x[1], 9999))\r\n\r\n            for interface, label, ifid in sorted_interfaces:\r\n                if label in shared.opts.hidden_tabs:\r\n                    continue\r\n                with gr.TabItem(label, id=ifid, elem_id=f\"tab_{ifid}\"):\r\n                    interface.render()\r\n\r\n                if ifid not in [\"extensions\", \"settings\"]:\r\n                    loadsave.add_block(interface, ifid)\r\n\r\n            loadsave.add_component(f\"webui/Tabs@{tabs.elem_id}\", tabs)\r\n\r\n            loadsave.setup_ui()\r\n\r\n        if os.path.exists(os.path.join(script_path, \"notification.mp3\")) and shared.opts.notification_audio:\r\n            gr.Audio(interactive=False, value=os.path.join(script_path, \"notification.mp3\"), elem_id=\"audio_notification\", visible=False)\r\n\r\n        footer = shared.html(\"footer.html\")\r\n        footer = footer.format(versions=versions_html(), api_docs=\"/docs\" if shared.cmd_opts.api else \"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API\")\r\n        gr.HTML(footer, elem_id=\"footer\")\r\n\r\n        settings.add_functionality(demo)\r\n\r\n        update_image_cfg_scale_visibility = lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == \"edit\")\r\n        settings.text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])\r\n        demo.load(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])\r\n\r\n        modelmerger_ui.setup_ui(dummy_component=dummy_component, sd_model_checkpoint_component=settings.component_dict['sd_model_checkpoint'])\r\n\r\n    if ui_settings_from_file != loadsave.ui_settings:\r\n        loadsave.dump_defaults()\r\n    demo.ui_loadsave = loadsave\r\n\r\n    return demo\r\n\r\n\r\ndef versions_html():\r\n    import torch\r\n    import launch\r\n\r\n    python_version = \".\".join([str(x) for x in sys.version_info[0:3]])\r\n    commit = launch.commit_hash()\r\n    tag = launch.git_tag()\r\n\r\n    if shared.xformers_available:\r\n        import xformers\r\n        xformers_version = xformers.__version__\r\n    else:\r\n        xformers_version = \"N/A\"\r\n\r\n    return f\"\"\"\r\nversion: <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/{commit}\">{tag}</a>\r\n&#x2000;•&#x2000;\r\npython: <span title=\"{sys.version}\">{python_version}</span>\r\n&#x2000;•&#x2000;\r\ntorch: {getattr(torch, '__long_version__',torch.__version__)}\r\n&#x2000;•&#x2000;\r\nxformers: {xformers_version}\r\n&#x2000;•&#x2000;\r\ngradio: {gr.__version__}\r\n&#x2000;•&#x2000;\r\ncheckpoint: <a id=\"sd_checkpoint_hash\">N/A</a>\r\n\"\"\"\r\n\r\n\r\ndef setup_ui_api(app):\r\n    from pydantic import BaseModel, Field\r\n\r\n    class QuicksettingsHint(BaseModel):\r\n        name: str = Field(title=\"Name of the quicksettings field\")\r\n        label: str = Field(title=\"Label of the quicksettings field\")\r\n\r\n    def quicksettings_hint():\r\n        return [QuicksettingsHint(name=k, label=v.label) for k, v in opts.data_labels.items()]\r\n\r\n    app.add_api_route(\"/internal/quicksettings-hint\", quicksettings_hint, methods=[\"GET\"], response_model=list[QuicksettingsHint])\r\n\r\n    app.add_api_route(\"/internal/ping\", lambda: {}, methods=[\"GET\"])\r\n\r\n    app.add_api_route(\"/internal/profile-startup\", lambda: timer.startup_record, methods=[\"GET\"])\r\n\r\n    def download_sysinfo(attachment=False):\r\n        from fastapi.responses import PlainTextResponse\r\n\r\n        text = sysinfo.get()\r\n        filename = f\"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.json\"\r\n\r\n        return PlainTextResponse(text, headers={'Content-Disposition': f'{\"attachment\" if attachment else \"inline\"}; filename=\"{filename}\"'})\r\n\r\n    app.add_api_route(\"/internal/sysinfo\", download_sysinfo, methods=[\"GET\"])\r\n    app.add_api_route(\"/internal/sysinfo-download\", lambda: download_sysinfo(attachment=True), methods=[\"GET\"])\r\n\r\n    import fastapi.staticfiles\r\n    app.mount(\"/webui-assets\", fastapi.staticfiles.StaticFiles(directory=launch_utils.repo_dir('stable-diffusion-webui-assets')), name=\"webui-assets\")\r\n"
  },
  {
    "path": "modules/ui_checkpoint_merger.py",
    "content": "\r\nimport gradio as gr\r\n\r\nfrom modules import sd_models, sd_vae, errors, extras, call_queue\r\nfrom modules.ui_components import FormRow\r\nfrom modules.ui_common import create_refresh_button\r\n\r\n\r\ndef update_interp_description(value):\r\n    interp_description_css = \"<p style='margin-bottom: 2.5em'>{}</p>\"\r\n    interp_descriptions = {\r\n        \"No interpolation\": interp_description_css.format(\"No interpolation will be used. Requires one model; A. Allows for format conversion and VAE baking.\"),\r\n        \"Weighted sum\": interp_description_css.format(\"A weighted sum will be used for interpolation. Requires two models; A and B. The result is calculated as A * (1 - M) + B * M\"),\r\n        \"Add difference\": interp_description_css.format(\"The difference between the last two models will be added to the first. Requires three models; A, B and C. The result is calculated as A + (B - C) * M\")\r\n    }\r\n    return interp_descriptions[value]\r\n\r\n\r\ndef modelmerger(*args):\r\n    try:\r\n        results = extras.run_modelmerger(*args)\r\n    except Exception as e:\r\n        errors.report(\"Error loading/saving model file\", exc_info=True)\r\n        sd_models.list_models()  # to remove the potentially missing models from the list\r\n        return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], f\"Error merging checkpoints: {e}\"]\r\n    return results\r\n\r\n\r\nclass UiCheckpointMerger:\r\n    def __init__(self):\r\n        with gr.Blocks(analytics_enabled=False) as modelmerger_interface:\r\n            with gr.Row(equal_height=False):\r\n                with gr.Column(variant='compact'):\r\n                    self.interp_description = gr.HTML(value=update_interp_description(\"Weighted sum\"), elem_id=\"modelmerger_interp_description\")\r\n\r\n                    with FormRow(elem_id=\"modelmerger_models\"):\r\n                        self.primary_model_name = gr.Dropdown(sd_models.checkpoint_tiles(), elem_id=\"modelmerger_primary_model_name\", label=\"Primary model (A)\")\r\n                        create_refresh_button(self.primary_model_name, sd_models.list_models, lambda: {\"choices\": sd_models.checkpoint_tiles()}, \"refresh_checkpoint_A\")\r\n\r\n                        self.secondary_model_name = gr.Dropdown(sd_models.checkpoint_tiles(), elem_id=\"modelmerger_secondary_model_name\", label=\"Secondary model (B)\")\r\n                        create_refresh_button(self.secondary_model_name, sd_models.list_models, lambda: {\"choices\": sd_models.checkpoint_tiles()}, \"refresh_checkpoint_B\")\r\n\r\n                        self.tertiary_model_name = gr.Dropdown(sd_models.checkpoint_tiles(), elem_id=\"modelmerger_tertiary_model_name\", label=\"Tertiary model (C)\")\r\n                        create_refresh_button(self.tertiary_model_name, sd_models.list_models, lambda: {\"choices\": sd_models.checkpoint_tiles()}, \"refresh_checkpoint_C\")\r\n\r\n                    self.custom_name = gr.Textbox(label=\"Custom Name (Optional)\", elem_id=\"modelmerger_custom_name\")\r\n                    self.interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id=\"modelmerger_interp_amount\")\r\n                    self.interp_method = gr.Radio(choices=[\"No interpolation\", \"Weighted sum\", \"Add difference\"], value=\"Weighted sum\", label=\"Interpolation Method\", elem_id=\"modelmerger_interp_method\")\r\n                    self.interp_method.change(fn=update_interp_description, inputs=[self.interp_method], outputs=[self.interp_description])\r\n\r\n                    with FormRow():\r\n                        self.checkpoint_format = gr.Radio(choices=[\"ckpt\", \"safetensors\"], value=\"safetensors\", label=\"Checkpoint format\", elem_id=\"modelmerger_checkpoint_format\")\r\n                        self.save_as_half = gr.Checkbox(value=False, label=\"Save as float16\", elem_id=\"modelmerger_save_as_half\")\r\n\r\n                    with FormRow():\r\n                        with gr.Column():\r\n                            self.config_source = gr.Radio(choices=[\"A, B or C\", \"B\", \"C\", \"Don't\"], value=\"A, B or C\", label=\"Copy config from\", type=\"index\", elem_id=\"modelmerger_config_method\")\r\n\r\n                        with gr.Column():\r\n                            with FormRow():\r\n                                self.bake_in_vae = gr.Dropdown(choices=[\"None\"] + list(sd_vae.vae_dict), value=\"None\", label=\"Bake in VAE\", elem_id=\"modelmerger_bake_in_vae\")\r\n                                create_refresh_button(self.bake_in_vae, sd_vae.refresh_vae_list, lambda: {\"choices\": [\"None\"] + list(sd_vae.vae_dict)}, \"modelmerger_refresh_bake_in_vae\")\r\n\r\n                    with FormRow():\r\n                        self.discard_weights = gr.Textbox(value=\"\", label=\"Discard weights with matching name\", elem_id=\"modelmerger_discard_weights\")\r\n\r\n                    with gr.Accordion(\"Metadata\", open=False) as metadata_editor:\r\n                        with FormRow():\r\n                            self.save_metadata = gr.Checkbox(value=True, label=\"Save metadata\", elem_id=\"modelmerger_save_metadata\")\r\n                            self.add_merge_recipe = gr.Checkbox(value=True, label=\"Add merge recipe metadata\", elem_id=\"modelmerger_add_recipe\")\r\n                            self.copy_metadata_fields = gr.Checkbox(value=True, label=\"Copy metadata from merged models\", elem_id=\"modelmerger_copy_metadata\")\r\n\r\n                        self.metadata_json = gr.TextArea('{}', label=\"Metadata in JSON format\")\r\n                        self.read_metadata = gr.Button(\"Read metadata from selected checkpoints\")\r\n\r\n                    with FormRow():\r\n                        self.modelmerger_merge = gr.Button(elem_id=\"modelmerger_merge\", value=\"Merge\", variant='primary')\r\n\r\n                with gr.Column(variant='compact', elem_id=\"modelmerger_results_container\"):\r\n                    with gr.Group(elem_id=\"modelmerger_results_panel\"):\r\n                        self.modelmerger_result = gr.HTML(elem_id=\"modelmerger_result\", show_label=False)\r\n\r\n        self.metadata_editor = metadata_editor\r\n        self.blocks = modelmerger_interface\r\n\r\n    def setup_ui(self, dummy_component, sd_model_checkpoint_component):\r\n        self.checkpoint_format.change(lambda fmt: gr.update(visible=fmt == 'safetensors'), inputs=[self.checkpoint_format], outputs=[self.metadata_editor], show_progress=False)\r\n\r\n        self.read_metadata.click(extras.read_metadata, inputs=[self.primary_model_name, self.secondary_model_name, self.tertiary_model_name], outputs=[self.metadata_json])\r\n\r\n        self.modelmerger_merge.click(fn=lambda: '', inputs=[], outputs=[self.modelmerger_result])\r\n        self.modelmerger_merge.click(\r\n            fn=call_queue.wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]),\r\n            _js='modelmerger',\r\n            inputs=[\r\n                dummy_component,\r\n                self.primary_model_name,\r\n                self.secondary_model_name,\r\n                self.tertiary_model_name,\r\n                self.interp_method,\r\n                self.interp_amount,\r\n                self.save_as_half,\r\n                self.custom_name,\r\n                self.checkpoint_format,\r\n                self.config_source,\r\n                self.bake_in_vae,\r\n                self.discard_weights,\r\n                self.save_metadata,\r\n                self.add_merge_recipe,\r\n                self.copy_metadata_fields,\r\n                self.metadata_json,\r\n            ],\r\n            outputs=[\r\n                self.primary_model_name,\r\n                self.secondary_model_name,\r\n                self.tertiary_model_name,\r\n                sd_model_checkpoint_component,\r\n                self.modelmerger_result,\r\n            ]\r\n        )\r\n\r\n        # Required as a workaround for change() event not triggering when loading values from ui-config.json\r\n        self.interp_description.value = update_interp_description(self.interp_method.value)\r\n\r\n"
  },
  {
    "path": "modules/ui_common.py",
    "content": "import csv\r\nimport dataclasses\r\nimport json\r\nimport html\r\nimport os\r\nfrom contextlib import nullcontext\r\n\r\nimport gradio as gr\r\n\r\nfrom modules import call_queue, shared, ui_tempdir, util\r\nfrom modules.infotext_utils import image_from_url_text\r\nimport modules.images\r\nfrom modules.ui_components import ToolButton\r\nimport modules.infotext_utils as parameters_copypaste\r\n\r\nfolder_symbol = '\\U0001f4c2'  # 📂\r\nrefresh_symbol = '\\U0001f504'  # 🔄\r\n\r\n\r\ndef update_generation_info(generation_info, html_info, img_index):\r\n    try:\r\n        generation_info = json.loads(generation_info)\r\n        if img_index < 0 or img_index >= len(generation_info[\"infotexts\"]):\r\n            return html_info, gr.update()\r\n        return plaintext_to_html(generation_info[\"infotexts\"][img_index]), gr.update()\r\n    except Exception:\r\n        pass\r\n    # if the json parse or anything else fails, just return the old html_info\r\n    return html_info, gr.update()\r\n\r\n\r\ndef plaintext_to_html(text, classname=None):\r\n    content = \"<br>\\n\".join(html.escape(x) for x in text.split('\\n'))\r\n\r\n    return f\"<p class='{classname}'>{content}</p>\" if classname else f\"<p>{content}</p>\"\r\n\r\n\r\ndef update_logfile(logfile_path, fields):\r\n    \"\"\"Update a logfile from old format to new format to maintain CSV integrity.\"\"\"\r\n    with open(logfile_path, \"r\", encoding=\"utf8\", newline=\"\") as file:\r\n        reader = csv.reader(file)\r\n        rows = list(reader)\r\n\r\n    # blank file: leave it as is\r\n    if not rows:\r\n        return\r\n\r\n    # file is already synced, do nothing\r\n    if len(rows[0]) == len(fields):\r\n        return\r\n\r\n    rows[0] = fields\r\n\r\n    # append new fields to each row as empty values\r\n    for row in rows[1:]:\r\n        while len(row) < len(fields):\r\n            row.append(\"\")\r\n\r\n    with open(logfile_path, \"w\", encoding=\"utf8\", newline=\"\") as file:\r\n        writer = csv.writer(file)\r\n        writer.writerows(rows)\r\n\r\n\r\ndef save_files(js_data, images, do_make_zip, index):\r\n    filenames = []\r\n    fullfns = []\r\n    parsed_infotexts = []\r\n\r\n    # quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it\r\n    class MyObject:\r\n        def __init__(self, d=None):\r\n            if d is not None:\r\n                for key, value in d.items():\r\n                    setattr(self, key, value)\r\n\r\n    data = json.loads(js_data)\r\n    p = MyObject(data)\r\n\r\n    path = shared.opts.outdir_save\r\n    save_to_dirs = shared.opts.use_save_to_dirs_for_ui\r\n    extension: str = shared.opts.samples_format\r\n    start_index = 0\r\n\r\n    if index > -1 and shared.opts.save_selected_only and (index >= data[\"index_of_first_image\"]):  # ensures we are looking at a specific non-grid picture, and we have save_selected_only\r\n        images = [images[index]]\r\n        start_index = index\r\n\r\n    os.makedirs(shared.opts.outdir_save, exist_ok=True)\r\n\r\n    fields = [\r\n        \"prompt\",\r\n        \"seed\",\r\n        \"width\",\r\n        \"height\",\r\n        \"sampler\",\r\n        \"cfgs\",\r\n        \"steps\",\r\n        \"filename\",\r\n        \"negative_prompt\",\r\n        \"sd_model_name\",\r\n        \"sd_model_hash\",\r\n    ]\r\n    logfile_path = os.path.join(shared.opts.outdir_save, \"log.csv\")\r\n\r\n    # NOTE: ensure csv integrity when fields are added by\r\n    # updating headers and padding with delimiters where needed\r\n    if shared.opts.save_write_log_csv and os.path.exists(logfile_path):\r\n        update_logfile(logfile_path, fields)\r\n\r\n    with (open(logfile_path, \"a\", encoding=\"utf8\", newline='') if shared.opts.save_write_log_csv else nullcontext()) as file:\r\n        if file:\r\n            at_start = file.tell() == 0\r\n            writer = csv.writer(file)\r\n            if at_start:\r\n                writer.writerow(fields)\r\n\r\n        for image_index, filedata in enumerate(images, start_index):\r\n            image = image_from_url_text(filedata)\r\n\r\n            is_grid = image_index < p.index_of_first_image\r\n\r\n            p.batch_index = image_index-1\r\n\r\n            parameters = parameters_copypaste.parse_generation_parameters(data[\"infotexts\"][image_index], [])\r\n            parsed_infotexts.append(parameters)\r\n            fullfn, txt_fullfn = modules.images.save_image(image, path, \"\", seed=parameters['Seed'], prompt=parameters['Prompt'], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs)\r\n\r\n            filename = os.path.relpath(fullfn, path)\r\n            filenames.append(filename)\r\n            fullfns.append(fullfn)\r\n            if txt_fullfn:\r\n                filenames.append(os.path.basename(txt_fullfn))\r\n                fullfns.append(txt_fullfn)\r\n\r\n        if file:\r\n            writer.writerow([parsed_infotexts[0]['Prompt'], parsed_infotexts[0]['Seed'], data[\"width\"], data[\"height\"], data[\"sampler_name\"], data[\"cfg_scale\"], data[\"steps\"], filenames[0], parsed_infotexts[0]['Negative prompt'], data[\"sd_model_name\"], data[\"sd_model_hash\"]])\r\n\r\n    # Make Zip\r\n    if do_make_zip:\r\n        p.all_seeds = [parameters['Seed'] for parameters in parsed_infotexts]\r\n        namegen = modules.images.FilenameGenerator(p, parsed_infotexts[0]['Seed'], parsed_infotexts[0]['Prompt'], image, True)\r\n        zip_filename = namegen.apply(shared.opts.grid_zip_filename_pattern or \"[datetime]_[[model_name]]_[seed]-[seed_last]\")\r\n        zip_filepath = os.path.join(path, f\"{zip_filename}.zip\")\r\n\r\n        from zipfile import ZipFile\r\n        with ZipFile(zip_filepath, \"w\") as zip_file:\r\n            for i in range(len(fullfns)):\r\n                with open(fullfns[i], mode=\"rb\") as f:\r\n                    zip_file.writestr(filenames[i], f.read())\r\n        fullfns.insert(0, zip_filepath)\r\n\r\n    return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f\"Saved: {filenames[0]}\")\r\n\r\n\r\n@dataclasses.dataclass\r\nclass OutputPanel:\r\n    gallery = None\r\n    generation_info = None\r\n    infotext = None\r\n    html_log = None\r\n    button_upscale = None\r\n\r\n\r\ndef create_output_panel(tabname, outdir, toprow=None):\r\n    res = OutputPanel()\r\n\r\n    def open_folder(f, images=None, index=None):\r\n        if shared.cmd_opts.hide_ui_dir_config:\r\n            return\r\n\r\n        try:\r\n            if 'Sub' in shared.opts.open_dir_button_choice:\r\n                image_dir = os.path.split(images[index][\"name\"].rsplit('?', 1)[0])[0]\r\n                if 'temp' in shared.opts.open_dir_button_choice or not ui_tempdir.is_gradio_temp_path(image_dir):\r\n                    f = image_dir\r\n        except Exception:\r\n            pass\r\n\r\n        util.open_folder(f)\r\n\r\n    with gr.Column(elem_id=f\"{tabname}_results\"):\r\n        if toprow:\r\n            toprow.create_inline_toprow_image()\r\n\r\n        with gr.Column(variant='panel', elem_id=f\"{tabname}_results_panel\"):\r\n            with gr.Group(elem_id=f\"{tabname}_gallery_container\"):\r\n                res.gallery = gr.Gallery(label='Output', show_label=False, elem_id=f\"{tabname}_gallery\", columns=4, preview=True, height=shared.opts.gallery_height or None)\r\n\r\n            with gr.Row(elem_id=f\"image_buttons_{tabname}\", elem_classes=\"image-buttons\"):\r\n                open_folder_button = ToolButton(folder_symbol, elem_id=f'{tabname}_open_folder', visible=not shared.cmd_opts.hide_ui_dir_config, tooltip=\"Open images output directory.\")\r\n\r\n                if tabname != \"extras\":\r\n                    save = ToolButton('💾', elem_id=f'save_{tabname}', tooltip=f\"Save the image to a dedicated directory ({shared.opts.outdir_save}).\")\r\n                    save_zip = ToolButton('🗃️', elem_id=f'save_zip_{tabname}', tooltip=f\"Save zip archive with images to a dedicated directory ({shared.opts.outdir_save})\")\r\n\r\n                buttons = {\r\n                    'img2img': ToolButton('🖼️', elem_id=f'{tabname}_send_to_img2img', tooltip=\"Send image and generation parameters to img2img tab.\"),\r\n                    'inpaint': ToolButton('🎨️', elem_id=f'{tabname}_send_to_inpaint', tooltip=\"Send image and generation parameters to img2img inpaint tab.\"),\r\n                    'extras': ToolButton('📐', elem_id=f'{tabname}_send_to_extras', tooltip=\"Send image and generation parameters to extras tab.\")\r\n                }\r\n\r\n                if tabname == 'txt2img':\r\n                    res.button_upscale = ToolButton('✨', elem_id=f'{tabname}_upscale', tooltip=\"Create an upscaled version of the current image using hires fix settings.\")\r\n\r\n            open_folder_button.click(\r\n                fn=lambda images, index: open_folder(shared.opts.outdir_samples or outdir, images, index),\r\n                _js=\"(y, w) => [y, selected_gallery_index()]\",\r\n                inputs=[\r\n                    res.gallery,\r\n                    open_folder_button,  # placeholder for index\r\n                ],\r\n                outputs=[],\r\n            )\r\n\r\n            if tabname != \"extras\":\r\n                download_files = gr.File(None, file_count=\"multiple\", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}')\r\n\r\n                with gr.Group():\r\n                    res.infotext = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes=\"infotext\")\r\n                    res.html_log = gr.HTML(elem_id=f'html_log_{tabname}', elem_classes=\"html-log\")\r\n\r\n                    res.generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}')\r\n                    if tabname == 'txt2img' or tabname == 'img2img':\r\n                        generation_info_button = gr.Button(visible=False, elem_id=f\"{tabname}_generation_info_button\")\r\n                        generation_info_button.click(\r\n                            fn=update_generation_info,\r\n                            _js=\"function(x, y, z){ return [x, y, selected_gallery_index()] }\",\r\n                            inputs=[res.generation_info, res.infotext, res.infotext],\r\n                            outputs=[res.infotext, res.infotext],\r\n                            show_progress=False,\r\n                        )\r\n\r\n                    save.click(\r\n                        fn=call_queue.wrap_gradio_call_no_job(save_files),\r\n                        _js=\"(x, y, z, w) => [x, y, false, selected_gallery_index()]\",\r\n                        inputs=[\r\n                            res.generation_info,\r\n                            res.gallery,\r\n                            res.infotext,\r\n                            res.infotext,\r\n                        ],\r\n                        outputs=[\r\n                            download_files,\r\n                            res.html_log,\r\n                        ],\r\n                        show_progress=False,\r\n                    )\r\n\r\n                    save_zip.click(\r\n                        fn=call_queue.wrap_gradio_call_no_job(save_files),\r\n                        _js=\"(x, y, z, w) => [x, y, true, selected_gallery_index()]\",\r\n                        inputs=[\r\n                            res.generation_info,\r\n                            res.gallery,\r\n                            res.infotext,\r\n                            res.infotext,\r\n                        ],\r\n                        outputs=[\r\n                            download_files,\r\n                            res.html_log,\r\n                        ]\r\n                    )\r\n\r\n            else:\r\n                res.generation_info = gr.HTML(elem_id=f'html_info_x_{tabname}')\r\n                res.infotext = gr.HTML(elem_id=f'html_info_{tabname}', elem_classes=\"infotext\")\r\n                res.html_log = gr.HTML(elem_id=f'html_log_{tabname}')\r\n\r\n            paste_field_names = []\r\n            if tabname == \"txt2img\":\r\n                paste_field_names = modules.scripts.scripts_txt2img.paste_field_names\r\n            elif tabname == \"img2img\":\r\n                paste_field_names = modules.scripts.scripts_img2img.paste_field_names\r\n\r\n            for paste_tabname, paste_button in buttons.items():\r\n                parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(\r\n                    paste_button=paste_button, tabname=paste_tabname, source_tabname=\"txt2img\" if tabname == \"txt2img\" else None, source_image_component=res.gallery,\r\n                    paste_field_names=paste_field_names\r\n                ))\r\n\r\n    return res\r\n\r\n\r\ndef create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):\r\n    refresh_components = refresh_component if isinstance(refresh_component, list) else [refresh_component]\r\n\r\n    label = None\r\n    for comp in refresh_components:\r\n        label = getattr(comp, 'label', None)\r\n        if label is not None:\r\n            break\r\n\r\n    def refresh():\r\n        refresh_method()\r\n        args = refreshed_args() if callable(refreshed_args) else refreshed_args\r\n\r\n        for k, v in args.items():\r\n            for comp in refresh_components:\r\n                setattr(comp, k, v)\r\n\r\n        return [gr.update(**(args or {})) for _ in refresh_components] if len(refresh_components) > 1 else gr.update(**(args or {}))\r\n\r\n    refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id, tooltip=f\"{label}: refresh\" if label else \"Refresh\")\r\n    refresh_button.click(\r\n        fn=refresh,\r\n        inputs=[],\r\n        outputs=refresh_components\r\n    )\r\n    return refresh_button\r\n\r\n\r\ndef setup_dialog(button_show, dialog, *, button_close=None):\r\n    \"\"\"Sets up the UI so that the dialog (gr.Box) is invisible, and is only shown when buttons_show is clicked, in a fullscreen modal window.\"\"\"\r\n\r\n    dialog.visible = False\r\n\r\n    button_show.click(\r\n        fn=lambda: gr.update(visible=True),\r\n        inputs=[],\r\n        outputs=[dialog],\r\n    ).then(fn=None, _js=\"function(){ popupId('\" + dialog.elem_id + \"'); }\")\r\n\r\n    if button_close:\r\n        button_close.click(fn=None, _js=\"closePopup\")\r\n\r\n"
  },
  {
    "path": "modules/ui_components.py",
    "content": "import gradio as gr\r\n\r\n\r\nclass FormComponent:\r\n    def get_expected_parent(self):\r\n        return gr.components.Form\r\n\r\n\r\ngr.Dropdown.get_expected_parent = FormComponent.get_expected_parent\r\n\r\n\r\nclass ToolButton(FormComponent, gr.Button):\r\n    \"\"\"Small button with single emoji as text, fits inside gradio forms\"\"\"\r\n\r\n    def __init__(self, *args, **kwargs):\r\n        classes = kwargs.pop(\"elem_classes\", [])\r\n        super().__init__(*args, elem_classes=[\"tool\", *classes], **kwargs)\r\n\r\n    def get_block_name(self):\r\n        return \"button\"\r\n\r\n\r\nclass ResizeHandleRow(gr.Row):\r\n    \"\"\"Same as gr.Row but fits inside gradio forms\"\"\"\r\n\r\n    def __init__(self, **kwargs):\r\n        super().__init__(**kwargs)\r\n\r\n        self.elem_classes.append(\"resize-handle-row\")\r\n\r\n    def get_block_name(self):\r\n        return \"row\"\r\n\r\n\r\nclass FormRow(FormComponent, gr.Row):\r\n    \"\"\"Same as gr.Row but fits inside gradio forms\"\"\"\r\n\r\n    def get_block_name(self):\r\n        return \"row\"\r\n\r\n\r\nclass FormColumn(FormComponent, gr.Column):\r\n    \"\"\"Same as gr.Column but fits inside gradio forms\"\"\"\r\n\r\n    def get_block_name(self):\r\n        return \"column\"\r\n\r\n\r\nclass FormGroup(FormComponent, gr.Group):\r\n    \"\"\"Same as gr.Group but fits inside gradio forms\"\"\"\r\n\r\n    def get_block_name(self):\r\n        return \"group\"\r\n\r\n\r\nclass FormHTML(FormComponent, gr.HTML):\r\n    \"\"\"Same as gr.HTML but fits inside gradio forms\"\"\"\r\n\r\n    def get_block_name(self):\r\n        return \"html\"\r\n\r\n\r\nclass FormColorPicker(FormComponent, gr.ColorPicker):\r\n    \"\"\"Same as gr.ColorPicker but fits inside gradio forms\"\"\"\r\n\r\n    def get_block_name(self):\r\n        return \"colorpicker\"\r\n\r\n\r\nclass DropdownMulti(FormComponent, gr.Dropdown):\r\n    \"\"\"Same as gr.Dropdown but always multiselect\"\"\"\r\n    def __init__(self, **kwargs):\r\n        super().__init__(multiselect=True, **kwargs)\r\n\r\n    def get_block_name(self):\r\n        return \"dropdown\"\r\n\r\n\r\nclass DropdownEditable(FormComponent, gr.Dropdown):\r\n    \"\"\"Same as gr.Dropdown but allows editing value\"\"\"\r\n    def __init__(self, **kwargs):\r\n        super().__init__(allow_custom_value=True, **kwargs)\r\n\r\n    def get_block_name(self):\r\n        return \"dropdown\"\r\n\r\n\r\nclass InputAccordion(gr.Checkbox):\r\n    \"\"\"A gr.Accordion that can be used as an input - returns True if open, False if closed.\r\n\r\n    Actually just a hidden checkbox, but creates an accordion that follows and is followed by the state of the checkbox.\r\n    \"\"\"\r\n\r\n    global_index = 0\r\n\r\n    def __init__(self, value, **kwargs):\r\n        self.accordion_id = kwargs.get('elem_id')\r\n        if self.accordion_id is None:\r\n            self.accordion_id = f\"input-accordion-{InputAccordion.global_index}\"\r\n            InputAccordion.global_index += 1\r\n\r\n        kwargs_checkbox = {\r\n            **kwargs,\r\n            \"elem_id\": f\"{self.accordion_id}-checkbox\",\r\n            \"visible\": False,\r\n        }\r\n        super().__init__(value, **kwargs_checkbox)\r\n\r\n        self.change(fn=None, _js='function(checked){ inputAccordionChecked(\"' + self.accordion_id + '\", checked); }', inputs=[self])\r\n\r\n        kwargs_accordion = {\r\n            **kwargs,\r\n            \"elem_id\": self.accordion_id,\r\n            \"label\": kwargs.get('label', 'Accordion'),\r\n            \"elem_classes\": ['input-accordion'],\r\n            \"open\": value,\r\n        }\r\n        self.accordion = gr.Accordion(**kwargs_accordion)\r\n\r\n    def extra(self):\r\n        \"\"\"Allows you to put something into the label of the accordion.\r\n\r\n        Use it like this:\r\n\r\n        ```\r\n        with InputAccordion(False, label=\"Accordion\") as acc:\r\n            with acc.extra():\r\n                FormHTML(value=\"hello\", min_width=0)\r\n\r\n            ...\r\n        ```\r\n        \"\"\"\r\n\r\n        return gr.Column(elem_id=self.accordion_id + '-extra', elem_classes='input-accordion-extra', min_width=0)\r\n\r\n    def __enter__(self):\r\n        self.accordion.__enter__()\r\n        return self\r\n\r\n    def __exit__(self, exc_type, exc_val, exc_tb):\r\n        self.accordion.__exit__(exc_type, exc_val, exc_tb)\r\n\r\n    def get_block_name(self):\r\n        return \"checkbox\"\r\n\r\n"
  },
  {
    "path": "modules/ui_extensions.py",
    "content": "import json\r\nimport os\r\nimport threading\r\nimport time\r\nfrom datetime import datetime, timezone\r\n\r\nimport git\r\n\r\nimport gradio as gr\r\nimport html\r\nimport shutil\r\nimport errno\r\n\r\nfrom modules import extensions, shared, paths, config_states, errors, restart\r\nfrom modules.paths_internal import config_states_dir\r\nfrom modules.call_queue import wrap_gradio_gpu_call\r\n\r\navailable_extensions = {\"extensions\": []}\r\nSTYLE_PRIMARY = ' style=\"color: var(--primary-400)\"'\r\n\r\n\r\ndef check_access():\r\n    assert not shared.cmd_opts.disable_extension_access, \"extension access disabled because of command line flags\"\r\n\r\n\r\ndef apply_and_restart(disable_list, update_list, disable_all):\r\n    check_access()\r\n\r\n    disabled = json.loads(disable_list)\r\n    assert type(disabled) == list, f\"wrong disable_list data for apply_and_restart: {disable_list}\"\r\n\r\n    update = json.loads(update_list)\r\n    assert type(update) == list, f\"wrong update_list data for apply_and_restart: {update_list}\"\r\n\r\n    if update:\r\n        save_config_state(\"Backup (pre-update)\")\r\n\r\n    update = set(update)\r\n\r\n    for ext in extensions.extensions:\r\n        if ext.name not in update:\r\n            continue\r\n\r\n        try:\r\n            ext.fetch_and_reset_hard()\r\n        except Exception:\r\n            errors.report(f\"Error getting updates for {ext.name}\", exc_info=True)\r\n\r\n    shared.opts.disabled_extensions = disabled\r\n    shared.opts.disable_all_extensions = disable_all\r\n    shared.opts.save(shared.config_filename)\r\n\r\n    if restart.is_restartable():\r\n        restart.restart_program()\r\n    else:\r\n        restart.stop_program()\r\n\r\n\r\ndef save_config_state(name):\r\n    current_config_state = config_states.get_config()\r\n\r\n    name = os.path.basename(name or \"Config\")\r\n\r\n    current_config_state[\"name\"] = name\r\n    timestamp = datetime.now().strftime('%Y_%m_%d-%H_%M_%S')\r\n    filename = os.path.join(config_states_dir, f\"{timestamp}_{name}.json\")\r\n    print(f\"Saving backup of webui/extension state to {filename}.\")\r\n    with open(filename, \"w\", encoding=\"utf-8\") as f:\r\n        json.dump(current_config_state, f, indent=4, ensure_ascii=False)\r\n    config_states.list_config_states()\r\n    new_value = next(iter(config_states.all_config_states.keys()), \"Current\")\r\n    new_choices = [\"Current\"] + list(config_states.all_config_states.keys())\r\n    return gr.Dropdown.update(value=new_value, choices=new_choices), f\"<span>Saved current webui/extension state to \\\"{filename}\\\"</span>\"\r\n\r\n\r\ndef restore_config_state(confirmed, config_state_name, restore_type):\r\n    if config_state_name == \"Current\":\r\n        return \"<span>Select a config to restore from.</span>\"\r\n    if not confirmed:\r\n        return \"<span>Cancelled.</span>\"\r\n\r\n    check_access()\r\n\r\n    config_state = config_states.all_config_states[config_state_name]\r\n\r\n    print(f\"*** Restoring webui state from backup: {restore_type} ***\")\r\n\r\n    if restore_type == \"extensions\" or restore_type == \"both\":\r\n        shared.opts.restore_config_state_file = config_state[\"filepath\"]\r\n        shared.opts.save(shared.config_filename)\r\n\r\n    if restore_type == \"webui\" or restore_type == \"both\":\r\n        config_states.restore_webui_config(config_state)\r\n\r\n    shared.state.request_restart()\r\n\r\n    return \"\"\r\n\r\n\r\ndef check_updates(id_task, disable_list):\r\n    check_access()\r\n\r\n    disabled = json.loads(disable_list)\r\n    assert type(disabled) == list, f\"wrong disable_list data for apply_and_restart: {disable_list}\"\r\n\r\n    exts = [ext for ext in extensions.extensions if ext.remote is not None and ext.name not in disabled]\r\n    shared.state.job_count = len(exts)\r\n\r\n    for ext in exts:\r\n        shared.state.textinfo = ext.name\r\n\r\n        try:\r\n            ext.check_updates()\r\n        except FileNotFoundError as e:\r\n            if 'FETCH_HEAD' not in str(e):\r\n                raise\r\n        except Exception:\r\n            errors.report(f\"Error checking updates for {ext.name}\", exc_info=True)\r\n\r\n        shared.state.nextjob()\r\n\r\n    return extension_table(), \"\"\r\n\r\n\r\ndef make_commit_link(commit_hash, remote, text=None):\r\n    if text is None:\r\n        text = commit_hash[:8]\r\n    if remote.startswith(\"https://github.com/\"):\r\n        if remote.endswith(\".git\"):\r\n            remote = remote[:-4]\r\n        href = remote + \"/commit/\" + commit_hash\r\n        return f'<a href=\"{href}\" target=\"_blank\">{text}</a>'\r\n    else:\r\n        return text\r\n\r\n\r\ndef extension_table():\r\n    code = f\"\"\"<!-- {time.time()} -->\r\n    <table id=\"extensions\">\r\n        <thead>\r\n            <tr>\r\n                <th>\r\n                    <input class=\"gr-check-radio gr-checkbox all_extensions_toggle\" type=\"checkbox\" {'checked=\"checked\"' if all(ext.enabled for ext in extensions.extensions) else ''} onchange=\"toggle_all_extensions(event)\" />\r\n                    <abbr title=\"Use checkbox to enable the extension; it will be enabled or disabled when you click apply button\">Extension</abbr>\r\n                </th>\r\n                <th>URL</th>\r\n                <th>Branch</th>\r\n                <th>Version</th>\r\n                <th>Date</th>\r\n                <th><abbr title=\"Use checkbox to mark the extension for update; it will be updated when you click apply button\">Update</abbr></th>\r\n            </tr>\r\n        </thead>\r\n        <tbody>\r\n    \"\"\"\r\n\r\n    for ext in extensions.extensions:\r\n        ext: extensions.Extension\r\n        ext.read_info_from_repo()\r\n\r\n        remote = f\"\"\"<a href=\"{html.escape(ext.remote or '')}\" target=\"_blank\">{html.escape(\"built-in\" if ext.is_builtin else ext.remote or '')}</a>\"\"\"\r\n\r\n        if ext.can_update:\r\n            ext_status = f\"\"\"<label><input class=\"gr-check-radio gr-checkbox\" name=\"update_{html.escape(ext.name)}\" checked=\"checked\" type=\"checkbox\">{html.escape(ext.status)}</label>\"\"\"\r\n        else:\r\n            ext_status = ext.status\r\n\r\n        style = \"\"\r\n        if shared.cmd_opts.disable_extra_extensions and not ext.is_builtin or shared.opts.disable_all_extensions == \"extra\" and not ext.is_builtin or shared.cmd_opts.disable_all_extensions or shared.opts.disable_all_extensions == \"all\":\r\n            style = STYLE_PRIMARY\r\n\r\n        version_link = ext.version\r\n        if ext.commit_hash and ext.remote:\r\n            version_link = make_commit_link(ext.commit_hash, ext.remote, ext.version)\r\n\r\n        code += f\"\"\"\r\n            <tr>\r\n                <td><label{style}><input class=\"gr-check-radio gr-checkbox extension_toggle\" name=\"enable_{html.escape(ext.name)}\" type=\"checkbox\" {'checked=\"checked\"' if ext.enabled else ''} onchange=\"toggle_extension(event)\" />{html.escape(ext.name)}</label></td>\r\n                <td>{remote}</td>\r\n                <td>{ext.branch}</td>\r\n                <td>{version_link}</td>\r\n                <td>{datetime.fromtimestamp(ext.commit_date) if ext.commit_date else \"\"}</td>\r\n                <td{' class=\"extension_status\"' if ext.remote is not None else ''}>{ext_status}</td>\r\n            </tr>\r\n    \"\"\"\r\n\r\n    code += \"\"\"\r\n        </tbody>\r\n    </table>\r\n    \"\"\"\r\n\r\n    return code\r\n\r\n\r\ndef update_config_states_table(state_name):\r\n    if state_name == \"Current\":\r\n        config_state = config_states.get_config()\r\n    else:\r\n        config_state = config_states.all_config_states[state_name]\r\n\r\n    config_name = config_state.get(\"name\", \"Config\")\r\n    created_date = datetime.fromtimestamp(config_state[\"created_at\"]).strftime('%Y-%m-%d %H:%M:%S')\r\n    filepath = config_state.get(\"filepath\", \"<unknown>\")\r\n\r\n    try:\r\n        webui_remote = config_state[\"webui\"][\"remote\"] or \"\"\r\n        webui_branch = config_state[\"webui\"][\"branch\"]\r\n        webui_commit_hash = config_state[\"webui\"][\"commit_hash\"] or \"<unknown>\"\r\n        webui_commit_date = config_state[\"webui\"][\"commit_date\"]\r\n        if webui_commit_date:\r\n            webui_commit_date = time.asctime(time.gmtime(webui_commit_date))\r\n        else:\r\n            webui_commit_date = \"<unknown>\"\r\n\r\n        remote = f\"\"\"<a href=\"{html.escape(webui_remote)}\" target=\"_blank\">{html.escape(webui_remote or '')}</a>\"\"\"\r\n        commit_link = make_commit_link(webui_commit_hash, webui_remote)\r\n        date_link = make_commit_link(webui_commit_hash, webui_remote, webui_commit_date)\r\n\r\n        current_webui = config_states.get_webui_config()\r\n\r\n        style_remote = \"\"\r\n        style_branch = \"\"\r\n        style_commit = \"\"\r\n        if current_webui[\"remote\"] != webui_remote:\r\n            style_remote = STYLE_PRIMARY\r\n        if current_webui[\"branch\"] != webui_branch:\r\n            style_branch = STYLE_PRIMARY\r\n        if current_webui[\"commit_hash\"] != webui_commit_hash:\r\n            style_commit = STYLE_PRIMARY\r\n\r\n        code = f\"\"\"<!-- {time.time()} -->\r\n<h2>Config Backup: {config_name}</h2>\r\n<div><b>Filepath:</b> {filepath}</div>\r\n<div><b>Created at:</b> {created_date}</div>\r\n<h2>WebUI State</h2>\r\n<table id=\"config_state_webui\">\r\n    <thead>\r\n        <tr>\r\n            <th>URL</th>\r\n            <th>Branch</th>\r\n            <th>Commit</th>\r\n            <th>Date</th>\r\n        </tr>\r\n    </thead>\r\n    <tbody>\r\n        <tr>\r\n            <td>\r\n                <label{style_remote}>{remote}</label>\r\n            </td>\r\n            <td>\r\n                <label{style_branch}>{webui_branch}</label>\r\n            </td>\r\n            <td>\r\n                <label{style_commit}>{commit_link}</label>\r\n            </td>\r\n            <td>\r\n                <label{style_commit}>{date_link}</label>\r\n            </td>\r\n        </tr>\r\n    </tbody>\r\n</table>\r\n<h2>Extension State</h2>\r\n<table id=\"config_state_extensions\">\r\n    <thead>\r\n        <tr>\r\n            <th>Extension</th>\r\n            <th>URL</th>\r\n            <th>Branch</th>\r\n            <th>Commit</th>\r\n            <th>Date</th>\r\n        </tr>\r\n    </thead>\r\n    <tbody>\r\n\"\"\"\r\n\r\n        ext_map = {ext.name: ext for ext in extensions.extensions}\r\n\r\n        for ext_name, ext_conf in config_state[\"extensions\"].items():\r\n            ext_remote = ext_conf[\"remote\"] or \"\"\r\n            ext_branch = ext_conf[\"branch\"] or \"<unknown>\"\r\n            ext_enabled = ext_conf[\"enabled\"]\r\n            ext_commit_hash = ext_conf[\"commit_hash\"] or \"<unknown>\"\r\n            ext_commit_date = ext_conf[\"commit_date\"]\r\n            if ext_commit_date:\r\n                ext_commit_date = time.asctime(time.gmtime(ext_commit_date))\r\n            else:\r\n                ext_commit_date = \"<unknown>\"\r\n\r\n            remote = f\"\"\"<a href=\"{html.escape(ext_remote)}\" target=\"_blank\">{html.escape(ext_remote or '')}</a>\"\"\"\r\n            commit_link = make_commit_link(ext_commit_hash, ext_remote)\r\n            date_link = make_commit_link(ext_commit_hash, ext_remote, ext_commit_date)\r\n\r\n            style_enabled = \"\"\r\n            style_remote = \"\"\r\n            style_branch = \"\"\r\n            style_commit = \"\"\r\n            if ext_name in ext_map:\r\n                current_ext = ext_map[ext_name]\r\n                current_ext.read_info_from_repo()\r\n                if current_ext.enabled != ext_enabled:\r\n                    style_enabled = STYLE_PRIMARY\r\n                if current_ext.remote != ext_remote:\r\n                    style_remote = STYLE_PRIMARY\r\n                if current_ext.branch != ext_branch:\r\n                    style_branch = STYLE_PRIMARY\r\n                if current_ext.commit_hash != ext_commit_hash:\r\n                    style_commit = STYLE_PRIMARY\r\n\r\n            code += f\"\"\"        <tr>\r\n            <td><label{style_enabled}><input class=\"gr-check-radio gr-checkbox\" type=\"checkbox\" disabled=\"true\" {'checked=\"checked\"' if ext_enabled else ''}>{html.escape(ext_name)}</label></td>\r\n            <td><label{style_remote}>{remote}</label></td>\r\n            <td><label{style_branch}>{ext_branch}</label></td>\r\n            <td><label{style_commit}>{commit_link}</label></td>\r\n            <td><label{style_commit}>{date_link}</label></td>\r\n        </tr>\r\n\"\"\"\r\n\r\n        code += \"\"\"    </tbody>\r\n</table>\"\"\"\r\n\r\n    except Exception as e:\r\n        print(f\"[ERROR]: Config states {filepath}, {e}\")\r\n        code = f\"\"\"<!-- {time.time()} -->\r\n<h2>Config Backup: {config_name}</h2>\r\n<div><b>Filepath:</b> {filepath}</div>\r\n<div><b>Created at:</b> {created_date}</div>\r\n<h2>This file is corrupted</h2>\"\"\"\r\n\r\n    return code\r\n\r\n\r\ndef normalize_git_url(url):\r\n    if url is None:\r\n        return \"\"\r\n\r\n    url = url.replace(\".git\", \"\")\r\n    return url\r\n\r\n\r\ndef get_extension_dirname_from_url(url):\r\n    *parts, last_part = url.split('/')\r\n    return normalize_git_url(last_part)\r\n\r\n\r\ndef install_extension_from_url(dirname, url, branch_name=None):\r\n    check_access()\r\n\r\n    if isinstance(dirname, str):\r\n        dirname = dirname.strip()\r\n    if isinstance(url, str):\r\n        url = url.strip()\r\n\r\n    assert url, 'No URL specified'\r\n\r\n    if dirname is None or dirname == \"\":\r\n        dirname = get_extension_dirname_from_url(url)\r\n\r\n    target_dir = os.path.join(extensions.extensions_dir, dirname)\r\n    assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}'\r\n\r\n    normalized_url = normalize_git_url(url)\r\n    if any(x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url):\r\n        raise Exception(f'Extension with this URL is already installed: {url}')\r\n\r\n    tmpdir = os.path.join(paths.data_path, \"tmp\", dirname)\r\n\r\n    try:\r\n        shutil.rmtree(tmpdir, True)\r\n        if not branch_name:\r\n            # if no branch is specified, use the default branch\r\n            with git.Repo.clone_from(url, tmpdir, filter=['blob:none']) as repo:\r\n                repo.remote().fetch()\r\n                for submodule in repo.submodules:\r\n                    submodule.update()\r\n        else:\r\n            with git.Repo.clone_from(url, tmpdir, filter=['blob:none'], branch=branch_name) as repo:\r\n                repo.remote().fetch()\r\n                for submodule in repo.submodules:\r\n                    submodule.update()\r\n        try:\r\n            os.rename(tmpdir, target_dir)\r\n        except OSError as err:\r\n            if err.errno == errno.EXDEV:\r\n                # Cross device link, typical in docker or when tmp/ and extensions/ are on different file systems\r\n                # Since we can't use a rename, do the slower but more versatile shutil.move()\r\n                shutil.move(tmpdir, target_dir)\r\n            else:\r\n                # Something else, not enough free space, permissions, etc.  rethrow it so that it gets handled.\r\n                raise err\r\n\r\n        import launch\r\n        launch.run_extension_installer(target_dir)\r\n\r\n        extensions.list_extensions()\r\n        return [extension_table(), html.escape(f\"Installed into {target_dir}. Use Installed tab to restart.\")]\r\n    finally:\r\n        shutil.rmtree(tmpdir, True)\r\n\r\n\r\ndef install_extension_from_index(url, selected_tags, showing_type, filtering_type, sort_column, filter_text):\r\n    ext_table, message = install_extension_from_url(None, url)\r\n\r\n    code, _ = refresh_available_extensions_from_data(selected_tags, showing_type, filtering_type, sort_column, filter_text)\r\n\r\n    return code, ext_table, message, ''\r\n\r\n\r\ndef refresh_available_extensions(url, selected_tags, showing_type, filtering_type, sort_column):\r\n    global available_extensions\r\n\r\n    import urllib.request\r\n    with urllib.request.urlopen(url) as response:\r\n        text = response.read()\r\n\r\n    available_extensions = json.loads(text)\r\n\r\n    code, tags = refresh_available_extensions_from_data(selected_tags, showing_type, filtering_type, sort_column)\r\n\r\n    return url, code, gr.CheckboxGroup.update(choices=tags), '', ''\r\n\r\n\r\ndef refresh_available_extensions_for_tags(selected_tags, showing_type, filtering_type, sort_column, filter_text):\r\n    code, _ = refresh_available_extensions_from_data(selected_tags, showing_type, filtering_type, sort_column, filter_text)\r\n\r\n    return code, ''\r\n\r\n\r\ndef search_extensions(filter_text, selected_tags, showing_type, filtering_type, sort_column):\r\n    code, _ = refresh_available_extensions_from_data(selected_tags, showing_type, filtering_type, sort_column, filter_text)\r\n\r\n    return code, ''\r\n\r\n\r\nsort_ordering = [\r\n    # (reverse, order_by_function)\r\n    (True, lambda x: x.get('added', 'z')),\r\n    (False, lambda x: x.get('added', 'z')),\r\n    (False, lambda x: x.get('name', 'z')),\r\n    (True, lambda x: x.get('name', 'z')),\r\n    (False, lambda x: 'z'),\r\n    (True, lambda x: x.get('commit_time', '')),\r\n    (True, lambda x: x.get('created_at', '')),\r\n    (True, lambda x: x.get('stars', 0)),\r\n]\r\n\r\n\r\ndef get_date(info: dict, key):\r\n    try:\r\n        return datetime.strptime(info.get(key), \"%Y-%m-%dT%H:%M:%SZ\").replace(tzinfo=timezone.utc).astimezone().strftime(\"%Y-%m-%d\")\r\n    except (ValueError, TypeError):\r\n        return ''\r\n\r\n\r\ndef refresh_available_extensions_from_data(selected_tags, showing_type, filtering_type, sort_column, filter_text=\"\"):\r\n    extlist = available_extensions[\"extensions\"]\r\n    installed_extensions = {extension.name for extension in extensions.extensions}\r\n    installed_extension_urls = {normalize_git_url(extension.remote) for extension in extensions.extensions if extension.remote is not None}\r\n\r\n    tags = available_extensions.get(\"tags\", {})\r\n    selected_tags = set(selected_tags)\r\n    hidden = 0\r\n\r\n    code = f\"\"\"<!-- {time.time()} -->\r\n    <table id=\"available_extensions\">\r\n        <thead>\r\n            <tr>\r\n                <th>Extension</th>\r\n                <th>Description</th>\r\n                <th>Action</th>\r\n            </tr>\r\n        </thead>\r\n        <tbody>\r\n    \"\"\"\r\n\r\n    sort_reverse, sort_function = sort_ordering[sort_column if 0 <= sort_column < len(sort_ordering) else 0]\r\n\r\n    for ext in sorted(extlist, key=sort_function, reverse=sort_reverse):\r\n        name = ext.get(\"name\", \"noname\")\r\n        stars = int(ext.get(\"stars\", 0))\r\n        added = ext.get('added', 'unknown')\r\n        update_time = get_date(ext, 'commit_time')\r\n        create_time = get_date(ext, 'created_at')\r\n        url = ext.get(\"url\", None)\r\n        description = ext.get(\"description\", \"\")\r\n        extension_tags = ext.get(\"tags\", [])\r\n\r\n        if url is None:\r\n            continue\r\n\r\n        existing = get_extension_dirname_from_url(url) in installed_extensions or normalize_git_url(url) in installed_extension_urls\r\n        extension_tags = extension_tags + [\"installed\"] if existing else extension_tags\r\n\r\n        if len(selected_tags) > 0:\r\n            matched_tags = [x for x in extension_tags if x in selected_tags]\r\n            if filtering_type == 'or':\r\n                need_hide = len(matched_tags) > 0\r\n            else:\r\n                need_hide = len(matched_tags) == len(selected_tags)\r\n\r\n            if showing_type == 'show':\r\n                need_hide = not need_hide\r\n\r\n            if need_hide:\r\n                hidden += 1\r\n                continue\r\n\r\n        if filter_text and filter_text.strip():\r\n            if filter_text.lower() not in html.escape(name).lower() and filter_text.lower() not in html.escape(description).lower():\r\n                hidden += 1\r\n                continue\r\n\r\n        install_code = f\"\"\"<button onclick=\"install_extension_from_index(this, '{html.escape(url)}')\" {\"disabled=disabled\" if existing else \"\"} class=\"lg secondary gradio-button custom-button\">{\"Install\" if not existing else \"Installed\"}</button>\"\"\"\r\n\r\n        tags_text = \", \".join([f\"<span class='extension-tag' title='{tags.get(x, '')}'>{x}</span>\" for x in extension_tags])\r\n\r\n        code += f\"\"\"\r\n            <tr>\r\n                <td><a href=\"{html.escape(url)}\" target=\"_blank\">{html.escape(name)}</a><br />{tags_text}</td>\r\n                <td>{html.escape(description)}<p class=\"info\">\r\n                <span class=\"date_added\">Update: {html.escape(update_time)}  Added: {html.escape(added)}  Created: {html.escape(create_time)}</span><span class=\"star_count\">stars: <b>{stars}</b></a></p></td>\r\n                <td>{install_code}</td>\r\n            </tr>\r\n\r\n        \"\"\"\r\n\r\n        for tag in [x for x in extension_tags if x not in tags]:\r\n            tags[tag] = tag\r\n\r\n    code += \"\"\"\r\n        </tbody>\r\n    </table>\r\n    \"\"\"\r\n\r\n    if hidden > 0:\r\n        code += f\"<p>Extension hidden: {hidden}</p>\"\r\n\r\n    return code, list(tags)\r\n\r\n\r\ndef preload_extensions_git_metadata():\r\n    for extension in extensions.extensions:\r\n        extension.read_info_from_repo()\r\n\r\n\r\ndef create_ui():\r\n    import modules.ui\r\n\r\n    config_states.list_config_states()\r\n\r\n    threading.Thread(target=preload_extensions_git_metadata).start()\r\n\r\n    with gr.Blocks(analytics_enabled=False) as ui:\r\n        with gr.Tabs(elem_id=\"tabs_extensions\"):\r\n            with gr.TabItem(\"Installed\", id=\"installed\"):\r\n\r\n                with gr.Row(elem_id=\"extensions_installed_top\"):\r\n                    apply_label = (\"Apply and restart UI\" if restart.is_restartable() else \"Apply and quit\")\r\n                    apply = gr.Button(value=apply_label, variant=\"primary\")\r\n                    check = gr.Button(value=\"Check for updates\")\r\n                    extensions_disable_all = gr.Radio(label=\"Disable all extensions\", choices=[\"none\", \"extra\", \"all\"], value=shared.opts.disable_all_extensions, elem_id=\"extensions_disable_all\")\r\n                    extensions_disabled_list = gr.Text(elem_id=\"extensions_disabled_list\", visible=False, container=False)\r\n                    extensions_update_list = gr.Text(elem_id=\"extensions_update_list\", visible=False, container=False)\r\n                    refresh = gr.Button(value='Refresh', variant=\"compact\")\r\n\r\n                html = \"\"\r\n\r\n                if shared.cmd_opts.disable_all_extensions or shared.cmd_opts.disable_extra_extensions or shared.opts.disable_all_extensions != \"none\":\r\n                    if shared.cmd_opts.disable_all_extensions:\r\n                        msg = '\"--disable-all-extensions\" was used, remove it to load all extensions again'\r\n                    elif shared.opts.disable_all_extensions != \"none\":\r\n                        msg = '\"Disable all extensions\" was set, change it to \"none\" to load all extensions again'\r\n                    elif shared.cmd_opts.disable_extra_extensions:\r\n                        msg = '\"--disable-extra-extensions\" was used, remove it to load all extensions again'\r\n                    html = f'<span style=\"color: var(--primary-400);\">{msg}</span>'\r\n\r\n                with gr.Row():\r\n                    info = gr.HTML(html)\r\n\r\n                with gr.Row(elem_classes=\"progress-container\"):\r\n                    extensions_table = gr.HTML('Loading...', elem_id=\"extensions_installed_html\")\r\n\r\n                ui.load(fn=extension_table, inputs=[], outputs=[extensions_table], show_progress=False)\r\n                refresh.click(fn=extension_table, inputs=[], outputs=[extensions_table], show_progress=False)\r\n\r\n                apply.click(\r\n                    fn=apply_and_restart,\r\n                    _js=\"extensions_apply\",\r\n                    inputs=[extensions_disabled_list, extensions_update_list, extensions_disable_all],\r\n                    outputs=[],\r\n                )\r\n\r\n                check.click(\r\n                    fn=wrap_gradio_gpu_call(check_updates, extra_outputs=[gr.update()]),\r\n                    _js=\"extensions_check\",\r\n                    inputs=[info, extensions_disabled_list],\r\n                    outputs=[extensions_table, info],\r\n                )\r\n\r\n            with gr.TabItem(\"Available\", id=\"available\"):\r\n                with gr.Row():\r\n                    refresh_available_extensions_button = gr.Button(value=\"Load from:\", variant=\"primary\")\r\n                    extensions_index_url = os.environ.get('WEBUI_EXTENSIONS_INDEX', \"https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json\")\r\n                    available_extensions_index = gr.Text(value=extensions_index_url, label=\"Extension index URL\", container=False)\r\n                    extension_to_install = gr.Text(elem_id=\"extension_to_install\", visible=False)\r\n                    install_extension_button = gr.Button(elem_id=\"install_extension_button\", visible=False)\r\n\r\n                with gr.Row():\r\n                    selected_tags = gr.CheckboxGroup(value=[\"ads\", \"localization\", \"installed\"], label=\"Extension tags\", choices=[\"script\", \"ads\", \"localization\", \"installed\"], elem_classes=['compact-checkbox-group'])\r\n                    sort_column = gr.Radio(value=\"newest first\", label=\"Order\", choices=[\"newest first\", \"oldest first\", \"a-z\", \"z-a\", \"internal order\",'update time', 'create time', \"stars\"], type=\"index\", elem_classes=['compact-checkbox-group'])\r\n\r\n                with gr.Row():\r\n                    showing_type = gr.Radio(value=\"hide\", label=\"Showing type\", choices=[\"hide\", \"show\"], elem_classes=['compact-checkbox-group'])\r\n                    filtering_type = gr.Radio(value=\"or\", label=\"Filtering type\", choices=[\"or\", \"and\"], elem_classes=['compact-checkbox-group'])\r\n\r\n                with gr.Row():\r\n                    search_extensions_text = gr.Text(label=\"Search\", container=False)\r\n\r\n                install_result = gr.HTML()\r\n                available_extensions_table = gr.HTML()\r\n\r\n                refresh_available_extensions_button.click(\r\n                    fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update(), gr.update()]),\r\n                    inputs=[available_extensions_index, selected_tags, showing_type, filtering_type, sort_column],\r\n                    outputs=[available_extensions_index, available_extensions_table, selected_tags, search_extensions_text, install_result],\r\n                )\r\n\r\n                install_extension_button.click(\r\n                    fn=modules.ui.wrap_gradio_call_no_job(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]),\r\n                    inputs=[extension_to_install, selected_tags, showing_type, filtering_type, sort_column, search_extensions_text],\r\n                    outputs=[available_extensions_table, extensions_table, install_result],\r\n                )\r\n\r\n                search_extensions_text.change(\r\n                    fn=modules.ui.wrap_gradio_call_no_job(search_extensions, extra_outputs=[gr.update()]),\r\n                    inputs=[search_extensions_text, selected_tags, showing_type, filtering_type, sort_column],\r\n                    outputs=[available_extensions_table, install_result],\r\n                )\r\n\r\n                selected_tags.change(\r\n                    fn=modules.ui.wrap_gradio_call_no_job(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),\r\n                    inputs=[selected_tags, showing_type, filtering_type, sort_column, search_extensions_text],\r\n                    outputs=[available_extensions_table, install_result]\r\n                )\r\n\r\n                showing_type.change(\r\n                    fn=modules.ui.wrap_gradio_call_no_job(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),\r\n                    inputs=[selected_tags, showing_type, filtering_type, sort_column, search_extensions_text],\r\n                    outputs=[available_extensions_table, install_result]\r\n                )\r\n\r\n                filtering_type.change(\r\n                    fn=modules.ui.wrap_gradio_call_no_job(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),\r\n                    inputs=[selected_tags, showing_type, filtering_type, sort_column, search_extensions_text],\r\n                    outputs=[available_extensions_table, install_result]\r\n                )\r\n\r\n                sort_column.change(\r\n                    fn=modules.ui.wrap_gradio_call_no_job(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),\r\n                    inputs=[selected_tags, showing_type, filtering_type, sort_column, search_extensions_text],\r\n                    outputs=[available_extensions_table, install_result]\r\n                )\r\n\r\n            with gr.TabItem(\"Install from URL\", id=\"install_from_url\"):\r\n                install_url = gr.Text(label=\"URL for extension's git repository\")\r\n                install_branch = gr.Text(label=\"Specific branch name\", placeholder=\"Leave empty for default main branch\")\r\n                install_dirname = gr.Text(label=\"Local directory name\", placeholder=\"Leave empty for auto\")\r\n                install_button = gr.Button(value=\"Install\", variant=\"primary\")\r\n                install_result = gr.HTML(elem_id=\"extension_install_result\")\r\n\r\n                install_button.click(\r\n                    fn=modules.ui.wrap_gradio_call_no_job(lambda *args: [gr.update(), *install_extension_from_url(*args)], extra_outputs=[gr.update(), gr.update()]),\r\n                    inputs=[install_dirname, install_url, install_branch],\r\n                    outputs=[install_url, extensions_table, install_result],\r\n                )\r\n\r\n            with gr.TabItem(\"Backup/Restore\"):\r\n                with gr.Row(elem_id=\"extensions_backup_top_row\"):\r\n                    config_states_list = gr.Dropdown(label=\"Saved Configs\", elem_id=\"extension_backup_saved_configs\", value=\"Current\", choices=[\"Current\"] + list(config_states.all_config_states.keys()))\r\n                    modules.ui.create_refresh_button(config_states_list, config_states.list_config_states, lambda: {\"choices\": [\"Current\"] + list(config_states.all_config_states.keys())}, \"refresh_config_states\")\r\n                    config_restore_type = gr.Radio(label=\"State to restore\", choices=[\"extensions\", \"webui\", \"both\"], value=\"extensions\", elem_id=\"extension_backup_restore_type\")\r\n                    config_restore_button = gr.Button(value=\"Restore Selected Config\", variant=\"primary\", elem_id=\"extension_backup_restore\")\r\n                with gr.Row(elem_id=\"extensions_backup_top_row2\"):\r\n                    config_save_name = gr.Textbox(\"\", placeholder=\"Config Name\", show_label=False)\r\n                    config_save_button = gr.Button(value=\"Save Current Config\")\r\n\r\n                config_states_info = gr.HTML(\"\")\r\n                config_states_table = gr.HTML(\"Loading...\")\r\n                ui.load(fn=update_config_states_table, inputs=[config_states_list], outputs=[config_states_table])\r\n\r\n                config_save_button.click(fn=save_config_state, inputs=[config_save_name], outputs=[config_states_list, config_states_info])\r\n\r\n                dummy_component = gr.Label(visible=False)\r\n                config_restore_button.click(fn=restore_config_state, _js=\"config_state_confirm_restore\", inputs=[dummy_component, config_states_list, config_restore_type], outputs=[config_states_info])\r\n\r\n                config_states_list.change(\r\n                    fn=update_config_states_table,\r\n                    inputs=[config_states_list],\r\n                    outputs=[config_states_table],\r\n                )\r\n\r\n\r\n    return ui\r\n"
  },
  {
    "path": "modules/ui_extra_networks.py",
    "content": "import functools\r\nimport os.path\r\nimport urllib.parse\r\nfrom base64 import b64decode\r\nfrom io import BytesIO\r\nfrom pathlib import Path\r\nfrom typing import Optional, Union\r\nfrom dataclasses import dataclass\r\n\r\nfrom modules import shared, ui_extra_networks_user_metadata, errors, extra_networks, util\r\nfrom modules.images import read_info_from_image, save_image_with_geninfo\r\nimport gradio as gr\r\nimport json\r\nimport html\r\nfrom fastapi.exceptions import HTTPException\r\nfrom PIL import Image\r\n\r\nfrom modules.infotext_utils import image_from_url_text\r\n\r\nextra_pages = []\r\nallowed_dirs = set()\r\ndefault_allowed_preview_extensions = [\"png\", \"jpg\", \"jpeg\", \"webp\", \"gif\"]\r\n\r\n@functools.cache\r\ndef allowed_preview_extensions_with_extra(extra_extensions=None):\r\n    return set(default_allowed_preview_extensions) | set(extra_extensions or [])\r\n\r\n\r\ndef allowed_preview_extensions():\r\n    return allowed_preview_extensions_with_extra((shared.opts.samples_format, ))\r\n\r\n\r\n@dataclass\r\nclass ExtraNetworksItem:\r\n    \"\"\"Wrapper for dictionaries representing ExtraNetworks items.\"\"\"\r\n    item: dict\r\n\r\n\r\ndef get_tree(paths: Union[str, list[str]], items: dict[str, ExtraNetworksItem]) -> dict:\r\n    \"\"\"Recursively builds a directory tree.\r\n\r\n    Args:\r\n        paths: Path or list of paths to directories. These paths are treated as roots from which\r\n            the tree will be built.\r\n        items: A dictionary associating filepaths to an ExtraNetworksItem instance.\r\n\r\n    Returns:\r\n        The result directory tree.\r\n    \"\"\"\r\n    if isinstance(paths, (str,)):\r\n        paths = [paths]\r\n\r\n    def _get_tree(_paths: list[str], _root: str):\r\n        _res = {}\r\n        for path in _paths:\r\n            relpath = os.path.relpath(path, _root)\r\n            if os.path.isdir(path):\r\n                dir_items = os.listdir(path)\r\n                # Ignore empty directories.\r\n                if not dir_items:\r\n                    continue\r\n                dir_tree = _get_tree([os.path.join(path, x) for x in dir_items], _root)\r\n                # We only want to store non-empty folders in the tree.\r\n                if dir_tree:\r\n                    _res[relpath] = dir_tree\r\n            else:\r\n                if path not in items:\r\n                    continue\r\n                # Add the ExtraNetworksItem to the result.\r\n                _res[relpath] = items[path]\r\n        return _res\r\n\r\n    res = {}\r\n    # Handle each root directory separately.\r\n    # Each root WILL have a key/value at the root of the result dict though\r\n    # the value can be an empty dict if the directory is empty. We want these\r\n    # placeholders for empty dirs so we can inform the user later.\r\n    for path in paths:\r\n        root = os.path.dirname(path)\r\n        relpath = os.path.relpath(path, root)\r\n        # Wrap the path in a list since that is what the `_get_tree` expects.\r\n        res[relpath] = _get_tree([path], root)\r\n        if res[relpath]:\r\n            # We need to pull the inner path out one for these root dirs.\r\n            res[relpath] = res[relpath][relpath]\r\n\r\n    return res\r\n\r\ndef register_page(page):\r\n    \"\"\"registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions\"\"\"\r\n\r\n    extra_pages.append(page)\r\n    allowed_dirs.clear()\r\n    allowed_dirs.update(set(sum([x.allowed_directories_for_previews() for x in extra_pages], [])))\r\n\r\n\r\ndef fetch_file(filename: str = \"\"):\r\n    from starlette.responses import FileResponse\r\n\r\n    if not os.path.isfile(filename):\r\n        raise HTTPException(status_code=404, detail=\"File not found\")\r\n\r\n    if not any(Path(x).absolute() in Path(filename).absolute().parents for x in allowed_dirs):\r\n        raise ValueError(f\"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.\")\r\n\r\n    ext = os.path.splitext(filename)[1].lower()[1:]\r\n    if ext not in allowed_preview_extensions():\r\n        raise ValueError(f\"File cannot be fetched: {filename}. Extensions allowed: {allowed_preview_extensions()}.\")\r\n\r\n    # would profit from returning 304\r\n    return FileResponse(filename, headers={\"Accept-Ranges\": \"bytes\"})\r\n\r\n\r\ndef fetch_cover_images(page: str = \"\", item: str = \"\", index: int = 0):\r\n    from starlette.responses import Response\r\n\r\n    page = next(iter([x for x in extra_pages if x.name == page]), None)\r\n    if page is None:\r\n        raise HTTPException(status_code=404, detail=\"File not found\")\r\n\r\n    metadata = page.metadata.get(item)\r\n    if metadata is None:\r\n        raise HTTPException(status_code=404, detail=\"File not found\")\r\n\r\n    cover_images = json.loads(metadata.get('ssmd_cover_images', {}))\r\n    image = cover_images[index] if index < len(cover_images) else None\r\n    if not image:\r\n        raise HTTPException(status_code=404, detail=\"File not found\")\r\n\r\n    try:\r\n        image = Image.open(BytesIO(b64decode(image)))\r\n        buffer = BytesIO()\r\n        image.save(buffer, format=image.format)\r\n        return Response(content=buffer.getvalue(), media_type=image.get_format_mimetype())\r\n    except Exception as err:\r\n        raise ValueError(f\"File cannot be fetched: {item}. Failed to load cover image.\") from err\r\n\r\n\r\ndef get_metadata(page: str = \"\", item: str = \"\"):\r\n    from starlette.responses import JSONResponse\r\n\r\n    page = next(iter([x for x in extra_pages if x.name == page]), None)\r\n    if page is None:\r\n        return JSONResponse({})\r\n\r\n    metadata = page.metadata.get(item)\r\n    if metadata is None:\r\n        return JSONResponse({})\r\n\r\n    metadata = {i:metadata[i] for i in metadata if i != 'ssmd_cover_images'}  # those are cover images, and they are too big to display in UI as text\r\n\r\n    return JSONResponse({\"metadata\": json.dumps(metadata, indent=4, ensure_ascii=False)})\r\n\r\n\r\ndef get_single_card(page: str = \"\", tabname: str = \"\", name: str = \"\"):\r\n    from starlette.responses import JSONResponse\r\n\r\n    page = next(iter([x for x in extra_pages if x.name == page]), None)\r\n\r\n    try:\r\n        item = page.create_item(name, enable_filter=False)\r\n        page.items[name] = item\r\n    except Exception as e:\r\n        errors.display(e, \"creating item for extra network\")\r\n        item = page.items.get(name)\r\n\r\n    page.read_user_metadata(item, use_cache=False)\r\n    item_html = page.create_item_html(tabname, item, shared.html(\"extra-networks-card.html\"))\r\n\r\n    return JSONResponse({\"html\": item_html})\r\n\r\n\r\ndef add_pages_to_demo(app):\r\n    app.add_api_route(\"/sd_extra_networks/thumb\", fetch_file, methods=[\"GET\"])\r\n    app.add_api_route(\"/sd_extra_networks/cover-images\", fetch_cover_images, methods=[\"GET\"])\r\n    app.add_api_route(\"/sd_extra_networks/metadata\", get_metadata, methods=[\"GET\"])\r\n    app.add_api_route(\"/sd_extra_networks/get-single-card\", get_single_card, methods=[\"GET\"])\r\n\r\n\r\ndef quote_js(s):\r\n    s = s.replace('\\\\', '\\\\\\\\')\r\n    s = s.replace('\"', '\\\\\"')\r\n    return f'\"{s}\"'\r\n\r\n\r\nclass ExtraNetworksPage:\r\n    def __init__(self, title):\r\n        self.title = title\r\n        self.name = title.lower()\r\n        # This is the actual name of the extra networks tab (not txt2img/img2img).\r\n        self.extra_networks_tabname = self.name.replace(\" \", \"_\")\r\n        self.allow_prompt = True\r\n        self.allow_negative_prompt = False\r\n        self.metadata = {}\r\n        self.items = {}\r\n        self.lister = util.MassFileLister()\r\n        # HTML Templates\r\n        self.pane_tpl = shared.html(\"extra-networks-pane.html\")\r\n        self.pane_content_tree_tpl = shared.html(\"extra-networks-pane-tree.html\")\r\n        self.pane_content_dirs_tpl = shared.html(\"extra-networks-pane-dirs.html\")\r\n        self.card_tpl = shared.html(\"extra-networks-card.html\")\r\n        self.btn_tree_tpl = shared.html(\"extra-networks-tree-button.html\")\r\n        self.btn_copy_path_tpl = shared.html(\"extra-networks-copy-path-button.html\")\r\n        self.btn_metadata_tpl = shared.html(\"extra-networks-metadata-button.html\")\r\n        self.btn_edit_item_tpl = shared.html(\"extra-networks-edit-item-button.html\")\r\n\r\n    def refresh(self):\r\n        pass\r\n\r\n    def read_user_metadata(self, item, use_cache=True):\r\n        filename = item.get(\"filename\", None)\r\n        metadata = extra_networks.get_user_metadata(filename, lister=self.lister if use_cache else None)\r\n\r\n        desc = metadata.get(\"description\", None)\r\n        if desc is not None:\r\n            item[\"description\"] = desc\r\n\r\n        item[\"user_metadata\"] = metadata\r\n\r\n    def link_preview(self, filename):\r\n        quoted_filename = urllib.parse.quote(filename.replace('\\\\', '/'))\r\n        mtime, _ = self.lister.mctime(filename)\r\n        return f\"./sd_extra_networks/thumb?filename={quoted_filename}&mtime={mtime}\"\r\n\r\n    def search_terms_from_path(self, filename, possible_directories=None):\r\n        abspath = os.path.abspath(filename)\r\n        for parentdir in (possible_directories if possible_directories is not None else self.allowed_directories_for_previews()):\r\n            parentdir = os.path.dirname(os.path.abspath(parentdir))\r\n            if abspath.startswith(parentdir):\r\n                return os.path.relpath(abspath, parentdir)\r\n\r\n        return \"\"\r\n\r\n    def create_item_html(\r\n        self,\r\n        tabname: str,\r\n        item: dict,\r\n        template: Optional[str] = None,\r\n    ) -> Union[str, dict]:\r\n        \"\"\"Generates HTML for a single ExtraNetworks Item.\r\n\r\n        Args:\r\n            tabname: The name of the active tab.\r\n            item: Dictionary containing item information.\r\n            template: Optional template string to use.\r\n\r\n        Returns:\r\n            If a template is passed: HTML string generated for this item.\r\n                Can be empty if the item is not meant to be shown.\r\n            If no template is passed: A dictionary containing the generated item's attributes.\r\n        \"\"\"\r\n        preview = item.get(\"preview\", None)\r\n        style_height = f\"height: {shared.opts.extra_networks_card_height}px;\" if shared.opts.extra_networks_card_height else ''\r\n        style_width = f\"width: {shared.opts.extra_networks_card_width}px;\" if shared.opts.extra_networks_card_width else ''\r\n        style_font_size = f\"font-size: {shared.opts.extra_networks_card_text_scale*100}%;\"\r\n        card_style = style_height + style_width + style_font_size\r\n        background_image = f'<img src=\"{html.escape(preview)}\" class=\"preview\" loading=\"lazy\">' if preview else ''\r\n\r\n        onclick = item.get(\"onclick\", None)\r\n        if onclick is None:\r\n            # Don't quote prompt/neg_prompt since they are stored as js strings already.\r\n            onclick_js_tpl = \"cardClicked('{tabname}', {prompt}, {neg_prompt}, {allow_neg});\"\r\n            onclick = onclick_js_tpl.format(\r\n                **{\r\n                    \"tabname\": tabname,\r\n                    \"prompt\": item[\"prompt\"],\r\n                    \"neg_prompt\": item.get(\"negative_prompt\", \"''\"),\r\n                    \"allow_neg\": str(self.allow_negative_prompt).lower(),\r\n                }\r\n            )\r\n            onclick = html.escape(onclick)\r\n\r\n        btn_copy_path = self.btn_copy_path_tpl.format(**{\"filename\": item[\"filename\"]})\r\n        btn_metadata = \"\"\r\n        metadata = item.get(\"metadata\")\r\n        if metadata:\r\n            btn_metadata = self.btn_metadata_tpl.format(\r\n                **{\r\n                    \"extra_networks_tabname\": self.extra_networks_tabname,\r\n                }\r\n            )\r\n        btn_edit_item = self.btn_edit_item_tpl.format(\r\n            **{\r\n                \"tabname\": tabname,\r\n                \"extra_networks_tabname\": self.extra_networks_tabname,\r\n            }\r\n        )\r\n\r\n        local_path = \"\"\r\n        filename = item.get(\"filename\", \"\")\r\n        for reldir in self.allowed_directories_for_previews():\r\n            absdir = os.path.abspath(reldir)\r\n\r\n            if filename.startswith(absdir):\r\n                local_path = filename[len(absdir):]\r\n\r\n        # if this is true, the item must not be shown in the default view, and must instead only be\r\n        # shown when searching for it\r\n        if shared.opts.extra_networks_hidden_models == \"Always\":\r\n            search_only = False\r\n        else:\r\n            search_only = \"/.\" in local_path or \"\\\\.\" in local_path\r\n\r\n        if search_only and shared.opts.extra_networks_hidden_models == \"Never\":\r\n            return \"\"\r\n\r\n        sort_keys = \" \".join(\r\n            [\r\n                f'data-sort-{k}=\"{html.escape(str(v))}\"'\r\n                for k, v in item.get(\"sort_keys\", {}).items()\r\n            ]\r\n        ).strip()\r\n\r\n        search_terms_html = \"\"\r\n        search_term_template = \"<span class='hidden {class}'>{search_term}</span>\"\r\n        for search_term in item.get(\"search_terms\", []):\r\n            search_terms_html += search_term_template.format(\r\n                **{\r\n                    \"class\": f\"search_terms{' search_only' if search_only else ''}\",\r\n                    \"search_term\": search_term,\r\n                }\r\n            )\r\n\r\n        description = (item.get(\"description\", \"\") or \"\" if shared.opts.extra_networks_card_show_desc else \"\")\r\n        if not shared.opts.extra_networks_card_description_is_html:\r\n            description = html.escape(description)\r\n\r\n        # Some items here might not be used depending on HTML template used.\r\n        args = {\r\n            \"background_image\": background_image,\r\n            \"card_clicked\": onclick,\r\n            \"copy_path_button\": btn_copy_path,\r\n            \"description\": description,\r\n            \"edit_button\": btn_edit_item,\r\n            \"local_preview\": quote_js(item[\"local_preview\"]),\r\n            \"metadata_button\": btn_metadata,\r\n            \"name\": html.escape(item[\"name\"]),\r\n            \"prompt\": item.get(\"prompt\", None),\r\n            \"save_card_preview\": html.escape(f\"return saveCardPreview(event, '{tabname}', '{item['local_preview']}');\"),\r\n            \"search_only\": \" search_only\" if search_only else \"\",\r\n            \"search_terms\": search_terms_html,\r\n            \"sort_keys\": sort_keys,\r\n            \"style\": card_style,\r\n            \"tabname\": tabname,\r\n            \"extra_networks_tabname\": self.extra_networks_tabname,\r\n        }\r\n\r\n        if template:\r\n            return template.format(**args)\r\n        else:\r\n            return args\r\n\r\n    def create_tree_dir_item_html(\r\n        self,\r\n        tabname: str,\r\n        dir_path: str,\r\n        content: Optional[str] = None,\r\n    ) -> Optional[str]:\r\n        \"\"\"Generates HTML for a directory item in the tree.\r\n\r\n        The generated HTML is of the format:\r\n        ```html\r\n        <li class=\"tree-list-item tree-list-item--has-subitem\">\r\n            <div class=\"tree-list-content tree-list-content-dir\"></div>\r\n            <ul class=\"tree-list tree-list--subgroup\">\r\n                {content}\r\n            </ul>\r\n        </li>\r\n        ```\r\n\r\n        Args:\r\n            tabname: The name of the active tab.\r\n            dir_path: Path to the directory for this item.\r\n            content: Optional HTML string that will be wrapped by this <ul>.\r\n\r\n        Returns:\r\n            HTML formatted string.\r\n        \"\"\"\r\n        if not content:\r\n            return None\r\n\r\n        btn = self.btn_tree_tpl.format(\r\n            **{\r\n                \"search_terms\": \"\",\r\n                \"subclass\": \"tree-list-content-dir\",\r\n                \"tabname\": tabname,\r\n                \"extra_networks_tabname\": self.extra_networks_tabname,\r\n                \"onclick_extra\": \"\",\r\n                \"data_path\": dir_path,\r\n                \"data_hash\": \"\",\r\n                \"action_list_item_action_leading\": \"<i class='tree-list-item-action-chevron'></i>\",\r\n                \"action_list_item_visual_leading\": \"🗀\",\r\n                \"action_list_item_label\": os.path.basename(dir_path),\r\n                \"action_list_item_visual_trailing\": \"\",\r\n                \"action_list_item_action_trailing\": \"\",\r\n            }\r\n        )\r\n        ul = f\"<ul class='tree-list tree-list--subgroup' hidden>{content}</ul>\"\r\n        return (\r\n            \"<li class='tree-list-item tree-list-item--has-subitem' data-tree-entry-type='dir'>\"\r\n            f\"{btn}{ul}\"\r\n            \"</li>\"\r\n        )\r\n\r\n    def create_tree_file_item_html(self, tabname: str, file_path: str, item: dict) -> str:\r\n        \"\"\"Generates HTML for a file item in the tree.\r\n\r\n        The generated HTML is of the format:\r\n        ```html\r\n        <li class=\"tree-list-item tree-list-item--subitem\">\r\n            <span data-filterable-item-text hidden></span>\r\n            <div class=\"tree-list-content tree-list-content-file\"></div>\r\n        </li>\r\n        ```\r\n\r\n        Args:\r\n            tabname: The name of the active tab.\r\n            file_path: The path to the file for this item.\r\n            item: Dictionary containing the item information.\r\n\r\n        Returns:\r\n            HTML formatted string.\r\n        \"\"\"\r\n        item_html_args = self.create_item_html(tabname, item)\r\n        action_buttons = \"\".join(\r\n            [\r\n                item_html_args[\"copy_path_button\"],\r\n                item_html_args[\"metadata_button\"],\r\n                item_html_args[\"edit_button\"],\r\n            ]\r\n        )\r\n        action_buttons = f\"<div class=\\\"button-row\\\">{action_buttons}</div>\"\r\n        btn = self.btn_tree_tpl.format(\r\n            **{\r\n                \"search_terms\": \"\",\r\n                \"subclass\": \"tree-list-content-file\",\r\n                \"tabname\": tabname,\r\n                \"extra_networks_tabname\": self.extra_networks_tabname,\r\n                \"onclick_extra\": item_html_args[\"card_clicked\"],\r\n                \"data_path\": file_path,\r\n                \"data_hash\": item[\"shorthash\"],\r\n                \"action_list_item_action_leading\": \"<i class='tree-list-item-action-chevron'></i>\",\r\n                \"action_list_item_visual_leading\": \"🗎\",\r\n                \"action_list_item_label\": item[\"name\"],\r\n                \"action_list_item_visual_trailing\": \"\",\r\n                \"action_list_item_action_trailing\": action_buttons,\r\n            }\r\n        )\r\n        return (\r\n            \"<li class='tree-list-item tree-list-item--subitem' data-tree-entry-type='file'>\"\r\n            f\"{btn}\"\r\n            \"</li>\"\r\n        )\r\n\r\n    def create_tree_view_html(self, tabname: str) -> str:\r\n        \"\"\"Generates HTML for displaying folders in a tree view.\r\n\r\n        Args:\r\n            tabname: The name of the active tab.\r\n\r\n        Returns:\r\n            HTML string generated for this tree view.\r\n        \"\"\"\r\n        res = \"\"\r\n\r\n        # Setup the tree dictionary.\r\n        roots = self.allowed_directories_for_previews()\r\n        tree_items = {v[\"filename\"]: ExtraNetworksItem(v) for v in self.items.values()}\r\n        tree = get_tree([os.path.abspath(x) for x in roots], items=tree_items)\r\n\r\n        if not tree:\r\n            return res\r\n\r\n        def _build_tree(data: Optional[dict[str, ExtraNetworksItem]] = None) -> Optional[str]:\r\n            \"\"\"Recursively builds HTML for a tree.\r\n\r\n            Args:\r\n                data: Dictionary representing a directory tree. Can be NoneType.\r\n                    Data keys should be absolute paths from the root and values\r\n                    should be subdirectory trees or an ExtraNetworksItem.\r\n\r\n            Returns:\r\n                If data is not None: HTML string\r\n                Else: None\r\n            \"\"\"\r\n            if not data:\r\n                return None\r\n\r\n            # Lists for storing <li> items html for directories and files separately.\r\n            _dir_li = []\r\n            _file_li = []\r\n\r\n            for k, v in sorted(data.items(), key=lambda x: shared.natural_sort_key(x[0])):\r\n                if isinstance(v, (ExtraNetworksItem,)):\r\n                    _file_li.append(self.create_tree_file_item_html(tabname, k, v.item))\r\n                else:\r\n                    _dir_li.append(self.create_tree_dir_item_html(tabname, k, _build_tree(v)))\r\n\r\n            # Directories should always be displayed before files so we order them here.\r\n            return \"\".join(_dir_li) + \"\".join(_file_li)\r\n\r\n        # Add each root directory to the tree.\r\n        for k, v in sorted(tree.items(), key=lambda x: shared.natural_sort_key(x[0])):\r\n            item_html = self.create_tree_dir_item_html(tabname, k, _build_tree(v))\r\n            # Only add non-empty entries to the tree.\r\n            if item_html is not None:\r\n                res += item_html\r\n\r\n        return f\"<ul class='tree-list tree-list--tree'>{res}</ul>\"\r\n\r\n    def create_dirs_view_html(self, tabname: str) -> str:\r\n        \"\"\"Generates HTML for displaying folders.\"\"\"\r\n\r\n        subdirs = {}\r\n        for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:\r\n            for root, dirs, _ in sorted(os.walk(parentdir, followlinks=True), key=lambda x: shared.natural_sort_key(x[0])):\r\n                for dirname in sorted(dirs, key=shared.natural_sort_key):\r\n                    x = os.path.join(root, dirname)\r\n\r\n                    if not os.path.isdir(x):\r\n                        continue\r\n\r\n                    subdir = os.path.abspath(x)[len(parentdir):]\r\n\r\n                    if shared.opts.extra_networks_dir_button_function:\r\n                        if not subdir.startswith(os.path.sep):\r\n                            subdir = os.path.sep + subdir\r\n                    else:\r\n                        while subdir.startswith(os.path.sep):\r\n                            subdir = subdir[1:]\r\n\r\n                    is_empty = len(os.listdir(x)) == 0\r\n                    if not is_empty and not subdir.endswith(os.path.sep):\r\n                        subdir = subdir + os.path.sep\r\n\r\n                    if (os.path.sep + \".\" in subdir or subdir.startswith(\".\")) and not shared.opts.extra_networks_show_hidden_directories:\r\n                        continue\r\n\r\n                    subdirs[subdir] = 1\r\n\r\n        if subdirs:\r\n            subdirs = {\"\": 1, **subdirs}\r\n\r\n        subdirs_html = \"\".join([f\"\"\"\r\n        <button class='lg secondary gradio-button custom-button{\" search-all\" if subdir == \"\" else \"\"}' onclick='extraNetworksSearchButton(\"{tabname}\", \"{self.extra_networks_tabname}\", event)'>\r\n        {html.escape(subdir if subdir != \"\" else \"all\")}\r\n        </button>\r\n        \"\"\" for subdir in subdirs])\r\n\r\n        return subdirs_html\r\n\r\n    def create_card_view_html(self, tabname: str, *, none_message) -> str:\r\n        \"\"\"Generates HTML for the network Card View section for a tab.\r\n\r\n        This HTML goes into the `extra-networks-pane.html` <div> with\r\n        `id='{tabname}_{extra_networks_tabname}_cards`.\r\n\r\n        Args:\r\n            tabname: The name of the active tab.\r\n            none_message: HTML text to show when there are no cards.\r\n\r\n        Returns:\r\n            HTML formatted string.\r\n        \"\"\"\r\n        res = []\r\n        for item in self.items.values():\r\n            res.append(self.create_item_html(tabname, item, self.card_tpl))\r\n\r\n        if not res:\r\n            dirs = \"\".join([f\"<li>{x}</li>\" for x in self.allowed_directories_for_previews()])\r\n            res = [none_message or shared.html(\"extra-networks-no-cards.html\").format(dirs=dirs)]\r\n\r\n        return \"\".join(res)\r\n\r\n    def create_html(self, tabname, *, empty=False):\r\n        \"\"\"Generates an HTML string for the current pane.\r\n\r\n        The generated HTML uses `extra-networks-pane.html` as a template.\r\n\r\n        Args:\r\n            tabname: The name of the active tab.\r\n            empty: create an empty HTML page with no items\r\n\r\n        Returns:\r\n            HTML formatted string.\r\n        \"\"\"\r\n        self.lister.reset()\r\n        self.metadata = {}\r\n\r\n        items_list = [] if empty else self.list_items()\r\n        self.items = {x[\"name\"]: x for x in items_list}\r\n\r\n        # Populate the instance metadata for each item.\r\n        for item in self.items.values():\r\n            metadata = item.get(\"metadata\")\r\n            if metadata:\r\n                self.metadata[item[\"name\"]] = metadata\r\n\r\n            if \"user_metadata\" not in item:\r\n                self.read_user_metadata(item)\r\n\r\n        show_tree = shared.opts.extra_networks_tree_view_default_enabled\r\n\r\n        page_params = {\r\n            \"tabname\": tabname,\r\n            \"extra_networks_tabname\": self.extra_networks_tabname,\r\n            \"data_sortdir\": shared.opts.extra_networks_card_order,\r\n            \"sort_path_active\": ' extra-network-control--enabled' if shared.opts.extra_networks_card_order_field == 'Path' else '',\r\n            \"sort_name_active\": ' extra-network-control--enabled' if shared.opts.extra_networks_card_order_field == 'Name' else '',\r\n            \"sort_date_created_active\": ' extra-network-control--enabled' if shared.opts.extra_networks_card_order_field == 'Date Created' else '',\r\n            \"sort_date_modified_active\": ' extra-network-control--enabled' if shared.opts.extra_networks_card_order_field == 'Date Modified' else '',\r\n            \"tree_view_btn_extra_class\": \"extra-network-control--enabled\" if show_tree else \"\",\r\n            \"items_html\": self.create_card_view_html(tabname, none_message=\"Loading...\" if empty else None),\r\n            \"extra_networks_tree_view_default_width\": shared.opts.extra_networks_tree_view_default_width,\r\n            \"tree_view_div_default_display_class\": \"\" if show_tree else \"extra-network-dirs-hidden\",\r\n        }\r\n\r\n        if shared.opts.extra_networks_tree_view_style == \"Tree\":\r\n            pane_content = self.pane_content_tree_tpl.format(**page_params, tree_html=self.create_tree_view_html(tabname))\r\n        else:\r\n            pane_content = self.pane_content_dirs_tpl.format(**page_params, dirs_html=self.create_dirs_view_html(tabname))\r\n\r\n        return self.pane_tpl.format(**page_params, pane_content=pane_content)\r\n\r\n    def create_item(self, name, index=None):\r\n        raise NotImplementedError()\r\n\r\n    def list_items(self):\r\n        raise NotImplementedError()\r\n\r\n    def allowed_directories_for_previews(self):\r\n        return []\r\n\r\n    def get_sort_keys(self, path):\r\n        \"\"\"\r\n        List of default keys used for sorting in the UI.\r\n        \"\"\"\r\n        pth = Path(path)\r\n        mtime, ctime = self.lister.mctime(path)\r\n        return {\r\n            \"date_created\": int(mtime),\r\n            \"date_modified\": int(ctime),\r\n            \"name\": pth.name.lower(),\r\n            \"path\": str(pth).lower(),\r\n        }\r\n\r\n    def find_preview(self, path):\r\n        \"\"\"\r\n        Find a preview PNG for a given path (without extension) and call link_preview on it.\r\n        \"\"\"\r\n\r\n        potential_files = sum([[f\"{path}.{ext}\", f\"{path}.preview.{ext}\"] for ext in allowed_preview_extensions()], [])\r\n\r\n        for file in potential_files:\r\n            if self.lister.exists(file):\r\n                return self.link_preview(file)\r\n\r\n        return None\r\n\r\n    def find_embedded_preview(self, path, name, metadata):\r\n        \"\"\"\r\n        Find if embedded preview exists in safetensors metadata and return endpoint for it.\r\n        \"\"\"\r\n\r\n        file = f\"{path}.safetensors\"\r\n        if self.lister.exists(file) and 'ssmd_cover_images' in metadata and len(list(filter(None, json.loads(metadata['ssmd_cover_images'])))) > 0:\r\n            return f\"./sd_extra_networks/cover-images?page={self.extra_networks_tabname}&item={name}\"\r\n\r\n        return None\r\n\r\n    def find_description(self, path):\r\n        \"\"\"\r\n        Find and read a description file for a given path (without extension).\r\n        \"\"\"\r\n        for file in [f\"{path}.txt\", f\"{path}.description.txt\"]:\r\n            if not self.lister.exists(file):\r\n                continue\r\n\r\n            try:\r\n                with open(file, \"r\", encoding=\"utf-8\", errors=\"replace\") as f:\r\n                    return f.read()\r\n            except OSError:\r\n                pass\r\n        return None\r\n\r\n    def create_user_metadata_editor(self, ui, tabname):\r\n        return ui_extra_networks_user_metadata.UserMetadataEditor(ui, tabname, self)\r\n\r\n\r\ndef initialize():\r\n    extra_pages.clear()\r\n\r\n\r\ndef register_default_pages():\r\n    from modules.ui_extra_networks_textual_inversion import ExtraNetworksPageTextualInversion\r\n    from modules.ui_extra_networks_hypernets import ExtraNetworksPageHypernetworks\r\n    from modules.ui_extra_networks_checkpoints import ExtraNetworksPageCheckpoints\r\n    register_page(ExtraNetworksPageTextualInversion())\r\n    register_page(ExtraNetworksPageHypernetworks())\r\n    register_page(ExtraNetworksPageCheckpoints())\r\n\r\n\r\nclass ExtraNetworksUi:\r\n    def __init__(self):\r\n        self.pages = None\r\n        \"\"\"gradio HTML components related to extra networks' pages\"\"\"\r\n\r\n        self.page_contents = None\r\n        \"\"\"HTML content of the above; empty initially, filled when extra pages have to be shown\"\"\"\r\n\r\n        self.stored_extra_pages = None\r\n\r\n        self.button_save_preview = None\r\n        self.preview_target_filename = None\r\n\r\n        self.tabname = None\r\n\r\n\r\ndef pages_in_preferred_order(pages):\r\n    tab_order = [x.lower().strip() for x in shared.opts.ui_extra_networks_tab_reorder.split(\",\")]\r\n\r\n    def tab_name_score(name):\r\n        name = name.lower()\r\n        for i, possible_match in enumerate(tab_order):\r\n            if possible_match in name:\r\n                return i\r\n\r\n        return len(pages)\r\n\r\n    tab_scores = {page.name: (tab_name_score(page.name), original_index) for original_index, page in enumerate(pages)}\r\n\r\n    return sorted(pages, key=lambda x: tab_scores[x.name])\r\n\r\n\r\ndef create_ui(interface: gr.Blocks, unrelated_tabs, tabname):\r\n    ui = ExtraNetworksUi()\r\n    ui.pages = []\r\n    ui.pages_contents = []\r\n    ui.user_metadata_editors = []\r\n    ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy())\r\n    ui.tabname = tabname\r\n\r\n    related_tabs = []\r\n\r\n    for page in ui.stored_extra_pages:\r\n        with gr.Tab(page.title, elem_id=f\"{tabname}_{page.extra_networks_tabname}\", elem_classes=[\"extra-page\"]) as tab:\r\n            with gr.Column(elem_id=f\"{tabname}_{page.extra_networks_tabname}_prompts\", elem_classes=[\"extra-page-prompts\"]):\r\n                pass\r\n\r\n            elem_id = f\"{tabname}_{page.extra_networks_tabname}_cards_html\"\r\n            page_elem = gr.HTML(page.create_html(tabname, empty=True), elem_id=elem_id)\r\n            ui.pages.append(page_elem)\r\n            editor = page.create_user_metadata_editor(ui, tabname)\r\n            editor.create_ui()\r\n            ui.user_metadata_editors.append(editor)\r\n            related_tabs.append(tab)\r\n\r\n    ui.button_save_preview = gr.Button('Save preview', elem_id=f\"{tabname}_save_preview\", visible=False)\r\n    ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=f\"{tabname}_preview_filename\", visible=False)\r\n\r\n    for tab in unrelated_tabs:\r\n        tab.select(fn=None, _js=f\"function(){{extraNetworksUnrelatedTabSelected('{tabname}');}}\", inputs=[], outputs=[], show_progress=False)\r\n\r\n    for page, tab in zip(ui.stored_extra_pages, related_tabs):\r\n        jscode = (\r\n            \"function(){{\"\r\n            f\"extraNetworksTabSelected('{tabname}', '{tabname}_{page.extra_networks_tabname}_prompts', {str(page.allow_prompt).lower()}, {str(page.allow_negative_prompt).lower()}, '{tabname}_{page.extra_networks_tabname}');\"\r\n            f\"applyExtraNetworkFilter('{tabname}_{page.extra_networks_tabname}');\"\r\n            \"}}\"\r\n        )\r\n        tab.select(fn=None, _js=jscode, inputs=[], outputs=[], show_progress=False)\r\n\r\n        def refresh():\r\n            for pg in ui.stored_extra_pages:\r\n                pg.refresh()\r\n            create_html()\r\n            return ui.pages_contents\r\n\r\n        button_refresh = gr.Button(\"Refresh\", elem_id=f\"{tabname}_{page.extra_networks_tabname}_extra_refresh_internal\", visible=False)\r\n        button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages).then(fn=lambda: None, _js=\"function(){ \" + f\"applyExtraNetworkFilter('{tabname}_{page.extra_networks_tabname}');\" + \" }\").then(fn=lambda: None, _js='setupAllResizeHandles')\r\n\r\n    def create_html():\r\n        ui.pages_contents = [pg.create_html(ui.tabname) for pg in ui.stored_extra_pages]\r\n\r\n    def pages_html():\r\n        if not ui.pages_contents:\r\n            create_html()\r\n        return ui.pages_contents\r\n\r\n    interface.load(fn=pages_html, inputs=[], outputs=ui.pages).then(fn=lambda: None, _js='setupAllResizeHandles')\r\n\r\n    return ui\r\n\r\n\r\ndef path_is_parent(parent_path, child_path):\r\n    parent_path = os.path.abspath(parent_path)\r\n    child_path = os.path.abspath(child_path)\r\n\r\n    return child_path.startswith(parent_path)\r\n\r\n\r\ndef setup_ui(ui, gallery):\r\n    def save_preview(index, images, filename):\r\n        # this function is here for backwards compatibility and likely will be removed soon\r\n\r\n        if len(images) == 0:\r\n            print(\"There is no image in gallery to save as a preview.\")\r\n            return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]\r\n\r\n        index = int(index)\r\n        index = 0 if index < 0 else index\r\n        index = len(images) - 1 if index >= len(images) else index\r\n\r\n        img_info = images[index if index >= 0 else 0]\r\n        image = image_from_url_text(img_info)\r\n        geninfo, items = read_info_from_image(image)\r\n\r\n        is_allowed = False\r\n        for extra_page in ui.stored_extra_pages:\r\n            if any(path_is_parent(x, filename) for x in extra_page.allowed_directories_for_previews()):\r\n                is_allowed = True\r\n                break\r\n\r\n        assert is_allowed, f'writing to {filename} is not allowed'\r\n\r\n        save_image_with_geninfo(image, geninfo, filename)\r\n\r\n        return [page.create_html(ui.tabname) for page in ui.stored_extra_pages]\r\n\r\n    ui.button_save_preview.click(\r\n        fn=save_preview,\r\n        _js=\"function(x, y, z){return [selected_gallery_index(), y, z]}\",\r\n        inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename],\r\n        outputs=[*ui.pages]\r\n    )\r\n\r\n    for editor in ui.user_metadata_editors:\r\n        editor.setup_ui(gallery)\r\n"
  },
  {
    "path": "modules/ui_extra_networks_checkpoints.py",
    "content": "import html\r\nimport os\r\n\r\nfrom modules import shared, ui_extra_networks, sd_models\r\nfrom modules.ui_extra_networks_checkpoints_user_metadata import CheckpointUserMetadataEditor\r\n\r\n\r\nclass ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):\r\n    def __init__(self):\r\n        super().__init__('Checkpoints')\r\n\r\n        self.allow_prompt = False\r\n\r\n    def refresh(self):\r\n        shared.refresh_checkpoints()\r\n\r\n    def create_item(self, name, index=None, enable_filter=True):\r\n        checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name)\r\n        if checkpoint is None:\r\n            return\r\n\r\n        path, ext = os.path.splitext(checkpoint.filename)\r\n        search_terms = [self.search_terms_from_path(checkpoint.filename)]\r\n        if checkpoint.sha256:\r\n            search_terms.append(checkpoint.sha256)\r\n        return {\r\n            \"name\": checkpoint.name_for_extra,\r\n            \"filename\": checkpoint.filename,\r\n            \"shorthash\": checkpoint.shorthash,\r\n            \"preview\": self.find_preview(path),\r\n            \"description\": self.find_description(path),\r\n            \"search_terms\": search_terms,\r\n            \"onclick\": html.escape(f\"return selectCheckpoint({ui_extra_networks.quote_js(name)})\"),\r\n            \"local_preview\": f\"{path}.{shared.opts.samples_format}\",\r\n            \"metadata\": checkpoint.metadata,\r\n            \"sort_keys\": {'default': index, **self.get_sort_keys(checkpoint.filename)},\r\n        }\r\n\r\n    def list_items(self):\r\n        # instantiate a list to protect against concurrent modification\r\n        names = list(sd_models.checkpoints_list)\r\n        for index, name in enumerate(names):\r\n            item = self.create_item(name, index)\r\n            if item is not None:\r\n                yield item\r\n\r\n    def allowed_directories_for_previews(self):\r\n        return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]\r\n\r\n    def create_user_metadata_editor(self, ui, tabname):\r\n        return CheckpointUserMetadataEditor(ui, tabname, self)\r\n"
  },
  {
    "path": "modules/ui_extra_networks_checkpoints_user_metadata.py",
    "content": "import gradio as gr\r\n\r\nfrom modules import ui_extra_networks_user_metadata, sd_vae, shared\r\nfrom modules.ui_common import create_refresh_button\r\n\r\n\r\nclass CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor):\r\n    def __init__(self, ui, tabname, page):\r\n        super().__init__(ui, tabname, page)\r\n\r\n        self.select_vae = None\r\n\r\n    def save_user_metadata(self, name, desc, notes, vae):\r\n        user_metadata = self.get_user_metadata(name)\r\n        user_metadata[\"description\"] = desc\r\n        user_metadata[\"notes\"] = notes\r\n        user_metadata[\"vae\"] = vae\r\n\r\n        self.write_user_metadata(name, user_metadata)\r\n\r\n    def update_vae(self, name):\r\n        if name == shared.sd_model.sd_checkpoint_info.name_for_extra:\r\n            sd_vae.reload_vae_weights()\r\n\r\n    def put_values_into_components(self, name):\r\n        user_metadata = self.get_user_metadata(name)\r\n        values = super().put_values_into_components(name)\r\n\r\n        return [\r\n            *values[0:5],\r\n            user_metadata.get('vae', ''),\r\n        ]\r\n\r\n    def create_editor(self):\r\n        self.create_default_editor_elems()\r\n\r\n        with gr.Row():\r\n            self.select_vae = gr.Dropdown(choices=[\"Automatic\", \"None\"] + list(sd_vae.vae_dict), value=\"None\", label=\"Preferred VAE\", elem_id=\"checpoint_edit_user_metadata_preferred_vae\")\r\n            create_refresh_button(self.select_vae, sd_vae.refresh_vae_list, lambda: {\"choices\": [\"Automatic\", \"None\"] + list(sd_vae.vae_dict)}, \"checpoint_edit_user_metadata_refresh_preferred_vae\")\r\n\r\n        self.edit_notes = gr.TextArea(label='Notes', lines=4)\r\n\r\n        self.create_default_buttons()\r\n\r\n        viewed_components = [\r\n            self.edit_name,\r\n            self.edit_description,\r\n            self.html_filedata,\r\n            self.html_preview,\r\n            self.edit_notes,\r\n            self.select_vae,\r\n        ]\r\n\r\n        self.button_edit\\\r\n            .click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=viewed_components)\\\r\n            .then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box])\r\n\r\n        edited_components = [\r\n            self.edit_description,\r\n            self.edit_notes,\r\n            self.select_vae,\r\n        ]\r\n\r\n        self.setup_save_handler(self.button_save, self.save_user_metadata, edited_components)\r\n        self.button_save.click(fn=self.update_vae, inputs=[self.edit_name_input])\r\n\r\n"
  },
  {
    "path": "modules/ui_extra_networks_hypernets.py",
    "content": "import os\r\n\r\nfrom modules import shared, ui_extra_networks\r\nfrom modules.ui_extra_networks import quote_js\r\nfrom modules.hashes import sha256_from_cache\r\n\r\n\r\nclass ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):\r\n    def __init__(self):\r\n        super().__init__('Hypernetworks')\r\n\r\n    def refresh(self):\r\n        shared.reload_hypernetworks()\r\n\r\n    def create_item(self, name, index=None, enable_filter=True):\r\n        full_path = shared.hypernetworks.get(name)\r\n        if full_path is None:\r\n            return\r\n\r\n        path, ext = os.path.splitext(full_path)\r\n        sha256 = sha256_from_cache(full_path, f'hypernet/{name}')\r\n        shorthash = sha256[0:10] if sha256 else None\r\n        search_terms = [self.search_terms_from_path(path)]\r\n        if sha256:\r\n            search_terms.append(sha256)\r\n        return {\r\n            \"name\": name,\r\n            \"filename\": full_path,\r\n            \"shorthash\": shorthash,\r\n            \"preview\": self.find_preview(path),\r\n            \"description\": self.find_description(path),\r\n            \"search_terms\": search_terms,\r\n            \"prompt\": quote_js(f\"<hypernet:{name}:\") + \" + opts.extra_networks_default_multiplier + \" + quote_js(\">\"),\r\n            \"local_preview\": f\"{path}.preview.{shared.opts.samples_format}\",\r\n            \"sort_keys\": {'default': index, **self.get_sort_keys(path + ext)},\r\n        }\r\n\r\n    def list_items(self):\r\n        # instantiate a list to protect against concurrent modification\r\n        names = list(shared.hypernetworks)\r\n        for index, name in enumerate(names):\r\n            item = self.create_item(name, index)\r\n            if item is not None:\r\n                yield item\r\n\r\n    def allowed_directories_for_previews(self):\r\n        return [shared.cmd_opts.hypernetwork_dir]\r\n\r\n"
  },
  {
    "path": "modules/ui_extra_networks_textual_inversion.py",
    "content": "import os\r\n\r\nfrom modules import ui_extra_networks, sd_hijack, shared\r\nfrom modules.ui_extra_networks import quote_js\r\n\r\n\r\nclass ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):\r\n    def __init__(self):\r\n        super().__init__('Textual Inversion')\r\n        self.allow_negative_prompt = True\r\n\r\n    def refresh(self):\r\n        sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)\r\n\r\n    def create_item(self, name, index=None, enable_filter=True):\r\n        embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name)\r\n        if embedding is None:\r\n            return\r\n\r\n        path, ext = os.path.splitext(embedding.filename)\r\n        search_terms = [self.search_terms_from_path(embedding.filename)]\r\n        if embedding.hash:\r\n            search_terms.append(embedding.hash)\r\n        return {\r\n            \"name\": name,\r\n            \"filename\": embedding.filename,\r\n            \"shorthash\": embedding.shorthash,\r\n            \"preview\": self.find_preview(path),\r\n            \"description\": self.find_description(path),\r\n            \"search_terms\": search_terms,\r\n            \"prompt\": quote_js(embedding.name),\r\n            \"local_preview\": f\"{path}.preview.{shared.opts.samples_format}\",\r\n            \"sort_keys\": {'default': index, **self.get_sort_keys(embedding.filename)},\r\n        }\r\n\r\n    def list_items(self):\r\n        # instantiate a list to protect against concurrent modification\r\n        names = list(sd_hijack.model_hijack.embedding_db.word_embeddings)\r\n        for index, name in enumerate(names):\r\n            item = self.create_item(name, index)\r\n            if item is not None:\r\n                yield item\r\n\r\n    def allowed_directories_for_previews(self):\r\n        return list(sd_hijack.model_hijack.embedding_db.embedding_dirs)\r\n"
  },
  {
    "path": "modules/ui_extra_networks_user_metadata.py",
    "content": "import datetime\r\nimport html\r\nimport json\r\nimport os.path\r\n\r\nimport gradio as gr\r\n\r\nfrom modules import infotext_utils, images, sysinfo, errors, ui_extra_networks\r\n\r\n\r\nclass UserMetadataEditor:\r\n\r\n    def __init__(self, ui, tabname, page):\r\n        self.ui = ui\r\n        self.tabname = tabname\r\n        self.page = page\r\n        self.id_part = f\"{self.tabname}_{self.page.extra_networks_tabname}_edit_user_metadata\"\r\n\r\n        self.box = None\r\n\r\n        self.edit_name_input = None\r\n        self.button_edit = None\r\n\r\n        self.edit_name = None\r\n        self.edit_description = None\r\n        self.edit_notes = None\r\n        self.html_filedata = None\r\n        self.html_preview = None\r\n        self.html_status = None\r\n\r\n        self.button_cancel = None\r\n        self.button_replace_preview = None\r\n        self.button_save = None\r\n\r\n    def get_user_metadata(self, name):\r\n        item = self.page.items.get(name, {})\r\n\r\n        user_metadata = item.get('user_metadata', None)\r\n        if not user_metadata:\r\n            user_metadata = {'description': item.get('description', '')}\r\n            item['user_metadata'] = user_metadata\r\n\r\n        return user_metadata\r\n\r\n    def create_extra_default_items_in_left_column(self):\r\n        pass\r\n\r\n    def create_default_editor_elems(self):\r\n        with gr.Row():\r\n            with gr.Column(scale=2):\r\n                self.edit_name = gr.HTML(elem_classes=\"extra-network-name\")\r\n                self.edit_description = gr.Textbox(label=\"Description\", lines=4)\r\n                self.html_filedata = gr.HTML()\r\n\r\n                self.create_extra_default_items_in_left_column()\r\n\r\n            with gr.Column(scale=1, min_width=0):\r\n                self.html_preview = gr.HTML()\r\n\r\n    def create_default_buttons(self):\r\n\r\n        with gr.Row(elem_classes=\"edit-user-metadata-buttons\"):\r\n            self.button_cancel = gr.Button('Cancel')\r\n            self.button_replace_preview = gr.Button('Replace preview', variant='primary')\r\n            self.button_save = gr.Button('Save', variant='primary')\r\n\r\n        self.html_status = gr.HTML(elem_classes=\"edit-user-metadata-status\")\r\n\r\n        self.button_cancel.click(fn=None, _js=\"closePopup\")\r\n\r\n    def get_card_html(self, name):\r\n        item = self.page.items.get(name, {})\r\n\r\n        preview_url = item.get(\"preview\", None)\r\n\r\n        if not preview_url:\r\n            filename, _ = os.path.splitext(item[\"filename\"])\r\n            preview_url = self.page.find_preview(filename)\r\n            item[\"preview\"] = preview_url\r\n\r\n        if preview_url:\r\n            preview = f'''\r\n            <div class='card standalone-card-preview'>\r\n                <img src=\"{html.escape(preview_url)}\" class=\"preview\">\r\n            </div>\r\n            '''\r\n        else:\r\n            preview = \"<div class='card standalone-card-preview'></div>\"\r\n\r\n        return preview\r\n\r\n    def relative_path(self, path):\r\n        for parent_path in self.page.allowed_directories_for_previews():\r\n            if ui_extra_networks.path_is_parent(parent_path, path):\r\n                return os.path.relpath(path, parent_path)\r\n\r\n        return os.path.basename(path)\r\n\r\n    def get_metadata_table(self, name):\r\n        item = self.page.items.get(name, {})\r\n        try:\r\n            filename = item[\"filename\"]\r\n            shorthash = item.get(\"shorthash\", None)\r\n\r\n            stats = os.stat(filename)\r\n            params = [\r\n                ('Filename: ', self.relative_path(filename)),\r\n                ('File size: ', sysinfo.pretty_bytes(stats.st_size)),\r\n                ('Hash: ', shorthash),\r\n                ('Modified: ', datetime.datetime.fromtimestamp(stats.st_mtime).strftime('%Y-%m-%d %H:%M')),\r\n            ]\r\n\r\n            return params\r\n        except Exception as e:\r\n            errors.display(e, f\"reading info for {name}\")\r\n            return []\r\n\r\n    def put_values_into_components(self, name):\r\n        user_metadata = self.get_user_metadata(name)\r\n\r\n        try:\r\n            params = self.get_metadata_table(name)\r\n        except Exception as e:\r\n            errors.display(e, f\"reading metadata info for {name}\")\r\n            params = []\r\n\r\n        table = '<table class=\"file-metadata\">' + \"\".join(f\"<tr><th>{name}</th><td>{value}</td></tr>\" for name, value in params if value is not None) + '</table>'\r\n\r\n        return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name), user_metadata.get('notes', '')\r\n\r\n    def write_user_metadata(self, name, metadata):\r\n        item = self.page.items.get(name, {})\r\n        filename = item.get(\"filename\", None)\r\n        basename, ext = os.path.splitext(filename)\r\n\r\n        metadata_path = basename + '.json'\r\n        with open(metadata_path, \"w\", encoding=\"utf8\") as file:\r\n            json.dump(metadata, file, indent=4, ensure_ascii=False)\r\n        self.page.lister.update_file_entry(metadata_path)\r\n\r\n    def save_user_metadata(self, name, desc, notes):\r\n        user_metadata = self.get_user_metadata(name)\r\n        user_metadata[\"description\"] = desc\r\n        user_metadata[\"notes\"] = notes\r\n\r\n        self.write_user_metadata(name, user_metadata)\r\n\r\n    def setup_save_handler(self, button, func, components):\r\n        button\\\r\n            .click(fn=func, inputs=[self.edit_name_input, *components], outputs=[])\\\r\n            .then(fn=None, _js=\"function(name){closePopup(); extraNetworksRefreshSingleCard(\" + json.dumps(self.page.name) + \",\" + json.dumps(self.tabname) + \", name);}\", inputs=[self.edit_name_input], outputs=[])\r\n\r\n    def create_editor(self):\r\n        self.create_default_editor_elems()\r\n\r\n        self.edit_notes = gr.TextArea(label='Notes', lines=4)\r\n\r\n        self.create_default_buttons()\r\n\r\n        self.button_edit\\\r\n            .click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=[self.edit_name, self.edit_description, self.html_filedata, self.html_preview, self.edit_notes])\\\r\n            .then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box])\r\n\r\n        self.setup_save_handler(self.button_save, self.save_user_metadata, [self.edit_description, self.edit_notes])\r\n\r\n    def create_ui(self):\r\n        with gr.Box(visible=False, elem_id=self.id_part, elem_classes=\"edit-user-metadata\") as box:\r\n            self.box = box\r\n\r\n            self.edit_name_input = gr.Textbox(\"Edit user metadata card id\", visible=False, elem_id=f\"{self.id_part}_name\")\r\n            self.button_edit = gr.Button(\"Edit user metadata\", visible=False, elem_id=f\"{self.id_part}_button\")\r\n\r\n            self.create_editor()\r\n\r\n    def save_preview(self, index, gallery, name):\r\n        if len(gallery) == 0:\r\n            return self.get_card_html(name), \"There is no image in gallery to save as a preview.\"\r\n\r\n        item = self.page.items.get(name, {})\r\n\r\n        index = int(index)\r\n        index = 0 if index < 0 else index\r\n        index = len(gallery) - 1 if index >= len(gallery) else index\r\n\r\n        img_info = gallery[index if index >= 0 else 0]\r\n        image = infotext_utils.image_from_url_text(img_info)\r\n        geninfo, items = images.read_info_from_image(image)\r\n\r\n        images.save_image_with_geninfo(image, geninfo, item[\"local_preview\"])\r\n        self.page.lister.update_file_entry(item[\"local_preview\"])\r\n        item['preview'] = self.page.find_preview(item[\"local_preview\"])\r\n        return self.get_card_html(name), ''\r\n\r\n    def setup_ui(self, gallery):\r\n        self.button_replace_preview.click(\r\n            fn=self.save_preview,\r\n            _js=f\"function(x, y, z){{return [selected_gallery_index_id('{self.tabname + '_gallery_container'}'), y, z]}}\",\r\n            inputs=[self.edit_name_input, gallery, self.edit_name_input],\r\n            outputs=[self.html_preview, self.html_status]\r\n        ).then(\r\n            fn=None,\r\n            _js=\"function(name){extraNetworksRefreshSingleCard(\" + json.dumps(self.page.name) + \",\" + json.dumps(self.tabname) + \", name);}\",\r\n            inputs=[self.edit_name_input],\r\n            outputs=[]\r\n        )\r\n"
  },
  {
    "path": "modules/ui_gradio_extensions.py",
    "content": "import os\r\nimport gradio as gr\r\n\r\nfrom modules import localization, shared, scripts, util\r\nfrom modules.paths import script_path, data_path\r\n\r\n\r\ndef webpath(fn):\r\n    return f'file={util.truncate_path(fn)}?{os.path.getmtime(fn)}'\r\n\r\n\r\ndef javascript_html():\r\n    # Ensure localization is in `window` before scripts\r\n    head = f'<script type=\"text/javascript\">{localization.localization_js(shared.opts.localization)}</script>\\n'\r\n\r\n    script_js = os.path.join(script_path, \"script.js\")\r\n    head += f'<script type=\"text/javascript\" src=\"{webpath(script_js)}\"></script>\\n'\r\n\r\n    for script in scripts.list_scripts(\"javascript\", \".js\"):\r\n        head += f'<script type=\"text/javascript\" src=\"{webpath(script.path)}\"></script>\\n'\r\n\r\n    for script in scripts.list_scripts(\"javascript\", \".mjs\"):\r\n        head += f'<script type=\"module\" src=\"{webpath(script.path)}\"></script>\\n'\r\n\r\n    if shared.cmd_opts.theme:\r\n        head += f'<script type=\"text/javascript\">set_theme(\\\"{shared.cmd_opts.theme}\\\");</script>\\n'\r\n\r\n    return head\r\n\r\n\r\ndef css_html():\r\n    head = \"\"\r\n\r\n    def stylesheet(fn):\r\n        return f'<link rel=\"stylesheet\" property=\"stylesheet\" href=\"{webpath(fn)}\">'\r\n\r\n    for cssfile in scripts.list_files_with_name(\"style.css\"):\r\n        head += stylesheet(cssfile)\r\n\r\n    user_css = os.path.join(data_path, \"user.css\")\r\n    if os.path.exists(user_css):\r\n        head += stylesheet(user_css)\r\n\r\n    from modules.shared_gradio_themes import resolve_var\r\n    light = resolve_var('background_fill_primary')\r\n    dark = resolve_var('background_fill_primary_dark')\r\n    head += f'<style>html {{ background-color: {light}; }} @media (prefers-color-scheme: dark) {{ html {{background-color:  {dark}; }} }}</style>'\r\n\r\n    return head\r\n\r\n\r\ndef reload_javascript():\r\n    js = javascript_html()\r\n    css = css_html()\r\n\r\n    def template_response(*args, **kwargs):\r\n        res = shared.GradioTemplateResponseOriginal(*args, **kwargs)\r\n        res.body = res.body.replace(b'</head>', f'{js}<meta name=\"referrer\" content=\"no-referrer\"/></head>'.encode(\"utf8\"))\r\n        res.body = res.body.replace(b'</body>', f'{css}</body>'.encode(\"utf8\"))\r\n        res.init_headers()\r\n        return res\r\n\r\n    gr.routes.templates.TemplateResponse = template_response\r\n\r\n\r\nif not hasattr(shared, 'GradioTemplateResponseOriginal'):\r\n    shared.GradioTemplateResponseOriginal = gr.routes.templates.TemplateResponse\r\n"
  },
  {
    "path": "modules/ui_loadsave.py",
    "content": "import json\r\nimport os\r\n\r\nimport gradio as gr\r\n\r\nfrom modules import errors\r\nfrom modules.ui_components import ToolButton, InputAccordion\r\n\r\n\r\ndef radio_choices(comp):  # gradio 3.41 changes choices from list of values to list of pairs\r\n    return [x[0] if isinstance(x, tuple) else x for x in getattr(comp, 'choices', [])]\r\n\r\n\r\nclass UiLoadsave:\r\n    \"\"\"allows saving and restoring default values for gradio components\"\"\"\r\n\r\n    def __init__(self, filename):\r\n        self.filename = filename\r\n        self.ui_settings = {}\r\n        self.component_mapping = {}\r\n        self.error_loading = False\r\n        self.finalized_ui = False\r\n\r\n        self.ui_defaults_view = None\r\n        self.ui_defaults_apply = None\r\n        self.ui_defaults_review = None\r\n\r\n        try:\r\n            self.ui_settings = self.read_from_file()\r\n        except FileNotFoundError:\r\n            pass\r\n        except Exception as e:\r\n            self.error_loading = True\r\n            errors.display(e, \"loading settings\")\r\n\r\n    def add_component(self, path, x):\r\n        \"\"\"adds component to the registry of tracked components\"\"\"\r\n\r\n        assert not self.finalized_ui\r\n\r\n        def apply_field(obj, field, condition=None, init_field=None):\r\n            key = f\"{path}/{field}\"\r\n\r\n            if getattr(obj, 'custom_script_source', None) is not None:\r\n                key = f\"customscript/{obj.custom_script_source}/{key}\"\r\n\r\n            if getattr(obj, 'do_not_save_to_config', False):\r\n                return\r\n\r\n            saved_value = self.ui_settings.get(key, None)\r\n\r\n            if isinstance(obj, gr.Accordion) and isinstance(x, InputAccordion) and field == 'value':\r\n                field = 'open'\r\n\r\n            if saved_value is None:\r\n                self.ui_settings[key] = getattr(obj, field)\r\n            elif condition and not condition(saved_value):\r\n                pass\r\n            else:\r\n                if isinstance(obj, gr.Textbox) and field == 'value':  # due to an undesirable behavior of gr.Textbox, if you give it an int value instead of str, everything dies\r\n                    saved_value = str(saved_value)\r\n                elif isinstance(obj, gr.Number) and field == 'value':\r\n                    try:\r\n                        saved_value = float(saved_value)\r\n                    except ValueError:\r\n                        return\r\n\r\n                setattr(obj, field, saved_value)\r\n                if init_field is not None:\r\n                    init_field(saved_value)\r\n\r\n            if field == 'value' and key not in self.component_mapping:\r\n                self.component_mapping[key] = obj\r\n\r\n        if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown, ToolButton, gr.Button] and x.visible:\r\n            apply_field(x, 'visible')\r\n\r\n        if type(x) == gr.Slider:\r\n            apply_field(x, 'value')\r\n            apply_field(x, 'minimum')\r\n            apply_field(x, 'maximum')\r\n            apply_field(x, 'step')\r\n\r\n        if type(x) == gr.Radio:\r\n            apply_field(x, 'value', lambda val: val in radio_choices(x))\r\n\r\n        if type(x) == gr.Checkbox:\r\n            apply_field(x, 'value')\r\n\r\n        if type(x) == gr.Textbox:\r\n            apply_field(x, 'value')\r\n\r\n        if type(x) == gr.Number:\r\n            apply_field(x, 'value')\r\n\r\n        if type(x) == gr.Dropdown:\r\n            def check_dropdown(val):\r\n                choices = radio_choices(x)\r\n                if getattr(x, 'multiselect', False):\r\n                    return all(value in choices for value in val)\r\n                else:\r\n                    return val in choices\r\n\r\n            apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None))\r\n\r\n        if type(x) == InputAccordion:\r\n            if hasattr(x, 'custom_script_source'):\r\n                x.accordion.custom_script_source = x.custom_script_source\r\n            if x.accordion.visible:\r\n                apply_field(x.accordion, 'visible')\r\n            apply_field(x, 'value')\r\n            apply_field(x.accordion, 'value')\r\n\r\n        def check_tab_id(tab_id):\r\n            tab_items = list(filter(lambda e: isinstance(e, gr.TabItem), x.children))\r\n            if type(tab_id) == str:\r\n                tab_ids = [t.id for t in tab_items]\r\n                return tab_id in tab_ids\r\n            elif type(tab_id) == int:\r\n                return 0 <= tab_id < len(tab_items)\r\n            else:\r\n                return False\r\n\r\n        if type(x) == gr.Tabs:\r\n            apply_field(x, 'selected', check_tab_id)\r\n\r\n    def add_block(self, x, path=\"\"):\r\n        \"\"\"adds all components inside a gradio block x to the registry of tracked components\"\"\"\r\n\r\n        if hasattr(x, 'children'):\r\n            if isinstance(x, gr.Tabs) and x.elem_id is not None:\r\n                # Tabs element can't have a label, have to use elem_id instead\r\n                self.add_component(f\"{path}/Tabs@{x.elem_id}\", x)\r\n            for c in x.children:\r\n                self.add_block(c, path)\r\n        elif x.label is not None:\r\n            self.add_component(f\"{path}/{x.label}\", x)\r\n        elif isinstance(x, gr.Button) and x.value is not None:\r\n            self.add_component(f\"{path}/{x.value}\", x)\r\n\r\n    def read_from_file(self):\r\n        with open(self.filename, \"r\", encoding=\"utf8\") as file:\r\n            return json.load(file)\r\n\r\n    def write_to_file(self, current_ui_settings):\r\n        with open(self.filename, \"w\", encoding=\"utf8\") as file:\r\n            json.dump(current_ui_settings, file, indent=4, ensure_ascii=False)\r\n\r\n    def dump_defaults(self):\r\n        \"\"\"saves default values to a file unless the file is present and there was an error loading default values at start\"\"\"\r\n\r\n        if self.error_loading and os.path.exists(self.filename):\r\n            return\r\n\r\n        self.write_to_file(self.ui_settings)\r\n\r\n    def iter_changes(self, current_ui_settings, values):\r\n        \"\"\"\r\n        given a dictionary with defaults from a file and current values from gradio elements, returns\r\n        an iterator over tuples of values that are not the same between the file and the current;\r\n        tuple contents are: path, old value, new value\r\n        \"\"\"\r\n\r\n        for (path, component), new_value in zip(self.component_mapping.items(), values):\r\n            old_value = current_ui_settings.get(path)\r\n\r\n            choices = radio_choices(component)\r\n            if isinstance(new_value, int) and choices:\r\n                if new_value >= len(choices):\r\n                    continue\r\n\r\n                new_value = choices[new_value]\r\n                if isinstance(new_value, tuple):\r\n                    new_value = new_value[0]\r\n\r\n            if new_value == old_value:\r\n                continue\r\n\r\n            if old_value is None and new_value == '' or new_value == []:\r\n                continue\r\n\r\n            yield path, old_value, new_value\r\n\r\n    def ui_view(self, *values):\r\n        text = [\"<table><thead><tr><th>Path</th><th>Old value</th><th>New value</th></thead><tbody>\"]\r\n\r\n        for path, old_value, new_value in self.iter_changes(self.read_from_file(), values):\r\n            if old_value is None:\r\n                old_value = \"<span class='ui-defaults-none'>None</span>\"\r\n\r\n            text.append(f\"<tr><td>{path}</td><td>{old_value}</td><td>{new_value}</td></tr>\")\r\n\r\n        if len(text) == 1:\r\n            text.append(\"<tr><td colspan=3>No changes</td></tr>\")\r\n\r\n        text.append(\"</tbody>\")\r\n        return \"\".join(text)\r\n\r\n    def ui_apply(self, *values):\r\n        num_changed = 0\r\n\r\n        current_ui_settings = self.read_from_file()\r\n\r\n        for path, _, new_value in self.iter_changes(current_ui_settings.copy(), values):\r\n            num_changed += 1\r\n            current_ui_settings[path] = new_value\r\n\r\n        if num_changed == 0:\r\n            return \"No changes.\"\r\n\r\n        self.write_to_file(current_ui_settings)\r\n\r\n        return f\"Wrote {num_changed} changes.\"\r\n\r\n    def create_ui(self):\r\n        \"\"\"creates ui elements for editing defaults UI, without adding any logic to them\"\"\"\r\n\r\n        gr.HTML(\r\n            f\"This page allows you to change default values in UI elements on other tabs.<br />\"\r\n            f\"Make your changes, press 'View changes' to review the changed default values,<br />\"\r\n            f\"then press 'Apply' to write them to {self.filename}.<br />\"\r\n            f\"New defaults will apply after you restart the UI.<br />\"\r\n        )\r\n\r\n        with gr.Row():\r\n            self.ui_defaults_view = gr.Button(value='View changes', elem_id=\"ui_defaults_view\", variant=\"secondary\")\r\n            self.ui_defaults_apply = gr.Button(value='Apply', elem_id=\"ui_defaults_apply\", variant=\"primary\")\r\n\r\n        self.ui_defaults_review = gr.HTML(\"\")\r\n\r\n    def setup_ui(self):\r\n        \"\"\"adds logic to elements created with create_ui; all add_block class must be made before this\"\"\"\r\n\r\n        assert not self.finalized_ui\r\n        self.finalized_ui = True\r\n\r\n        self.ui_defaults_view.click(fn=self.ui_view, inputs=list(self.component_mapping.values()), outputs=[self.ui_defaults_review])\r\n        self.ui_defaults_apply.click(fn=self.ui_apply, inputs=list(self.component_mapping.values()), outputs=[self.ui_defaults_review])\r\n"
  },
  {
    "path": "modules/ui_postprocessing.py",
    "content": "import gradio as gr\r\nfrom modules import scripts, shared, ui_common, postprocessing, call_queue, ui_toprow\r\nimport modules.infotext_utils as parameters_copypaste\r\nfrom modules.ui_components import ResizeHandleRow\r\n\r\n\r\ndef create_ui():\r\n    dummy_component = gr.Label(visible=False)\r\n    tab_index = gr.Number(value=0, visible=False)\r\n\r\n    with ResizeHandleRow(equal_height=False, variant='compact'):\r\n        with gr.Column(variant='compact'):\r\n            with gr.Tabs(elem_id=\"mode_extras\"):\r\n                with gr.TabItem('Single Image', id=\"single_image\", elem_id=\"extras_single_tab\") as tab_single:\r\n                    extras_image = gr.Image(label=\"Source\", source=\"upload\", interactive=True, type=\"pil\", elem_id=\"extras_image\", image_mode=\"RGBA\")\r\n\r\n                with gr.TabItem('Batch Process', id=\"batch_process\", elem_id=\"extras_batch_process_tab\") as tab_batch:\r\n                    image_batch = gr.Files(label=\"Batch Process\", interactive=True, elem_id=\"extras_image_batch\")\r\n\r\n                with gr.TabItem('Batch from Directory', id=\"batch_from_directory\", elem_id=\"extras_batch_directory_tab\") as tab_batch_dir:\r\n                    extras_batch_input_dir = gr.Textbox(label=\"Input directory\", **shared.hide_dirs, placeholder=\"A directory on the same machine where the server is running.\", elem_id=\"extras_batch_input_dir\")\r\n                    extras_batch_output_dir = gr.Textbox(label=\"Output directory\", **shared.hide_dirs, placeholder=\"Leave blank to save images to the default path.\", elem_id=\"extras_batch_output_dir\")\r\n                    show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id=\"extras_show_extras_results\")\r\n\r\n            script_inputs = scripts.scripts_postproc.setup_ui()\r\n\r\n        with gr.Column():\r\n            toprow = ui_toprow.Toprow(is_compact=True, is_img2img=False, id_part=\"extras\")\r\n            toprow.create_inline_toprow_image()\r\n            submit = toprow.submit\r\n\r\n            output_panel = ui_common.create_output_panel(\"extras\", shared.opts.outdir_extras_samples)\r\n\r\n    tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index])\r\n    tab_batch.select(fn=lambda: 1, inputs=[], outputs=[tab_index])\r\n    tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index])\r\n\r\n    submit.click(\r\n        fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing_webui, extra_outputs=[None, '']),\r\n        _js=\"submit_extras\",\r\n        inputs=[\r\n            dummy_component,\r\n            tab_index,\r\n            extras_image,\r\n            image_batch,\r\n            extras_batch_input_dir,\r\n            extras_batch_output_dir,\r\n            show_extras_results,\r\n            *script_inputs\r\n        ],\r\n        outputs=[\r\n            output_panel.gallery,\r\n            output_panel.generation_info,\r\n            output_panel.html_log,\r\n        ],\r\n        show_progress=False,\r\n    )\r\n\r\n    parameters_copypaste.add_paste_fields(\"extras\", extras_image, None)\r\n\r\n    extras_image.change(\r\n        fn=scripts.scripts_postproc.image_changed,\r\n        inputs=[], outputs=[]\r\n    )\r\n"
  },
  {
    "path": "modules/ui_prompt_styles.py",
    "content": "import gradio as gr\r\n\r\nfrom modules import shared, ui_common, ui_components, styles\r\n\r\nstyles_edit_symbol = '\\U0001f58c\\uFE0F'  # 🖌️\r\nstyles_materialize_symbol = '\\U0001f4cb'  # 📋\r\nstyles_copy_symbol = '\\U0001f4dd'  # 📝\r\n\r\n\r\ndef select_style(name):\r\n    style = shared.prompt_styles.styles.get(name)\r\n    existing = style is not None\r\n    empty = not name\r\n\r\n    prompt = style.prompt if style else gr.update()\r\n    negative_prompt = style.negative_prompt if style else gr.update()\r\n\r\n    return prompt, negative_prompt, gr.update(visible=existing), gr.update(visible=not empty)\r\n\r\n\r\ndef save_style(name, prompt, negative_prompt):\r\n    if not name:\r\n        return gr.update(visible=False)\r\n\r\n    existing_style = shared.prompt_styles.styles.get(name)\r\n    path = existing_style.path if existing_style is not None else None\r\n\r\n    style = styles.PromptStyle(name, prompt, negative_prompt, path)\r\n    shared.prompt_styles.styles[style.name] = style\r\n    shared.prompt_styles.save_styles()\r\n\r\n    return gr.update(visible=True)\r\n\r\n\r\ndef delete_style(name):\r\n    if name == \"\":\r\n        return\r\n\r\n    shared.prompt_styles.styles.pop(name, None)\r\n    shared.prompt_styles.save_styles()\r\n\r\n    return '', '', ''\r\n\r\n\r\ndef materialize_styles(prompt, negative_prompt, styles):\r\n    prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles)\r\n    negative_prompt = shared.prompt_styles.apply_negative_styles_to_prompt(negative_prompt, styles)\r\n\r\n    return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=negative_prompt), gr.Dropdown.update(value=[])]\r\n\r\n\r\ndef refresh_styles():\r\n    return gr.update(choices=list(shared.prompt_styles.styles)), gr.update(choices=list(shared.prompt_styles.styles))\r\n\r\n\r\nclass UiPromptStyles:\r\n    def __init__(self, tabname, main_ui_prompt, main_ui_negative_prompt):\r\n        self.tabname = tabname\r\n        self.main_ui_prompt = main_ui_prompt\r\n        self.main_ui_negative_prompt = main_ui_negative_prompt\r\n\r\n        with gr.Row(elem_id=f\"{tabname}_styles_row\"):\r\n            self.dropdown = gr.Dropdown(label=\"Styles\", show_label=False, elem_id=f\"{tabname}_styles\", choices=list(shared.prompt_styles.styles), value=[], multiselect=True, tooltip=\"Styles\")\r\n            edit_button = ui_components.ToolButton(value=styles_edit_symbol, elem_id=f\"{tabname}_styles_edit_button\", tooltip=\"Edit styles\")\r\n\r\n        with gr.Box(elem_id=f\"{tabname}_styles_dialog\", elem_classes=\"popup-dialog\") as styles_dialog:\r\n            with gr.Row():\r\n                self.selection = gr.Dropdown(label=\"Styles\", elem_id=f\"{tabname}_styles_edit_select\", choices=list(shared.prompt_styles.styles), value=[], allow_custom_value=True, info=\"Styles allow you to add custom text to prompt. Use the {prompt} token in style text, and it will be replaced with user's prompt when applying style. Otherwise, style's text will be added to the end of the prompt.\")\r\n                ui_common.create_refresh_button([self.dropdown, self.selection], shared.prompt_styles.reload, lambda: {\"choices\": list(shared.prompt_styles.styles)}, f\"refresh_{tabname}_styles\")\r\n                self.materialize = ui_components.ToolButton(value=styles_materialize_symbol, elem_id=f\"{tabname}_style_apply_dialog\", tooltip=\"Apply all selected styles from the style selection dropdown in main UI to the prompt.\")\r\n                self.copy = ui_components.ToolButton(value=styles_copy_symbol, elem_id=f\"{tabname}_style_copy\", tooltip=\"Copy main UI prompt to style.\")\r\n\r\n            with gr.Row():\r\n                self.prompt = gr.Textbox(label=\"Prompt\", show_label=True, elem_id=f\"{tabname}_edit_style_prompt\", lines=3, elem_classes=[\"prompt\"])\r\n\r\n            with gr.Row():\r\n                self.neg_prompt = gr.Textbox(label=\"Negative prompt\", show_label=True, elem_id=f\"{tabname}_edit_style_neg_prompt\", lines=3, elem_classes=[\"prompt\"])\r\n\r\n            with gr.Row():\r\n                self.save = gr.Button('Save', variant='primary', elem_id=f'{tabname}_edit_style_save', visible=False)\r\n                self.delete = gr.Button('Delete', variant='primary', elem_id=f'{tabname}_edit_style_delete', visible=False)\r\n                self.close = gr.Button('Close', variant='secondary', elem_id=f'{tabname}_edit_style_close')\r\n\r\n        self.selection.change(\r\n            fn=select_style,\r\n            inputs=[self.selection],\r\n            outputs=[self.prompt, self.neg_prompt, self.delete, self.save],\r\n            show_progress=False,\r\n        )\r\n\r\n        self.save.click(\r\n            fn=save_style,\r\n            inputs=[self.selection, self.prompt, self.neg_prompt],\r\n            outputs=[self.delete],\r\n            show_progress=False,\r\n        ).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False)\r\n\r\n        self.delete.click(\r\n            fn=delete_style,\r\n            _js='function(name){ if(name == \"\") return \"\"; return confirm(\"Delete style \" + name + \"?\") ? name : \"\"; }',\r\n            inputs=[self.selection],\r\n            outputs=[self.selection, self.prompt, self.neg_prompt],\r\n            show_progress=False,\r\n        ).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False)\r\n\r\n        self.setup_apply_button(self.materialize)\r\n\r\n        self.copy.click(\r\n            fn=lambda p, n: (p, n),\r\n            inputs=[main_ui_prompt, main_ui_negative_prompt],\r\n            outputs=[self.prompt, self.neg_prompt],\r\n            show_progress=False,\r\n        )\r\n\r\n        ui_common.setup_dialog(button_show=edit_button, dialog=styles_dialog, button_close=self.close)\r\n\r\n    def setup_apply_button(self, button):\r\n        button.click(\r\n            fn=materialize_styles,\r\n            inputs=[self.main_ui_prompt, self.main_ui_negative_prompt, self.dropdown],\r\n            outputs=[self.main_ui_prompt, self.main_ui_negative_prompt, self.dropdown],\r\n            show_progress=False,\r\n        ).then(fn=None, _js=\"function(){update_\"+self.tabname+\"_tokens(); closePopup();}\", show_progress=False)\r\n"
  },
  {
    "path": "modules/ui_settings.py",
    "content": "import gradio as gr\r\n\r\nfrom modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo, timer, shared_items\r\nfrom modules.call_queue import wrap_gradio_call_no_job\r\nfrom modules.options import options_section\r\nfrom modules.shared import opts\r\nfrom modules.ui_components import FormRow\r\nfrom modules.ui_gradio_extensions import reload_javascript\r\nfrom concurrent.futures import ThreadPoolExecutor, as_completed\r\n\r\n\r\ndef get_value_for_setting(key):\r\n    value = getattr(opts, key)\r\n\r\n    info = opts.data_labels[key]\r\n    args = info.component_args() if callable(info.component_args) else info.component_args or {}\r\n    args = {k: v for k, v in args.items() if k not in {'precision'}}\r\n\r\n    return gr.update(value=value, **args)\r\n\r\n\r\ndef create_setting_component(key, is_quicksettings=False):\r\n    def fun():\r\n        return opts.data[key] if key in opts.data else opts.data_labels[key].default\r\n\r\n    info = opts.data_labels[key]\r\n    t = type(info.default)\r\n\r\n    args = info.component_args() if callable(info.component_args) else info.component_args\r\n\r\n    if info.component is not None:\r\n        comp = info.component\r\n    elif t == str:\r\n        comp = gr.Textbox\r\n    elif t == int:\r\n        comp = gr.Number\r\n    elif t == bool:\r\n        comp = gr.Checkbox\r\n    else:\r\n        raise Exception(f'bad options item type: {t} for key {key}')\r\n\r\n    elem_id = f\"setting_{key}\"\r\n\r\n    if info.refresh is not None:\r\n        if is_quicksettings:\r\n            res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))\r\n            ui_common.create_refresh_button(res, info.refresh, info.component_args, f\"refresh_{key}\")\r\n        else:\r\n            with FormRow():\r\n                res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))\r\n                ui_common.create_refresh_button(res, info.refresh, info.component_args, f\"refresh_{key}\")\r\n    else:\r\n        res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {}))\r\n\r\n    return res\r\n\r\n\r\nclass UiSettings:\r\n    submit = None\r\n    result = None\r\n    interface = None\r\n    components = None\r\n    component_dict = None\r\n    dummy_component = None\r\n    quicksettings_list = None\r\n    quicksettings_names = None\r\n    text_settings = None\r\n    show_all_pages = None\r\n    show_one_page = None\r\n    search_input = None\r\n\r\n    def run_settings(self, *args):\r\n        changed = []\r\n\r\n        for key, value, comp in zip(opts.data_labels.keys(), args, self.components):\r\n            assert comp == self.dummy_component or opts.same_type(value, opts.data_labels[key].default), f\"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}\"\r\n\r\n        for key, value, comp in zip(opts.data_labels.keys(), args, self.components):\r\n            if comp == self.dummy_component:\r\n                continue\r\n\r\n            if opts.set(key, value):\r\n                changed.append(key)\r\n\r\n        try:\r\n            opts.save(shared.config_filename)\r\n        except RuntimeError:\r\n            return opts.dumpjson(), f'{len(changed)} settings changed without save: {\", \".join(changed)}.'\r\n        return opts.dumpjson(), f'{len(changed)} settings changed{\": \" if changed else \"\"}{\", \".join(changed)}.'\r\n\r\n    def run_settings_single(self, value, key):\r\n        if not opts.same_type(value, opts.data_labels[key].default):\r\n            return gr.update(visible=True), opts.dumpjson()\r\n\r\n        if value is None or not opts.set(key, value):\r\n            return gr.update(value=getattr(opts, key)), opts.dumpjson()\r\n\r\n        opts.save(shared.config_filename)\r\n\r\n        return get_value_for_setting(key), opts.dumpjson()\r\n\r\n    def register_settings(self):\r\n        script_callbacks.ui_settings_callback()\r\n\r\n    def create_ui(self, loadsave, dummy_component):\r\n        self.components = []\r\n        self.component_dict = {}\r\n        self.dummy_component = dummy_component\r\n\r\n        shared.settings_components = self.component_dict\r\n\r\n        # we add this as late as possible so that scripts have already registered their callbacks\r\n        opts.data_labels.update(options_section(('callbacks', \"Callbacks\", \"system\"), {\r\n            **shared_items.callbacks_order_settings(),\r\n        }))\r\n\r\n        opts.reorder()\r\n\r\n        with gr.Blocks(analytics_enabled=False) as settings_interface:\r\n            with gr.Row():\r\n                with gr.Column(scale=6):\r\n                    self.submit = gr.Button(value=\"Apply settings\", variant='primary', elem_id=\"settings_submit\")\r\n                with gr.Column():\r\n                    restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id=\"settings_restart_gradio\")\r\n\r\n            self.result = gr.HTML(elem_id=\"settings_result\")\r\n\r\n            self.quicksettings_names = opts.quicksettings_list\r\n            self.quicksettings_names = {x: i for i, x in enumerate(self.quicksettings_names) if x != 'quicksettings'}\r\n\r\n            self.quicksettings_list = []\r\n\r\n            previous_section = None\r\n            current_tab = None\r\n            current_row = None\r\n            with gr.Tabs(elem_id=\"settings\"):\r\n                for i, (k, item) in enumerate(opts.data_labels.items()):\r\n                    section_must_be_skipped = item.section[0] is None\r\n\r\n                    if previous_section != item.section and not section_must_be_skipped:\r\n                        elem_id, text = item.section\r\n\r\n                        if current_tab is not None:\r\n                            current_row.__exit__()\r\n                            current_tab.__exit__()\r\n\r\n                        gr.Group()\r\n                        current_tab = gr.TabItem(elem_id=f\"settings_{elem_id}\", label=text)\r\n                        current_tab.__enter__()\r\n                        current_row = gr.Column(elem_id=f\"column_settings_{elem_id}\", variant='compact')\r\n                        current_row.__enter__()\r\n\r\n                        previous_section = item.section\r\n\r\n                    if k in self.quicksettings_names and not shared.cmd_opts.freeze_settings:\r\n                        self.quicksettings_list.append((i, k, item))\r\n                        self.components.append(dummy_component)\r\n                    elif section_must_be_skipped:\r\n                        self.components.append(dummy_component)\r\n                    else:\r\n                        component = create_setting_component(k)\r\n                        self.component_dict[k] = component\r\n                        self.components.append(component)\r\n\r\n                if current_tab is not None:\r\n                    current_row.__exit__()\r\n                    current_tab.__exit__()\r\n\r\n                with gr.TabItem(\"Defaults\", id=\"defaults\", elem_id=\"settings_tab_defaults\"):\r\n                    loadsave.create_ui()\r\n\r\n                with gr.TabItem(\"Sysinfo\", id=\"sysinfo\", elem_id=\"settings_tab_sysinfo\"):\r\n                    gr.HTML('<a href=\"./internal/sysinfo-download\" class=\"sysinfo_big_link\" download>Download system info</a><br /><a href=\"./internal/sysinfo\" target=\"_blank\">(or open as text in a new page)</a>', elem_id=\"sysinfo_download\")\r\n\r\n                    with gr.Row():\r\n                        with gr.Column(scale=1):\r\n                            sysinfo_check_file = gr.File(label=\"Check system info for validity\", type='binary')\r\n                        with gr.Column(scale=1):\r\n                            sysinfo_check_output = gr.HTML(\"\", elem_id=\"sysinfo_validity\")\r\n                        with gr.Column(scale=100):\r\n                            pass\r\n\r\n                with gr.TabItem(\"Actions\", id=\"actions\", elem_id=\"settings_tab_actions\"):\r\n                    request_notifications = gr.Button(value='Request browser notifications', elem_id=\"request_notifications\")\r\n                    download_localization = gr.Button(value='Download localization template', elem_id=\"download_localization\")\r\n                    reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id=\"settings_reload_script_bodies\")\r\n                    with gr.Row():\r\n                        unload_sd_model = gr.Button(value='Unload SD checkpoint to RAM', elem_id=\"sett_unload_sd_model\")\r\n                        reload_sd_model = gr.Button(value='Load SD checkpoint to VRAM from RAM', elem_id=\"sett_reload_sd_model\")\r\n                    with gr.Row():\r\n                        calculate_all_checkpoint_hash = gr.Button(value='Calculate hash for all checkpoint', elem_id=\"calculate_all_checkpoint_hash\")\r\n                        calculate_all_checkpoint_hash_threads = gr.Number(value=1, label=\"Number of parallel calculations\", elem_id=\"calculate_all_checkpoint_hash_threads\", precision=0, minimum=1)\r\n\r\n                with gr.TabItem(\"Licenses\", id=\"licenses\", elem_id=\"settings_tab_licenses\"):\r\n                    gr.HTML(shared.html(\"licenses.html\"), elem_id=\"licenses\")\r\n\r\n                self.show_all_pages = gr.Button(value=\"Show all pages\", elem_id=\"settings_show_all_pages\")\r\n                self.show_one_page = gr.Button(value=\"Show only one page\", elem_id=\"settings_show_one_page\", visible=False)\r\n                self.show_one_page.click(lambda: None)\r\n\r\n                self.search_input = gr.Textbox(value=\"\", elem_id=\"settings_search\", max_lines=1, placeholder=\"Search...\", show_label=False)\r\n\r\n                self.text_settings = gr.Textbox(elem_id=\"settings_json\", value=lambda: opts.dumpjson(), visible=False)\r\n\r\n            def call_func_and_return_text(func, text):\r\n                def handler():\r\n                    t = timer.Timer()\r\n                    func()\r\n                    t.record(text)\r\n\r\n                    return f'{text} in {t.total:.1f}s'\r\n\r\n                return handler\r\n\r\n            unload_sd_model.click(\r\n                fn=call_func_and_return_text(sd_models.unload_model_weights, 'Unloaded the checkpoint'),\r\n                inputs=[],\r\n                outputs=[self.result]\r\n            )\r\n\r\n            reload_sd_model.click(\r\n                fn=call_func_and_return_text(lambda: sd_models.send_model_to_device(shared.sd_model), 'Loaded the checkpoint'),\r\n                inputs=[],\r\n                outputs=[self.result]\r\n            )\r\n\r\n            request_notifications.click(\r\n                fn=lambda: None,\r\n                inputs=[],\r\n                outputs=[],\r\n                _js='function(){}'\r\n            )\r\n\r\n            download_localization.click(\r\n                fn=lambda: None,\r\n                inputs=[],\r\n                outputs=[],\r\n                _js='download_localization'\r\n            )\r\n\r\n            def reload_scripts():\r\n                scripts.reload_script_body_only()\r\n                reload_javascript()  # need to refresh the html page\r\n\r\n            reload_script_bodies.click(\r\n                fn=reload_scripts,\r\n                inputs=[],\r\n                outputs=[]\r\n            )\r\n\r\n            restart_gradio.click(\r\n                fn=shared.state.request_restart,\r\n                _js='restart_reload',\r\n                inputs=[],\r\n                outputs=[],\r\n            )\r\n\r\n            def check_file(x):\r\n                if x is None:\r\n                    return ''\r\n\r\n                if sysinfo.check(x.decode('utf8', errors='ignore')):\r\n                    return 'Valid'\r\n\r\n                return 'Invalid'\r\n\r\n            sysinfo_check_file.change(\r\n                fn=check_file,\r\n                inputs=[sysinfo_check_file],\r\n                outputs=[sysinfo_check_output],\r\n            )\r\n\r\n            def calculate_all_checkpoint_hash_fn(max_thread):\r\n                checkpoints_list = sd_models.checkpoints_list.values()\r\n                with ThreadPoolExecutor(max_workers=max_thread) as executor:\r\n                    futures = [executor.submit(checkpoint.calculate_shorthash) for checkpoint in checkpoints_list]\r\n                    completed = 0\r\n                    for _ in as_completed(futures):\r\n                        completed += 1\r\n                        print(f\"{completed} / {len(checkpoints_list)} \")\r\n                    print(\"Finish calculating hash for all checkpoints\")\r\n\r\n            calculate_all_checkpoint_hash.click(\r\n                fn=calculate_all_checkpoint_hash_fn,\r\n                inputs=[calculate_all_checkpoint_hash_threads],\r\n            )\r\n\r\n        self.interface = settings_interface\r\n\r\n    def add_quicksettings(self):\r\n        with gr.Row(elem_id=\"quicksettings\", variant=\"compact\"):\r\n            for _i, k, _item in sorted(self.quicksettings_list, key=lambda x: self.quicksettings_names.get(x[1], x[0])):\r\n                component = create_setting_component(k, is_quicksettings=True)\r\n                self.component_dict[k] = component\r\n\r\n    def add_functionality(self, demo):\r\n        self.submit.click(\r\n            fn=wrap_gradio_call_no_job(lambda *args: self.run_settings(*args), extra_outputs=[gr.update()]),\r\n            inputs=self.components,\r\n            outputs=[self.text_settings, self.result],\r\n        )\r\n\r\n        for _i, k, _item in self.quicksettings_list:\r\n            component = self.component_dict[k]\r\n            info = opts.data_labels[k]\r\n\r\n            if isinstance(component, gr.Textbox):\r\n                methods = [component.submit, component.blur]\r\n            elif hasattr(component, 'release'):\r\n                methods = [component.release]\r\n            else:\r\n                methods = [component.change]\r\n\r\n            for method in methods:\r\n                method(\r\n                    fn=lambda value, k=k: self.run_settings_single(value, key=k),\r\n                    inputs=[component],\r\n                    outputs=[component, self.text_settings],\r\n                    show_progress=info.refresh is not None,\r\n                )\r\n\r\n        button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)\r\n        button_set_checkpoint.click(\r\n            fn=lambda value, _: self.run_settings_single(value, key='sd_model_checkpoint'),\r\n            _js=\"function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }\",\r\n            inputs=[self.component_dict['sd_model_checkpoint'], self.dummy_component],\r\n            outputs=[self.component_dict['sd_model_checkpoint'], self.text_settings],\r\n        )\r\n\r\n        component_keys = [k for k in opts.data_labels.keys() if k in self.component_dict]\r\n\r\n        def get_settings_values():\r\n            return [get_value_for_setting(key) for key in component_keys]\r\n\r\n        demo.load(\r\n            fn=get_settings_values,\r\n            inputs=[],\r\n            outputs=[self.component_dict[k] for k in component_keys],\r\n            queue=False,\r\n        )\r\n\r\n    def search(self, text):\r\n        print(text)\r\n\r\n        return [gr.update(visible=text in (comp.label or \"\")) for comp in self.components]\r\n"
  },
  {
    "path": "modules/ui_tempdir.py",
    "content": "import os\r\nimport tempfile\r\nfrom collections import namedtuple\r\nfrom pathlib import Path\r\n\r\nimport gradio.components\r\n\r\nfrom PIL import PngImagePlugin\r\n\r\nfrom modules import shared\r\n\r\n\r\nSavedfile = namedtuple(\"Savedfile\", [\"name\"])\r\n\r\n\r\ndef register_tmp_file(gradio, filename):\r\n    if hasattr(gradio, 'temp_file_sets'):  # gradio 3.15\r\n        gradio.temp_file_sets[0] = gradio.temp_file_sets[0] | {os.path.abspath(filename)}\r\n\r\n    if hasattr(gradio, 'temp_dirs'):  # gradio 3.9\r\n        gradio.temp_dirs = gradio.temp_dirs | {os.path.abspath(os.path.dirname(filename))}\r\n\r\n\r\ndef check_tmp_file(gradio, filename):\r\n    if hasattr(gradio, 'temp_file_sets'):\r\n        return any(filename in fileset for fileset in gradio.temp_file_sets)\r\n\r\n    if hasattr(gradio, 'temp_dirs'):\r\n        return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio.temp_dirs)\r\n\r\n    return False\r\n\r\n\r\ndef save_pil_to_file(self, pil_image, dir=None, format=\"png\"):\r\n    already_saved_as = getattr(pil_image, 'already_saved_as', None)\r\n    if already_saved_as and os.path.isfile(already_saved_as):\r\n        register_tmp_file(shared.demo, already_saved_as)\r\n        filename_with_mtime = f'{already_saved_as}?{os.path.getmtime(already_saved_as)}'\r\n        register_tmp_file(shared.demo, filename_with_mtime)\r\n        return filename_with_mtime\r\n\r\n    if shared.opts.temp_dir != \"\":\r\n        dir = shared.opts.temp_dir\r\n    else:\r\n        os.makedirs(dir, exist_ok=True)\r\n\r\n    use_metadata = False\r\n    metadata = PngImagePlugin.PngInfo()\r\n    for key, value in pil_image.info.items():\r\n        if isinstance(key, str) and isinstance(value, str):\r\n            metadata.add_text(key, value)\r\n            use_metadata = True\r\n\r\n    file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=\".png\", dir=dir)\r\n    pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))\r\n    return file_obj.name\r\n\r\n\r\ndef install_ui_tempdir_override():\r\n    \"\"\"override save to file function so that it also writes PNG info\"\"\"\r\n    gradio.components.IOComponent.pil_to_temp_file = save_pil_to_file\r\n\r\n\r\ndef on_tmpdir_changed():\r\n    if shared.opts.temp_dir == \"\" or shared.demo is None:\r\n        return\r\n\r\n    os.makedirs(shared.opts.temp_dir, exist_ok=True)\r\n\r\n    register_tmp_file(shared.demo, os.path.join(shared.opts.temp_dir, \"x\"))\r\n\r\n\r\ndef cleanup_tmpdr():\r\n    temp_dir = shared.opts.temp_dir\r\n    if temp_dir == \"\" or not os.path.isdir(temp_dir):\r\n        return\r\n\r\n    for root, _, files in os.walk(temp_dir, topdown=False):\r\n        for name in files:\r\n            _, extension = os.path.splitext(name)\r\n            if extension != \".png\":\r\n                continue\r\n\r\n            filename = os.path.join(root, name)\r\n            os.remove(filename)\r\n\r\n\r\ndef is_gradio_temp_path(path):\r\n    \"\"\"\r\n    Check if the path is a temp dir used by gradio\r\n    \"\"\"\r\n    path = Path(path)\r\n    if shared.opts.temp_dir and path.is_relative_to(shared.opts.temp_dir):\r\n        return True\r\n    if gradio_temp_dir := os.environ.get(\"GRADIO_TEMP_DIR\"):\r\n        if path.is_relative_to(gradio_temp_dir):\r\n            return True\r\n    if path.is_relative_to(Path(tempfile.gettempdir()) / \"gradio\"):\r\n        return True\r\n    return False\r\n"
  },
  {
    "path": "modules/ui_toprow.py",
    "content": "import gradio as gr\r\n\r\nfrom modules import shared, ui_prompt_styles\r\nimport modules.images\r\n\r\nfrom modules.ui_components import ToolButton\r\n\r\n\r\nclass Toprow:\r\n    \"\"\"Creates a top row UI with prompts, generate button, styles, extra little buttons for things, and enables some functionality related to their operation\"\"\"\r\n\r\n    prompt = None\r\n    prompt_img = None\r\n    negative_prompt = None\r\n\r\n    button_interrogate = None\r\n    button_deepbooru = None\r\n\r\n    interrupt = None\r\n    interrupting = None\r\n    skip = None\r\n    submit = None\r\n\r\n    paste = None\r\n    clear_prompt_button = None\r\n    apply_styles = None\r\n    restore_progress_button = None\r\n\r\n    token_counter = None\r\n    token_button = None\r\n    negative_token_counter = None\r\n    negative_token_button = None\r\n\r\n    ui_styles = None\r\n\r\n    submit_box = None\r\n\r\n    def __init__(self, is_img2img, is_compact=False, id_part=None):\r\n        if id_part is None:\r\n            id_part = \"img2img\" if is_img2img else \"txt2img\"\r\n\r\n        self.id_part = id_part\r\n        self.is_img2img = is_img2img\r\n        self.is_compact = is_compact\r\n\r\n        if not is_compact:\r\n            with gr.Row(elem_id=f\"{id_part}_toprow\", variant=\"compact\"):\r\n                self.create_classic_toprow()\r\n        else:\r\n            self.create_submit_box()\r\n\r\n    def create_classic_toprow(self):\r\n        self.create_prompts()\r\n\r\n        with gr.Column(scale=1, elem_id=f\"{self.id_part}_actions_column\"):\r\n            self.create_submit_box()\r\n\r\n            self.create_tools_row()\r\n\r\n            self.create_styles_ui()\r\n\r\n    def create_inline_toprow_prompts(self):\r\n        if not self.is_compact:\r\n            return\r\n\r\n        self.create_prompts()\r\n\r\n        with gr.Row(elem_classes=[\"toprow-compact-stylerow\"]):\r\n            with gr.Column(elem_classes=[\"toprow-compact-tools\"]):\r\n                self.create_tools_row()\r\n            with gr.Column():\r\n                self.create_styles_ui()\r\n\r\n    def create_inline_toprow_image(self):\r\n        if not self.is_compact:\r\n            return\r\n\r\n        self.submit_box.render()\r\n\r\n    def create_prompts(self):\r\n        with gr.Column(elem_id=f\"{self.id_part}_prompt_container\", elem_classes=[\"prompt-container-compact\"] if self.is_compact else [], scale=6):\r\n            with gr.Row(elem_id=f\"{self.id_part}_prompt_row\", elem_classes=[\"prompt-row\"]):\r\n                self.prompt = gr.Textbox(label=\"Prompt\", elem_id=f\"{self.id_part}_prompt\", show_label=False, lines=3, placeholder=\"Prompt\\n(Press Ctrl+Enter to generate, Alt+Enter to skip, Esc to interrupt)\", elem_classes=[\"prompt\"])\r\n                self.prompt_img = gr.File(label=\"\", elem_id=f\"{self.id_part}_prompt_image\", file_count=\"single\", type=\"binary\", visible=False)\r\n\r\n            with gr.Row(elem_id=f\"{self.id_part}_neg_prompt_row\", elem_classes=[\"prompt-row\"]):\r\n                self.negative_prompt = gr.Textbox(label=\"Negative prompt\", elem_id=f\"{self.id_part}_neg_prompt\", show_label=False, lines=3, placeholder=\"Negative prompt\\n(Press Ctrl+Enter to generate, Alt+Enter to skip, Esc to interrupt)\", elem_classes=[\"prompt\"])\r\n\r\n        self.prompt_img.change(\r\n            fn=modules.images.image_data,\r\n            inputs=[self.prompt_img],\r\n            outputs=[self.prompt, self.prompt_img],\r\n            show_progress=False,\r\n        )\r\n\r\n    def create_submit_box(self):\r\n        with gr.Row(elem_id=f\"{self.id_part}_generate_box\", elem_classes=[\"generate-box\"] + ([\"generate-box-compact\"] if self.is_compact else []), render=not self.is_compact) as submit_box:\r\n            self.submit_box = submit_box\r\n\r\n            self.interrupt = gr.Button('Interrupt', elem_id=f\"{self.id_part}_interrupt\", elem_classes=\"generate-box-interrupt\", tooltip=\"End generation immediately or after completing current batch\")\r\n            self.skip = gr.Button('Skip', elem_id=f\"{self.id_part}_skip\", elem_classes=\"generate-box-skip\", tooltip=\"Stop generation of current batch and continues onto next batch\")\r\n            self.interrupting = gr.Button('Interrupting...', elem_id=f\"{self.id_part}_interrupting\", elem_classes=\"generate-box-interrupting\", tooltip=\"Interrupting generation...\")\r\n            self.submit = gr.Button('Generate', elem_id=f\"{self.id_part}_generate\", variant='primary', tooltip=\"Right click generate forever menu\")\r\n\r\n            def interrupt_function():\r\n                if not shared.state.stopping_generation and shared.state.job_count > 1 and shared.opts.interrupt_after_current:\r\n                    shared.state.stop_generating()\r\n                    gr.Info(\"Generation will stop after finishing this image, click again to stop immediately.\")\r\n                else:\r\n                    shared.state.interrupt()\r\n\r\n            self.skip.click(fn=shared.state.skip)\r\n            self.interrupt.click(fn=interrupt_function, _js='function(){ showSubmitInterruptingPlaceholder(\"' + self.id_part + '\"); }')\r\n            self.interrupting.click(fn=interrupt_function)\r\n\r\n    def create_tools_row(self):\r\n        with gr.Row(elem_id=f\"{self.id_part}_tools\"):\r\n            from modules.ui import paste_symbol, clear_prompt_symbol, restore_progress_symbol\r\n\r\n            self.paste = ToolButton(value=paste_symbol, elem_id=\"paste\", tooltip=\"Read generation parameters from prompt or last generation if prompt is empty into user interface.\")\r\n            self.clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f\"{self.id_part}_clear_prompt\", tooltip=\"Clear prompt\")\r\n            self.apply_styles = ToolButton(value=ui_prompt_styles.styles_materialize_symbol, elem_id=f\"{self.id_part}_style_apply\", tooltip=\"Apply all selected styles to prompts.\")\r\n\r\n            if self.is_img2img:\r\n                self.button_interrogate = ToolButton('📎', tooltip='Interrogate CLIP - use CLIP neural network to create a text describing the image, and put it into the prompt field', elem_id=\"interrogate\")\r\n                self.button_deepbooru = ToolButton('📦', tooltip='Interrogate DeepBooru - use DeepBooru neural network to create a text describing the image, and put it into the prompt field', elem_id=\"deepbooru\")\r\n\r\n            self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f\"{self.id_part}_restore_progress\", visible=False, tooltip=\"Restore progress\")\r\n\r\n            self.token_counter = gr.HTML(value=\"<span>0/75</span>\", elem_id=f\"{self.id_part}_token_counter\", elem_classes=[\"token-counter\"], visible=False)\r\n            self.token_button = gr.Button(visible=False, elem_id=f\"{self.id_part}_token_button\")\r\n            self.negative_token_counter = gr.HTML(value=\"<span>0/75</span>\", elem_id=f\"{self.id_part}_negative_token_counter\", elem_classes=[\"token-counter\"], visible=False)\r\n            self.negative_token_button = gr.Button(visible=False, elem_id=f\"{self.id_part}_negative_token_button\")\r\n\r\n            self.clear_prompt_button.click(\r\n                fn=lambda *x: x,\r\n                _js=\"confirm_clear_prompt\",\r\n                inputs=[self.prompt, self.negative_prompt],\r\n                outputs=[self.prompt, self.negative_prompt],\r\n            )\r\n\r\n    def create_styles_ui(self):\r\n        self.ui_styles = ui_prompt_styles.UiPromptStyles(self.id_part, self.prompt, self.negative_prompt)\r\n        self.ui_styles.setup_apply_button(self.apply_styles)\r\n"
  },
  {
    "path": "modules/upscaler.py",
    "content": "import os\nfrom abc import abstractmethod\n\nimport PIL\nfrom PIL import Image\n\nimport modules.shared\nfrom modules import modelloader, shared\n\nLANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)\nNEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST)\n\n\nclass Upscaler:\n    name = None\n    model_path = None\n    model_name = None\n    model_url = None\n    enable = True\n    filter = None\n    model = None\n    user_path = None\n    scalers: list\n    tile = True\n\n    def __init__(self, create_dirs=False):\n        self.mod_pad_h = None\n        self.tile_size = modules.shared.opts.ESRGAN_tile\n        self.tile_pad = modules.shared.opts.ESRGAN_tile_overlap\n        self.device = modules.shared.device\n        self.img = None\n        self.output = None\n        self.scale = 1\n        self.half = not modules.shared.cmd_opts.no_half\n        self.pre_pad = 0\n        self.mod_scale = None\n        self.model_download_path = None\n\n        if self.model_path is None and self.name:\n            self.model_path = os.path.join(shared.models_path, self.name)\n        if self.model_path and create_dirs:\n            os.makedirs(self.model_path, exist_ok=True)\n\n        try:\n            import cv2  # noqa: F401\n            self.can_tile = True\n        except Exception:\n            pass\n\n    @abstractmethod\n    def do_upscale(self, img: PIL.Image, selected_model: str):\n        return img\n\n    def upscale(self, img: PIL.Image, scale, selected_model: str = None):\n        self.scale = scale\n        dest_w = int((img.width * scale) // 8 * 8)\n        dest_h = int((img.height * scale) // 8 * 8)\n\n        for i in range(3):\n            if img.width >= dest_w and img.height >= dest_h and (i > 0 or scale != 1):\n                break\n\n            if shared.state.interrupted:\n                break\n\n            shape = (img.width, img.height)\n\n            img = self.do_upscale(img, selected_model)\n\n            if shape == (img.width, img.height):\n                break\n\n        if img.width != dest_w or img.height != dest_h:\n            img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS)\n\n        return img\n\n    @abstractmethod\n    def load_model(self, path: str):\n        pass\n\n    def find_models(self, ext_filter=None) -> list:\n        return modelloader.load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path, ext_filter=ext_filter)\n\n    def update_status(self, prompt):\n        print(f\"\\nextras: {prompt}\", file=shared.progress_print_out)\n\n\nclass UpscalerData:\n    name = None\n    data_path = None\n    scale: int = 4\n    scaler: Upscaler = None\n    model: None\n\n    def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None):\n        self.name = name\n        self.data_path = path\n        self.local_data_path = path\n        self.scaler = upscaler\n        self.scale = scale\n        self.model = model\n\n    def __repr__(self):\n        return f\"<UpscalerData name={self.name} path={self.data_path} scale={self.scale}>\"\n\n\nclass UpscalerNone(Upscaler):\n    name = \"None\"\n    scalers = []\n\n    def load_model(self, path):\n        pass\n\n    def do_upscale(self, img, selected_model=None):\n        return img\n\n    def __init__(self, dirname=None):\n        super().__init__(False)\n        self.scalers = [UpscalerData(\"None\", None, self)]\n\n\nclass UpscalerLanczos(Upscaler):\n    scalers = []\n\n    def do_upscale(self, img, selected_model=None):\n        return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=LANCZOS)\n\n    def load_model(self, _):\n        pass\n\n    def __init__(self, dirname=None):\n        super().__init__(False)\n        self.name = \"Lanczos\"\n        self.scalers = [UpscalerData(\"Lanczos\", None, self)]\n\n\nclass UpscalerNearest(Upscaler):\n    scalers = []\n\n    def do_upscale(self, img, selected_model=None):\n        return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=NEAREST)\n\n    def load_model(self, _):\n        pass\n\n    def __init__(self, dirname=None):\n        super().__init__(False)\n        self.name = \"Nearest\"\n        self.scalers = [UpscalerData(\"Nearest\", None, self)]\n"
  },
  {
    "path": "modules/upscaler_utils.py",
    "content": "import logging\nfrom typing import Callable\n\nimport numpy as np\nimport torch\nimport tqdm\nfrom PIL import Image\n\nfrom modules import devices, images, shared, torch_utils\n\nlogger = logging.getLogger(__name__)\n\n\ndef pil_image_to_torch_bgr(img: Image.Image) -> torch.Tensor:\n    img = np.array(img.convert(\"RGB\"))\n    img = img[:, :, ::-1]  # flip RGB to BGR\n    img = np.transpose(img, (2, 0, 1))  # HWC to CHW\n    img = np.ascontiguousarray(img) / 255  # Rescale to [0, 1]\n    return torch.from_numpy(img)\n\n\ndef torch_bgr_to_pil_image(tensor: torch.Tensor) -> Image.Image:\n    if tensor.ndim == 4:\n        # If we're given a tensor with a batch dimension, squeeze it out\n        # (but only if it's a batch of size 1).\n        if tensor.shape[0] != 1:\n            raise ValueError(f\"{tensor.shape} does not describe a BCHW tensor\")\n        tensor = tensor.squeeze(0)\n    assert tensor.ndim == 3, f\"{tensor.shape} does not describe a CHW tensor\"\n    # TODO: is `tensor.float().cpu()...numpy()` the most efficient idiom?\n    arr = tensor.float().cpu().clamp_(0, 1).numpy()  # clamp\n    arr = 255.0 * np.moveaxis(arr, 0, 2)  # CHW to HWC, rescale\n    arr = arr.round().astype(np.uint8)\n    arr = arr[:, :, ::-1]  # flip BGR to RGB\n    return Image.fromarray(arr, \"RGB\")\n\n\ndef upscale_pil_patch(model, img: Image.Image) -> Image.Image:\n    \"\"\"\n    Upscale a given PIL image using the given model.\n    \"\"\"\n    param = torch_utils.get_param(model)\n\n    with torch.inference_mode():\n        tensor = pil_image_to_torch_bgr(img).unsqueeze(0)  # add batch dimension\n        tensor = tensor.to(device=param.device, dtype=param.dtype)\n        with devices.without_autocast():\n            return torch_bgr_to_pil_image(model(tensor))\n\n\ndef upscale_with_model(\n    model: Callable[[torch.Tensor], torch.Tensor],\n    img: Image.Image,\n    *,\n    tile_size: int,\n    tile_overlap: int = 0,\n    desc=\"tiled upscale\",\n) -> Image.Image:\n    if tile_size <= 0:\n        logger.debug(\"Upscaling %s without tiling\", img)\n        output = upscale_pil_patch(model, img)\n        logger.debug(\"=> %s\", output)\n        return output\n\n    grid = images.split_grid(img, tile_size, tile_size, tile_overlap)\n    newtiles = []\n\n    with tqdm.tqdm(total=grid.tile_count, desc=desc, disable=not shared.opts.enable_upscale_progressbar) as p:\n        for y, h, row in grid.tiles:\n            newrow = []\n            for x, w, tile in row:\n                if shared.state.interrupted:\n                    return img\n                output = upscale_pil_patch(model, tile)\n                scale_factor = output.width // tile.width\n                newrow.append([x * scale_factor, w * scale_factor, output])\n                p.update(1)\n            newtiles.append([y * scale_factor, h * scale_factor, newrow])\n\n    newgrid = images.Grid(\n        newtiles,\n        tile_w=grid.tile_w * scale_factor,\n        tile_h=grid.tile_h * scale_factor,\n        image_w=grid.image_w * scale_factor,\n        image_h=grid.image_h * scale_factor,\n        overlap=grid.overlap * scale_factor,\n    )\n    return images.combine_grid(newgrid)\n\n\ndef tiled_upscale_2(\n    img: torch.Tensor,\n    model,\n    *,\n    tile_size: int,\n    tile_overlap: int,\n    scale: int,\n    device: torch.device,\n    desc=\"Tiled upscale\",\n):\n    # Alternative implementation of `upscale_with_model` originally used by\n    # SwinIR and ScuNET.  It differs from `upscale_with_model` in that tiling and\n    # weighting is done in PyTorch space, as opposed to `images.Grid` doing it in\n    # Pillow space without weighting.\n\n    b, c, h, w = img.size()\n    tile_size = min(tile_size, h, w)\n\n    if tile_size <= 0:\n        logger.debug(\"Upscaling %s without tiling\", img.shape)\n        return model(img)\n\n    stride = tile_size - tile_overlap\n    h_idx_list = list(range(0, h - tile_size, stride)) + [h - tile_size]\n    w_idx_list = list(range(0, w - tile_size, stride)) + [w - tile_size]\n    result = torch.zeros(\n        b,\n        c,\n        h * scale,\n        w * scale,\n        device=device,\n        dtype=img.dtype,\n    )\n    weights = torch.zeros_like(result)\n    logger.debug(\"Upscaling %s to %s with tiles\", img.shape, result.shape)\n    with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc, disable=not shared.opts.enable_upscale_progressbar) as pbar:\n        for h_idx in h_idx_list:\n            if shared.state.interrupted or shared.state.skipped:\n                break\n\n            for w_idx in w_idx_list:\n                if shared.state.interrupted or shared.state.skipped:\n                    break\n\n                # Only move this patch to the device if it's not already there.\n                in_patch = img[\n                    ...,\n                    h_idx : h_idx + tile_size,\n                    w_idx : w_idx + tile_size,\n                ].to(device=device)\n\n                out_patch = model(in_patch)\n\n                result[\n                    ...,\n                    h_idx * scale : (h_idx + tile_size) * scale,\n                    w_idx * scale : (w_idx + tile_size) * scale,\n                ].add_(out_patch)\n\n                out_patch_mask = torch.ones_like(out_patch)\n\n                weights[\n                    ...,\n                    h_idx * scale : (h_idx + tile_size) * scale,\n                    w_idx * scale : (w_idx + tile_size) * scale,\n                ].add_(out_patch_mask)\n\n                pbar.update(1)\n\n    output = result.div_(weights)\n\n    return output\n\n\ndef upscale_2(\n    img: Image.Image,\n    model,\n    *,\n    tile_size: int,\n    tile_overlap: int,\n    scale: int,\n    desc: str,\n):\n    \"\"\"\n    Convenience wrapper around `tiled_upscale_2` that handles PIL images.\n    \"\"\"\n    param = torch_utils.get_param(model)\n    tensor = pil_image_to_torch_bgr(img).to(dtype=param.dtype).unsqueeze(0)  # add batch dimension\n\n    with torch.no_grad():\n        output = tiled_upscale_2(\n            tensor,\n            model,\n            tile_size=tile_size,\n            tile_overlap=tile_overlap,\n            scale=scale,\n            desc=desc,\n            device=param.device,\n        )\n    return torch_bgr_to_pil_image(output)\n"
  },
  {
    "path": "modules/util.py",
    "content": "import os\r\nimport re\r\n\r\nfrom modules import shared\r\nfrom modules.paths_internal import script_path, cwd\r\n\r\n\r\ndef natural_sort_key(s, regex=re.compile('([0-9]+)')):\r\n    return [int(text) if text.isdigit() else text.lower() for text in regex.split(s)]\r\n\r\n\r\ndef listfiles(dirname):\r\n    filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=natural_sort_key) if not x.startswith(\".\")]\r\n    return [file for file in filenames if os.path.isfile(file)]\r\n\r\n\r\ndef html_path(filename):\r\n    return os.path.join(script_path, \"html\", filename)\r\n\r\n\r\ndef html(filename):\r\n    path = html_path(filename)\r\n\r\n    try:\r\n        with open(path, encoding=\"utf8\") as file:\r\n            return file.read()\r\n    except OSError:\r\n        return \"\"\r\n\r\n\r\ndef walk_files(path, allowed_extensions=None):\r\n    if not os.path.exists(path):\r\n        return\r\n\r\n    if allowed_extensions is not None:\r\n        allowed_extensions = set(allowed_extensions)\r\n\r\n    items = list(os.walk(path, followlinks=True))\r\n    items = sorted(items, key=lambda x: natural_sort_key(x[0]))\r\n\r\n    for root, _, files in items:\r\n        for filename in sorted(files, key=natural_sort_key):\r\n            if allowed_extensions is not None:\r\n                _, ext = os.path.splitext(filename)\r\n                if ext.lower() not in allowed_extensions:\r\n                    continue\r\n\r\n            if not shared.opts.list_hidden_files and (\"/.\" in root or \"\\\\.\" in root):\r\n                continue\r\n\r\n            yield os.path.join(root, filename)\r\n\r\n\r\ndef ldm_print(*args, **kwargs):\r\n    if shared.opts.hide_ldm_prints:\r\n        return\r\n\r\n    print(*args, **kwargs)\r\n\r\n\r\ndef truncate_path(target_path, base_path=cwd):\r\n    abs_target, abs_base = os.path.abspath(target_path), os.path.abspath(base_path)\r\n    try:\r\n        if os.path.commonpath([abs_target, abs_base]) == abs_base:\r\n            return os.path.relpath(abs_target, abs_base)\r\n    except ValueError:\r\n        pass\r\n    return abs_target\r\n\r\n\r\nclass MassFileListerCachedDir:\r\n    \"\"\"A class that caches file metadata for a specific directory.\"\"\"\r\n\r\n    def __init__(self, dirname):\r\n        self.files = None\r\n        self.files_cased = None\r\n        self.dirname = dirname\r\n\r\n        stats = ((x.name, x.stat(follow_symlinks=False)) for x in os.scandir(self.dirname))\r\n        files = [(n, s.st_mtime, s.st_ctime) for n, s in stats]\r\n        self.files = {x[0].lower(): x for x in files}\r\n        self.files_cased = {x[0]: x for x in files}\r\n\r\n    def update_entry(self, filename):\r\n        \"\"\"Add a file to the cache\"\"\"\r\n        file_path = os.path.join(self.dirname, filename)\r\n        try:\r\n            stat = os.stat(file_path)\r\n            entry = (filename, stat.st_mtime, stat.st_ctime)\r\n            self.files[filename.lower()] = entry\r\n            self.files_cased[filename] = entry\r\n        except FileNotFoundError as e:\r\n            print(f'MassFileListerCachedDir.add_entry: \"{file_path}\" {e}')\r\n\r\n\r\nclass MassFileLister:\r\n    \"\"\"A class that provides a way to check for the existence and mtime/ctile of files without doing more than one stat call per file.\"\"\"\r\n\r\n    def __init__(self):\r\n        self.cached_dirs = {}\r\n\r\n    def find(self, path):\r\n        \"\"\"\r\n        Find the metadata for a file at the given path.\r\n\r\n        Returns:\r\n            tuple or None: A tuple of (name, mtime, ctime) if the file exists, or None if it does not.\r\n        \"\"\"\r\n\r\n        dirname, filename = os.path.split(path)\r\n\r\n        cached_dir = self.cached_dirs.get(dirname)\r\n        if cached_dir is None:\r\n            cached_dir = MassFileListerCachedDir(dirname)\r\n            self.cached_dirs[dirname] = cached_dir\r\n\r\n        stats = cached_dir.files_cased.get(filename)\r\n        if stats is not None:\r\n            return stats\r\n\r\n        stats = cached_dir.files.get(filename.lower())\r\n        if stats is None:\r\n            return None\r\n\r\n        try:\r\n            os_stats = os.stat(path, follow_symlinks=False)\r\n            return filename, os_stats.st_mtime, os_stats.st_ctime\r\n        except Exception:\r\n            return None\r\n\r\n    def exists(self, path):\r\n        \"\"\"Check if a file exists at the given path.\"\"\"\r\n\r\n        return self.find(path) is not None\r\n\r\n    def mctime(self, path):\r\n        \"\"\"\r\n        Get the modification and creation times for a file at the given path.\r\n\r\n        Returns:\r\n            tuple: A tuple of (mtime, ctime) if the file exists, or (0, 0) if it does not.\r\n        \"\"\"\r\n\r\n        stats = self.find(path)\r\n        return (0, 0) if stats is None else stats[1:3]\r\n\r\n    def reset(self):\r\n        \"\"\"Clear the cache of all directories.\"\"\"\r\n        self.cached_dirs.clear()\r\n\r\n    def update_file_entry(self, path):\r\n        \"\"\"Update the cache for a specific directory.\"\"\"\r\n        dirname, filename = os.path.split(path)\r\n        if cached_dir := self.cached_dirs.get(dirname):\r\n            cached_dir.update_entry(filename)\r\n\r\ndef topological_sort(dependencies):\r\n    \"\"\"Accepts a dictionary mapping name to its dependencies, returns a list of names ordered according to dependencies.\r\n    Ignores errors relating to missing dependencies or circular dependencies\r\n    \"\"\"\r\n\r\n    visited = {}\r\n    result = []\r\n\r\n    def inner(name):\r\n        visited[name] = True\r\n\r\n        for dep in dependencies.get(name, []):\r\n            if dep in dependencies and dep not in visited:\r\n                inner(dep)\r\n\r\n        result.append(name)\r\n\r\n    for depname in dependencies:\r\n        if depname not in visited:\r\n            inner(depname)\r\n\r\n    return result\r\n\r\n\r\ndef open_folder(path):\r\n    \"\"\"Open a folder in the file manager of the respect OS.\"\"\"\r\n    # import at function level to avoid potential issues\r\n    import gradio as gr\r\n    import platform\r\n    import sys\r\n    import subprocess\r\n\r\n    if not os.path.exists(path):\r\n        msg = f'Folder \"{path}\" does not exist. after you save an image, the folder will be created.'\r\n        print(msg)\r\n        gr.Info(msg)\r\n        return\r\n    elif not os.path.isdir(path):\r\n        msg = f\"\"\"\r\nWARNING\r\nAn open_folder request was made with an path that is not a folder.\r\nThis could be an error or a malicious attempt to run code on your computer.\r\nRequested path was: {path}\r\n\"\"\"\r\n        print(msg, file=sys.stderr)\r\n        gr.Warning(msg)\r\n        return\r\n\r\n    path = os.path.normpath(path)\r\n    if platform.system() == \"Windows\":\r\n        os.startfile(path)\r\n    elif platform.system() == \"Darwin\":\r\n        subprocess.Popen([\"open\", path])\r\n    elif \"microsoft-standard-WSL2\" in platform.uname().release:\r\n        subprocess.Popen([\"explorer.exe\", subprocess.check_output([\"wslpath\", \"-w\", path])])\r\n    else:\r\n        subprocess.Popen([\"xdg-open\", path])\r\n"
  },
  {
    "path": "modules/xlmr.py",
    "content": "from transformers import BertPreTrainedModel, BertConfig\nimport torch.nn as nn\nimport torch\nfrom transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig\nfrom transformers import XLMRobertaModel,XLMRobertaTokenizer\nfrom typing import Optional\n\nfrom modules import torch_utils\n\n\nclass BertSeriesConfig(BertConfig):\n    def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act=\"gelu\", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type=\"absolute\", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn=\"average\",learn_encoder=False,model_type='bert',**kwargs):\n\n        super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)\n        self.project_dim = project_dim\n        self.pooler_fn = pooler_fn\n        self.learn_encoder = learn_encoder\n\nclass RobertaSeriesConfig(XLMRobertaConfig):\n    def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):\n        super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)\n        self.project_dim = project_dim\n        self.pooler_fn = pooler_fn\n        self.learn_encoder = learn_encoder\n\n\nclass BertSeriesModelWithTransformation(BertPreTrainedModel):\n\n    _keys_to_ignore_on_load_unexpected = [r\"pooler\"]\n    _keys_to_ignore_on_load_missing = [r\"position_ids\", r\"predictions.decoder.bias\"]\n    config_class = BertSeriesConfig\n\n    def __init__(self, config=None, **kargs):\n        # modify initialization for autoloading\n        if config is None:\n            config = XLMRobertaConfig()\n            config.attention_probs_dropout_prob= 0.1\n            config.bos_token_id=0\n            config.eos_token_id=2\n            config.hidden_act='gelu'\n            config.hidden_dropout_prob=0.1\n            config.hidden_size=1024\n            config.initializer_range=0.02\n            config.intermediate_size=4096\n            config.layer_norm_eps=1e-05\n            config.max_position_embeddings=514\n\n            config.num_attention_heads=16\n            config.num_hidden_layers=24\n            config.output_past=True\n            config.pad_token_id=1\n            config.position_embedding_type= \"absolute\"\n\n            config.type_vocab_size= 1\n            config.use_cache=True\n            config.vocab_size= 250002\n            config.project_dim = 768\n            config.learn_encoder = False\n        super().__init__(config)\n        self.roberta = XLMRobertaModel(config)\n        self.transformation = nn.Linear(config.hidden_size,config.project_dim)\n        self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n        self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')\n        self.pooler = lambda x: x[:,0]\n        self.post_init()\n\n    def encode(self,c):\n        device = torch_utils.get_param(self).device\n        text = self.tokenizer(c,\n                        truncation=True,\n                        max_length=77,\n                        return_length=False,\n                        return_overflowing_tokens=False,\n                        padding=\"max_length\",\n                        return_tensors=\"pt\")\n        text[\"input_ids\"] = torch.tensor(text[\"input_ids\"]).to(device)\n        text[\"attention_mask\"] = torch.tensor(\n            text['attention_mask']).to(device)\n        features = self(**text)\n        return features['projection_state']\n\n    def forward(\n        self,\n        input_ids: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        token_type_ids: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.Tensor] = None,\n        head_mask: Optional[torch.Tensor] = None,\n        inputs_embeds: Optional[torch.Tensor] = None,\n        encoder_hidden_states: Optional[torch.Tensor] = None,\n        encoder_attention_mask: Optional[torch.Tensor] = None,\n        output_attentions: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n    ) :\n        r\"\"\"\n        \"\"\"\n\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n\n        outputs = self.roberta(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            token_type_ids=token_type_ids,\n            position_ids=position_ids,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            encoder_hidden_states=encoder_hidden_states,\n            encoder_attention_mask=encoder_attention_mask,\n            output_attentions=output_attentions,\n            output_hidden_states=True,\n            return_dict=return_dict,\n        )\n\n        # last module outputs\n        sequence_output = outputs[0]\n\n\n        # project every module\n        sequence_output_ln = self.pre_LN(sequence_output)\n\n        # pooler\n        pooler_output = self.pooler(sequence_output_ln)\n        pooler_output = self.transformation(pooler_output)\n        projection_state = self.transformation(outputs.last_hidden_state)\n\n        return {\n            'pooler_output':pooler_output,\n            'last_hidden_state':outputs.last_hidden_state,\n            'hidden_states':outputs.hidden_states,\n            'attentions':outputs.attentions,\n            'projection_state':projection_state,\n            'sequence_out': sequence_output\n        }\n\n\nclass RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):\n    base_model_prefix = 'roberta'\n    config_class= RobertaSeriesConfig\n"
  },
  {
    "path": "modules/xlmr_m18.py",
    "content": "from transformers import BertPreTrainedModel,BertConfig\nimport torch.nn as nn\nimport torch\nfrom transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig\nfrom transformers import XLMRobertaModel,XLMRobertaTokenizer\nfrom typing import Optional\nfrom modules import torch_utils\n\n\nclass BertSeriesConfig(BertConfig):\n    def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act=\"gelu\", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type=\"absolute\", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn=\"average\",learn_encoder=False,model_type='bert',**kwargs):\n\n        super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs)\n        self.project_dim = project_dim\n        self.pooler_fn = pooler_fn\n        self.learn_encoder = learn_encoder\n\nclass RobertaSeriesConfig(XLMRobertaConfig):\n    def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs):\n        super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)\n        self.project_dim = project_dim\n        self.pooler_fn = pooler_fn\n        self.learn_encoder = learn_encoder\n\n\nclass BertSeriesModelWithTransformation(BertPreTrainedModel):\n\n    _keys_to_ignore_on_load_unexpected = [r\"pooler\"]\n    _keys_to_ignore_on_load_missing = [r\"position_ids\", r\"predictions.decoder.bias\"]\n    config_class = BertSeriesConfig\n\n    def __init__(self, config=None, **kargs):\n        # modify initialization for autoloading\n        if config is None:\n            config = XLMRobertaConfig()\n            config.attention_probs_dropout_prob= 0.1\n            config.bos_token_id=0\n            config.eos_token_id=2\n            config.hidden_act='gelu'\n            config.hidden_dropout_prob=0.1\n            config.hidden_size=1024\n            config.initializer_range=0.02\n            config.intermediate_size=4096\n            config.layer_norm_eps=1e-05\n            config.max_position_embeddings=514\n\n            config.num_attention_heads=16\n            config.num_hidden_layers=24\n            config.output_past=True\n            config.pad_token_id=1\n            config.position_embedding_type= \"absolute\"\n\n            config.type_vocab_size= 1\n            config.use_cache=True\n            config.vocab_size= 250002\n            config.project_dim = 1024\n            config.learn_encoder = False\n        super().__init__(config)\n        self.roberta = XLMRobertaModel(config)\n        self.transformation = nn.Linear(config.hidden_size,config.project_dim)\n        # self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n        self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')\n        # self.pooler = lambda x: x[:,0]\n        # self.post_init()\n\n        self.has_pre_transformation = True\n        if self.has_pre_transformation:\n            self.transformation_pre = nn.Linear(config.hidden_size, config.project_dim)\n            self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n        self.post_init()\n\n    def encode(self,c):\n        device = torch_utils.get_param(self).device\n        text = self.tokenizer(c,\n                        truncation=True,\n                        max_length=77,\n                        return_length=False,\n                        return_overflowing_tokens=False,\n                        padding=\"max_length\",\n                        return_tensors=\"pt\")\n        text[\"input_ids\"] = torch.tensor(text[\"input_ids\"]).to(device)\n        text[\"attention_mask\"] = torch.tensor(\n            text['attention_mask']).to(device)\n        features = self(**text)\n        return features['projection_state']\n\n    def forward(\n        self,\n        input_ids: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        token_type_ids: Optional[torch.Tensor] = None,\n        position_ids: Optional[torch.Tensor] = None,\n        head_mask: Optional[torch.Tensor] = None,\n        inputs_embeds: Optional[torch.Tensor] = None,\n        encoder_hidden_states: Optional[torch.Tensor] = None,\n        encoder_attention_mask: Optional[torch.Tensor] = None,\n        output_attentions: Optional[bool] = None,\n        return_dict: Optional[bool] = None,\n        output_hidden_states: Optional[bool] = None,\n    ) :\n        r\"\"\"\n        \"\"\"\n\n        return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n\n\n        outputs = self.roberta(\n            input_ids=input_ids,\n            attention_mask=attention_mask,\n            token_type_ids=token_type_ids,\n            position_ids=position_ids,\n            head_mask=head_mask,\n            inputs_embeds=inputs_embeds,\n            encoder_hidden_states=encoder_hidden_states,\n            encoder_attention_mask=encoder_attention_mask,\n            output_attentions=output_attentions,\n            output_hidden_states=True,\n            return_dict=return_dict,\n        )\n\n        # # last module outputs\n        # sequence_output = outputs[0]\n\n\n        # # project every module\n        # sequence_output_ln = self.pre_LN(sequence_output)\n\n        # # pooler\n        # pooler_output = self.pooler(sequence_output_ln)\n        # pooler_output = self.transformation(pooler_output)\n        # projection_state = self.transformation(outputs.last_hidden_state)\n\n        if self.has_pre_transformation:\n            sequence_output2 = outputs[\"hidden_states\"][-2]\n            sequence_output2 = self.pre_LN(sequence_output2)\n            projection_state2 = self.transformation_pre(sequence_output2)\n\n            return {\n                \"projection_state\": projection_state2,\n                \"last_hidden_state\": outputs.last_hidden_state,\n                \"hidden_states\": outputs.hidden_states,\n                \"attentions\": outputs.attentions,\n            }\n        else:\n            projection_state = self.transformation(outputs.last_hidden_state)\n            return {\n                \"projection_state\": projection_state,\n                \"last_hidden_state\": outputs.last_hidden_state,\n                \"hidden_states\": outputs.hidden_states,\n                \"attentions\": outputs.attentions,\n            }\n\n\n        # return {\n        #     'pooler_output':pooler_output,\n        #     'last_hidden_state':outputs.last_hidden_state,\n        #     'hidden_states':outputs.hidden_states,\n        #     'attentions':outputs.attentions,\n        #     'projection_state':projection_state,\n        #     'sequence_out': sequence_output\n        # }\n\n\nclass RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):\n    base_model_prefix = 'roberta'\n    config_class= RobertaSeriesConfig\n"
  },
  {
    "path": "modules/xpu_specific.py",
    "content": "from modules import shared\nfrom modules.sd_hijack_utils import CondFunc\n\nhas_ipex = False\ntry:\n    import torch\n    import intel_extension_for_pytorch as ipex # noqa: F401\n    has_ipex = True\nexcept Exception:\n    pass\n\n\ndef check_for_xpu():\n    return has_ipex and hasattr(torch, 'xpu') and torch.xpu.is_available()\n\n\ndef get_xpu_device_string():\n    if shared.cmd_opts.device_id is not None:\n        return f\"xpu:{shared.cmd_opts.device_id}\"\n    return \"xpu\"\n\n\ndef torch_xpu_gc():\n    with torch.xpu.device(get_xpu_device_string()):\n        torch.xpu.empty_cache()\n\n\nhas_xpu = check_for_xpu()\n\n\n# Arc GPU cannot allocate a single block larger than 4GB: https://github.com/intel/compute-runtime/issues/627\n# Here we implement a slicing algorithm to split large batch size into smaller chunks,\n# so that SDPA of each chunk wouldn't require any allocation larger than ARC_SINGLE_ALLOCATION_LIMIT.\n# The heuristic limit (TOTAL_VRAM // 8) is tuned for Intel Arc A770 16G and Arc A750 8G,\n# which is the best trade-off between VRAM usage and performance.\nARC_SINGLE_ALLOCATION_LIMIT = {}\norig_sdp_attn_func = torch.nn.functional.scaled_dot_product_attention\ndef torch_xpu_scaled_dot_product_attention(\n    query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, *args, **kwargs\n):\n    # cast to same dtype first\n    key = key.to(query.dtype)\n    value = value.to(query.dtype)\n    if attn_mask is not None and attn_mask.dtype != torch.bool:\n        attn_mask = attn_mask.to(query.dtype)\n\n    N = query.shape[:-2]  # Batch size\n    L = query.size(-2)  # Target sequence length\n    E = query.size(-1)  # Embedding dimension of the query and key\n    S = key.size(-2)  # Source sequence length\n    Ev = value.size(-1)  # Embedding dimension of the value\n\n    total_batch_size = torch.numel(torch.empty(N))\n    device_id = query.device.index\n    if device_id not in ARC_SINGLE_ALLOCATION_LIMIT:\n        ARC_SINGLE_ALLOCATION_LIMIT[device_id] = min(torch.xpu.get_device_properties(device_id).total_memory // 8, 4 * 1024 * 1024 * 1024)\n    batch_size_limit = max(1, ARC_SINGLE_ALLOCATION_LIMIT[device_id] // (L * S * query.element_size()))\n\n    if total_batch_size <= batch_size_limit:\n        return orig_sdp_attn_func(\n            query,\n            key,\n            value,\n            attn_mask,\n            dropout_p,\n            is_causal,\n            *args, **kwargs\n        )\n\n    query = torch.reshape(query, (-1, L, E))\n    key = torch.reshape(key, (-1, S, E))\n    value = torch.reshape(value, (-1, S, Ev))\n    if attn_mask is not None:\n        attn_mask = attn_mask.view(-1, L, S)\n    chunk_count = (total_batch_size + batch_size_limit - 1) // batch_size_limit\n    outputs = []\n    for i in range(chunk_count):\n        attn_mask_chunk = (\n            None\n            if attn_mask is None\n            else attn_mask[i * batch_size_limit : (i + 1) * batch_size_limit, :, :]\n        )\n        chunk_output = orig_sdp_attn_func(\n            query[i * batch_size_limit : (i + 1) * batch_size_limit, :, :],\n            key[i * batch_size_limit : (i + 1) * batch_size_limit, :, :],\n            value[i * batch_size_limit : (i + 1) * batch_size_limit, :, :],\n            attn_mask_chunk,\n            dropout_p,\n            is_causal,\n            *args, **kwargs\n        )\n        outputs.append(chunk_output)\n    result = torch.cat(outputs, dim=0)\n    return torch.reshape(result, (*N, L, Ev))\n\n\ndef is_xpu_device(device: str | torch.device = None):\n    if device is None:\n        return False\n    if isinstance(device, str):\n        return device.startswith(\"xpu\")\n    return device.type == \"xpu\"\n\n\nif has_xpu:\n    try:\n        # torch.Generator supports \"xpu\" device since 2.1\n        torch.Generator(\"xpu\")\n    except RuntimeError:\n        # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device (for torch < 2.1)\n        CondFunc('torch.Generator',\n            lambda orig_func, device=None: torch.xpu.Generator(device),\n            lambda orig_func, device=None: is_xpu_device(device))\n\n    # W/A for some OPs that could not handle different input dtypes\n    CondFunc('torch.nn.functional.layer_norm',\n        lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:\n        orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs),\n        lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:\n        weight is not None and input.dtype != weight.data.dtype)\n    CondFunc('torch.nn.modules.GroupNorm.forward',\n        lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),\n        lambda orig_func, self, input: input.dtype != self.weight.data.dtype)\n    CondFunc('torch.nn.modules.linear.Linear.forward',\n        lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),\n        lambda orig_func, self, input: input.dtype != self.weight.data.dtype)\n    CondFunc('torch.nn.modules.conv.Conv2d.forward',\n        lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),\n        lambda orig_func, self, input: input.dtype != self.weight.data.dtype)\n    CondFunc('torch.bmm',\n        lambda orig_func, input, mat2, out=None: orig_func(input.to(mat2.dtype), mat2, out=out),\n        lambda orig_func, input, mat2, out=None: input.dtype != mat2.dtype)\n    CondFunc('torch.cat',\n        lambda orig_func, tensors, dim=0, out=None: orig_func([t.to(tensors[0].dtype) for t in tensors], dim=dim, out=out),\n        lambda orig_func, tensors, dim=0, out=None: not all(t.dtype == tensors[0].dtype for t in tensors))\n    CondFunc('torch.nn.functional.scaled_dot_product_attention',\n        lambda orig_func, *args, **kwargs: torch_xpu_scaled_dot_product_attention(*args, **kwargs),\n        lambda orig_func, query, *args, **kwargs: query.is_xpu)\n"
  },
  {
    "path": "package.json",
    "content": "{\n  \"name\": \"stable-diffusion-webui\",\n  \"version\": \"0.0.0\",\n  \"devDependencies\": {\n    \"eslint\": \"^8.40.0\"\n  },\n  \"scripts\": {\n    \"lint\": \"eslint .\",\n    \"fix\": \"eslint --fix .\"\n  }\n}\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[tool.ruff]\n\ntarget-version = \"py39\"\n\n[tool.ruff.lint]\n\nextend-select = [\n  \"B\",\n  \"C\",\n  \"I\",\n  \"W\",\n]\n\nexclude = [\n\t\"extensions\",\n\t\"extensions-disabled\",\n]\n\nignore = [\n\t\"E501\", # Line too long\n\t\"E721\", # Do not compare types, use `isinstance`\n\t\"E731\", # Do not assign a `lambda` expression, use a `def`\n\t\n\t\"I001\", # Import block is un-sorted or un-formatted\n\t\"C901\", # Function is too complex\n\t\"C408\", # Rewrite as a literal\n\t\"W605\", # invalid escape sequence, messes with some docstrings\n]\n\n[tool.ruff.lint.per-file-ignores]\n\"webui.py\" = [\"E402\"]  # Module level import not at top of file\n\n[tool.ruff.lint.flake8-bugbear]\n# Allow default arguments like, e.g., `data: List[str] = fastapi.Query(None)`.\nextend-immutable-calls = [\"fastapi.Depends\", \"fastapi.security.HTTPBasic\"]\n\n[tool.pytest.ini_options]\nbase_url = \"http://127.0.0.1:7860\"\n"
  },
  {
    "path": "requirements-test.txt",
    "content": "pytest-base-url~=2.0\npytest-cov~=4.0\npytest~=7.3\n"
  },
  {
    "path": "requirements.txt",
    "content": "GitPython\r\nPillow\r\naccelerate\r\n\r\nblendmodes\r\nclean-fid\r\ndiskcache\r\neinops\r\nfacexlib\r\nfastapi>=0.90.1\r\ngradio==3.41.2\r\ninflection\r\njsonmerge\r\nkornia\r\nlark\r\nnumpy\r\nomegaconf\r\nopen-clip-torch\r\n\r\npiexif\r\nprotobuf==3.20.0\r\npsutil\r\npytorch_lightning\r\nrequests\r\nresize-right\r\n\r\nsafetensors\r\nscikit-image>=0.19\r\ntomesd\r\ntorch\r\ntorchdiffeq\r\ntorchsde\r\ntransformers==4.30.2\r\npillow-avif-plugin==1.4.3"
  },
  {
    "path": "requirements_npu.txt",
    "content": "cloudpickle\ndecorator\nsynr==0.5.0\ntornado\n"
  },
  {
    "path": "requirements_versions.txt",
    "content": "setuptools==69.5.1  # temp fix for compatibility with some old packages\r\nGitPython==3.1.32\r\nPillow==9.5.0\r\naccelerate==0.21.0\r\nblendmodes==2022\r\nclean-fid==0.1.35\r\ndiskcache==5.6.3\r\neinops==0.4.1\r\nfacexlib==0.3.0\r\nfastapi==0.94.0\r\ngradio==3.41.2\r\nhttpcore==0.15\r\ninflection==0.5.1\r\njsonmerge==1.8.0\r\nkornia==0.6.7\r\nlark==1.1.2\r\nnumpy==1.26.2\r\nomegaconf==2.2.3\r\nopen-clip-torch==2.20.0\r\npiexif==1.1.3\r\nprotobuf==3.20.0\r\npsutil==5.9.5\r\npytorch_lightning==1.9.4\r\nresize-right==0.0.2\r\nsafetensors==0.4.2\r\nscikit-image==0.21.0\r\nspandrel==0.3.4\r\nspandrel-extra-arches==0.1.1\r\ntomesd==0.1.3\r\ntorch\r\ntorchdiffeq==0.2.3\r\ntorchsde==0.2.6\r\ntransformers==4.30.2\r\nhttpx==0.24.1\r\npillow-avif-plugin==1.4.3\r\n"
  },
  {
    "path": "script.js",
    "content": "function gradioApp() {\n    const elems = document.getElementsByTagName('gradio-app');\n    const elem = elems.length == 0 ? document : elems[0];\n\n    if (elem !== document) {\n        elem.getElementById = function(id) {\n            return document.getElementById(id);\n        };\n    }\n    return elem.shadowRoot ? elem.shadowRoot : elem;\n}\n\n/**\n * Get the currently selected top-level UI tab button (e.g. the button that says \"Extras\").\n */\nfunction get_uiCurrentTab() {\n    return gradioApp().querySelector('#tabs > .tab-nav > button.selected');\n}\n\n/**\n * Get the first currently visible top-level UI tab content (e.g. the div hosting the \"txt2img\" UI).\n */\nfunction get_uiCurrentTabContent() {\n    return gradioApp().querySelector('#tabs > .tabitem[id^=tab_]:not([style*=\"display: none\"])');\n}\n\nvar uiUpdateCallbacks = [];\nvar uiAfterUpdateCallbacks = [];\nvar uiLoadedCallbacks = [];\nvar uiTabChangeCallbacks = [];\nvar optionsChangedCallbacks = [];\nvar optionsAvailableCallbacks = [];\nvar uiAfterUpdateTimeout = null;\nvar uiCurrentTab = null;\n\n/**\n * Register callback to be called at each UI update.\n * The callback receives an array of MutationRecords as an argument.\n */\nfunction onUiUpdate(callback) {\n    uiUpdateCallbacks.push(callback);\n}\n\n/**\n * Register callback to be called soon after UI updates.\n * The callback receives no arguments.\n *\n * This is preferred over `onUiUpdate` if you don't need\n * access to the MutationRecords, as your function will\n * not be called quite as often.\n */\nfunction onAfterUiUpdate(callback) {\n    uiAfterUpdateCallbacks.push(callback);\n}\n\n/**\n * Register callback to be called when the UI is loaded.\n * The callback receives no arguments.\n */\nfunction onUiLoaded(callback) {\n    uiLoadedCallbacks.push(callback);\n}\n\n/**\n * Register callback to be called when the UI tab is changed.\n * The callback receives no arguments.\n */\nfunction onUiTabChange(callback) {\n    uiTabChangeCallbacks.push(callback);\n}\n\n/**\n * Register callback to be called when the options are changed.\n * The callback receives no arguments.\n * @param callback\n */\nfunction onOptionsChanged(callback) {\n    optionsChangedCallbacks.push(callback);\n}\n\n/**\n * Register callback to be called when the options (in opts global variable) are available.\n * The callback receives no arguments.\n * If you register the callback after the options are available, it's just immediately called.\n */\nfunction onOptionsAvailable(callback) {\n    if (Object.keys(opts).length != 0) {\n        callback();\n        return;\n    }\n\n    optionsAvailableCallbacks.push(callback);\n}\n\nfunction executeCallbacks(queue, arg) {\n    for (const callback of queue) {\n        try {\n            callback(arg);\n        } catch (e) {\n            console.error(\"error running callback\", callback, \":\", e);\n        }\n    }\n}\n\n/**\n * Schedule the execution of the callbacks registered with onAfterUiUpdate.\n * The callbacks are executed after a short while, unless another call to this function\n * is made before that time. IOW, the callbacks are executed only once, even\n * when there are multiple mutations observed.\n */\nfunction scheduleAfterUiUpdateCallbacks() {\n    clearTimeout(uiAfterUpdateTimeout);\n    uiAfterUpdateTimeout = setTimeout(function() {\n        executeCallbacks(uiAfterUpdateCallbacks);\n    }, 200);\n}\n\nvar executedOnLoaded = false;\n\ndocument.addEventListener(\"DOMContentLoaded\", function() {\n    var mutationObserver = new MutationObserver(function(m) {\n        if (!executedOnLoaded && gradioApp().querySelector('#txt2img_prompt')) {\n            executedOnLoaded = true;\n            executeCallbacks(uiLoadedCallbacks);\n        }\n\n        executeCallbacks(uiUpdateCallbacks, m);\n        scheduleAfterUiUpdateCallbacks();\n        const newTab = get_uiCurrentTab();\n        if (newTab && (newTab !== uiCurrentTab)) {\n            uiCurrentTab = newTab;\n            executeCallbacks(uiTabChangeCallbacks);\n        }\n    });\n    mutationObserver.observe(gradioApp(), {childList: true, subtree: true});\n});\n\n/**\n * Add keyboard shortcuts:\n * Ctrl+Enter to start/restart a generation\n * Alt/Option+Enter to skip a generation\n * Esc to interrupt a generation\n */\ndocument.addEventListener('keydown', function(e) {\n    const isEnter = e.key === 'Enter' || e.keyCode === 13;\n    const isCtrlKey = e.metaKey || e.ctrlKey;\n    const isAltKey = e.altKey;\n    const isEsc = e.key === 'Escape';\n\n    const generateButton = get_uiCurrentTabContent().querySelector('button[id$=_generate]');\n    const interruptButton = get_uiCurrentTabContent().querySelector('button[id$=_interrupt]');\n    const skipButton = get_uiCurrentTabContent().querySelector('button[id$=_skip]');\n\n    if (isCtrlKey && isEnter) {\n        if (interruptButton.style.display === 'block') {\n            interruptButton.click();\n            const callback = (mutationList) => {\n                for (const mutation of mutationList) {\n                    if (mutation.type === 'attributes' && mutation.attributeName === 'style') {\n                        if (interruptButton.style.display === 'none') {\n                            generateButton.click();\n                            observer.disconnect();\n                        }\n                    }\n                }\n            };\n            const observer = new MutationObserver(callback);\n            observer.observe(interruptButton, {attributes: true});\n        } else {\n            generateButton.click();\n        }\n        e.preventDefault();\n    }\n\n    if (isAltKey && isEnter) {\n        skipButton.click();\n        e.preventDefault();\n    }\n\n    if (isEsc) {\n        const globalPopup = document.querySelector('.global-popup');\n        const lightboxModal = document.querySelector('#lightboxModal');\n        if (!globalPopup || globalPopup.style.display === 'none') {\n            if (document.activeElement === lightboxModal) return;\n            if (interruptButton.style.display === 'block') {\n                interruptButton.click();\n                e.preventDefault();\n            }\n        }\n    }\n});\n\n/**\n * checks that a UI element is not in another hidden element or tab content\n */\nfunction uiElementIsVisible(el) {\n    if (el === document) {\n        return true;\n    }\n\n    const computedStyle = getComputedStyle(el);\n    const isVisible = computedStyle.display !== 'none';\n\n    if (!isVisible) return false;\n    return uiElementIsVisible(el.parentNode);\n}\n\nfunction uiElementInSight(el) {\n    const clRect = el.getBoundingClientRect();\n    const windowHeight = window.innerHeight;\n    const isOnScreen = clRect.bottom > 0 && clRect.top < windowHeight;\n\n    return isOnScreen;\n}\n"
  },
  {
    "path": "scripts/custom_code.py",
    "content": "import modules.scripts as scripts\r\nimport gradio as gr\r\nimport ast\r\nimport copy\r\n\r\nfrom modules.processing import Processed\r\nfrom modules.shared import cmd_opts\r\n\r\n\r\ndef convertExpr2Expression(expr):\r\n    expr.lineno = 0\r\n    expr.col_offset = 0\r\n    result = ast.Expression(expr.value, lineno=0, col_offset = 0)\r\n\r\n    return result\r\n\r\n\r\ndef exec_with_return(code, module):\r\n    \"\"\"\r\n    like exec() but can return values\r\n    https://stackoverflow.com/a/52361938/5862977\r\n    \"\"\"\r\n    code_ast = ast.parse(code)\r\n\r\n    init_ast = copy.deepcopy(code_ast)\r\n    init_ast.body = code_ast.body[:-1]\r\n\r\n    last_ast = copy.deepcopy(code_ast)\r\n    last_ast.body = code_ast.body[-1:]\r\n\r\n    exec(compile(init_ast, \"<ast>\", \"exec\"), module.__dict__)\r\n    if type(last_ast.body[0]) == ast.Expr:\r\n        return eval(compile(convertExpr2Expression(last_ast.body[0]), \"<ast>\", \"eval\"), module.__dict__)\r\n    else:\r\n        exec(compile(last_ast, \"<ast>\", \"exec\"), module.__dict__)\r\n\r\n\r\nclass Script(scripts.Script):\r\n\r\n    def title(self):\r\n        return \"Custom code\"\r\n\r\n    def show(self, is_img2img):\r\n        return cmd_opts.allow_code\r\n\r\n    def ui(self, is_img2img):\r\n        example = \"\"\"from modules.processing import process_images\r\n\r\np.width = 768\r\np.height = 768\r\np.batch_size = 2\r\np.steps = 10\r\n\r\nreturn process_images(p)\r\n\"\"\"\r\n\r\n\r\n        code = gr.Code(value=example, language=\"python\", label=\"Python code\", elem_id=self.elem_id(\"code\"))\r\n        indent_level = gr.Number(label='Indent level', value=2, precision=0, elem_id=self.elem_id(\"indent_level\"))\r\n\r\n        return [code, indent_level]\r\n\r\n    def run(self, p, code, indent_level):\r\n        assert cmd_opts.allow_code, '--allow-code option must be enabled'\r\n\r\n        display_result_data = [[], -1, \"\"]\r\n\r\n        def display(imgs, s=display_result_data[1], i=display_result_data[2]):\r\n            display_result_data[0] = imgs\r\n            display_result_data[1] = s\r\n            display_result_data[2] = i\r\n\r\n        from types import ModuleType\r\n        module = ModuleType(\"testmodule\")\r\n        module.__dict__.update(globals())\r\n        module.p = p\r\n        module.display = display\r\n\r\n        indent = \" \" * indent_level\r\n        indented = code.replace('\\n', f\"\\n{indent}\")\r\n        body = f\"\"\"def __webuitemp__():\r\n{indent}{indented}\r\n__webuitemp__()\"\"\"\r\n\r\n        result = exec_with_return(body, module)\r\n\r\n        if isinstance(result, Processed):\r\n            return result\r\n\r\n        return Processed(p, *display_result_data)\r\n"
  },
  {
    "path": "scripts/img2imgalt.py",
    "content": "from collections import namedtuple\r\n\r\nimport numpy as np\r\nfrom tqdm import trange\r\n\r\nimport modules.scripts as scripts\r\nimport gradio as gr\r\n\r\nfrom modules import processing, shared, sd_samplers, sd_samplers_common\r\n\r\nimport torch\r\nimport k_diffusion as K\r\n\r\ndef find_noise_for_image(p, cond, uncond, cfg_scale, steps):\r\n    x = p.init_latent\r\n\r\n    s_in = x.new_ones([x.shape[0]])\r\n    if shared.sd_model.parameterization == \"v\":\r\n        dnw = K.external.CompVisVDenoiser(shared.sd_model)\r\n        skip = 1\r\n    else:\r\n        dnw = K.external.CompVisDenoiser(shared.sd_model)\r\n        skip = 0\r\n    sigmas = dnw.get_sigmas(steps).flip(0)\r\n\r\n    shared.state.sampling_steps = steps\r\n\r\n    for i in trange(1, len(sigmas)):\r\n        shared.state.sampling_step += 1\r\n\r\n        x_in = torch.cat([x] * 2)\r\n        sigma_in = torch.cat([sigmas[i] * s_in] * 2)\r\n        cond_in = torch.cat([uncond, cond])\r\n\r\n        image_conditioning = torch.cat([p.image_conditioning] * 2)\r\n        cond_in = {\"c_concat\": [image_conditioning], \"c_crossattn\": [cond_in]}\r\n\r\n        c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)[skip:]]\r\n        t = dnw.sigma_to_t(sigma_in)\r\n\r\n        eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in)\r\n        denoised_uncond, denoised_cond = (x_in + eps * c_out).chunk(2)\r\n\r\n        denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cfg_scale\r\n\r\n        d = (x - denoised) / sigmas[i]\r\n        dt = sigmas[i] - sigmas[i - 1]\r\n\r\n        x = x + d * dt\r\n\r\n        sd_samplers_common.store_latent(x)\r\n\r\n        # This shouldn't be necessary, but solved some VRAM issues\r\n        del x_in, sigma_in, cond_in, c_out, c_in, t,\r\n        del eps, denoised_uncond, denoised_cond, denoised, d, dt\r\n\r\n    shared.state.nextjob()\r\n\r\n    return x / x.std()\r\n\r\n\r\nCached = namedtuple(\"Cached\", [\"noise\", \"cfg_scale\", \"steps\", \"latent\", \"original_prompt\", \"original_negative_prompt\", \"sigma_adjustment\"])\r\n\r\n\r\n# Based on changes suggested by briansemrau in https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/736\r\ndef find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):\r\n    x = p.init_latent\r\n\r\n    s_in = x.new_ones([x.shape[0]])\r\n    if shared.sd_model.parameterization == \"v\":\r\n        dnw = K.external.CompVisVDenoiser(shared.sd_model)\r\n        skip = 1\r\n    else:\r\n        dnw = K.external.CompVisDenoiser(shared.sd_model)\r\n        skip = 0\r\n    sigmas = dnw.get_sigmas(steps).flip(0)\r\n\r\n    shared.state.sampling_steps = steps\r\n\r\n    for i in trange(1, len(sigmas)):\r\n        shared.state.sampling_step += 1\r\n\r\n        x_in = torch.cat([x] * 2)\r\n        sigma_in = torch.cat([sigmas[i - 1] * s_in] * 2)\r\n        cond_in = torch.cat([uncond, cond])\r\n\r\n        image_conditioning = torch.cat([p.image_conditioning] * 2)\r\n        cond_in = {\"c_concat\": [image_conditioning], \"c_crossattn\": [cond_in]}\r\n\r\n        c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)[skip:]]\r\n\r\n        if i == 1:\r\n            t = dnw.sigma_to_t(torch.cat([sigmas[i] * s_in] * 2))\r\n        else:\r\n            t = dnw.sigma_to_t(sigma_in)\r\n\r\n        eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in)\r\n        denoised_uncond, denoised_cond = (x_in + eps * c_out).chunk(2)\r\n\r\n        denoised = denoised_uncond + (denoised_cond - denoised_uncond) * cfg_scale\r\n\r\n        if i == 1:\r\n            d = (x - denoised) / (2 * sigmas[i])\r\n        else:\r\n            d = (x - denoised) / sigmas[i - 1]\r\n\r\n        dt = sigmas[i] - sigmas[i - 1]\r\n        x = x + d * dt\r\n\r\n        sd_samplers_common.store_latent(x)\r\n\r\n        # This shouldn't be necessary, but solved some VRAM issues\r\n        del x_in, sigma_in, cond_in, c_out, c_in, t,\r\n        del eps, denoised_uncond, denoised_cond, denoised, d, dt\r\n\r\n    shared.state.nextjob()\r\n\r\n    return x / sigmas[-1]\r\n\r\n\r\nclass Script(scripts.Script):\r\n    def __init__(self):\r\n        self.cache = None\r\n\r\n    def title(self):\r\n        return \"img2img alternative test\"\r\n\r\n    def show(self, is_img2img):\r\n        return is_img2img\r\n\r\n    def ui(self, is_img2img):\r\n        info = gr.Markdown('''\r\n        * `CFG Scale` should be 2 or lower.\r\n        ''')\r\n\r\n        override_sampler = gr.Checkbox(label=\"Override `Sampling method` to Euler?(this method is built for it)\", value=True, elem_id=self.elem_id(\"override_sampler\"))\r\n\r\n        override_prompt = gr.Checkbox(label=\"Override `prompt` to the same value as `original prompt`?(and `negative prompt`)\", value=True, elem_id=self.elem_id(\"override_prompt\"))\r\n        original_prompt = gr.Textbox(label=\"Original prompt\", lines=1, elem_id=self.elem_id(\"original_prompt\"))\r\n        original_negative_prompt = gr.Textbox(label=\"Original negative prompt\", lines=1, elem_id=self.elem_id(\"original_negative_prompt\"))\r\n\r\n        override_steps = gr.Checkbox(label=\"Override `Sampling Steps` to the same value as `Decode steps`?\", value=True, elem_id=self.elem_id(\"override_steps\"))\r\n        st = gr.Slider(label=\"Decode steps\", minimum=1, maximum=150, step=1, value=50, elem_id=self.elem_id(\"st\"))\r\n\r\n        override_strength = gr.Checkbox(label=\"Override `Denoising strength` to 1?\", value=True, elem_id=self.elem_id(\"override_strength\"))\r\n\r\n        cfg = gr.Slider(label=\"Decode CFG scale\", minimum=0.0, maximum=15.0, step=0.1, value=1.0, elem_id=self.elem_id(\"cfg\"))\r\n        randomness = gr.Slider(label=\"Randomness\", minimum=0.0, maximum=1.0, step=0.01, value=0.0, elem_id=self.elem_id(\"randomness\"))\r\n        sigma_adjustment = gr.Checkbox(label=\"Sigma adjustment for finding noise for image\", value=False, elem_id=self.elem_id(\"sigma_adjustment\"))\r\n\r\n        return [\r\n            info,\r\n            override_sampler,\r\n            override_prompt, original_prompt, original_negative_prompt,\r\n            override_steps, st,\r\n            override_strength,\r\n            cfg, randomness, sigma_adjustment,\r\n        ]\r\n\r\n    def run(self, p, _, override_sampler, override_prompt, original_prompt, original_negative_prompt, override_steps, st, override_strength, cfg, randomness, sigma_adjustment):\r\n        # Override\r\n        if override_sampler:\r\n            p.sampler_name = \"Euler\"\r\n        if override_prompt:\r\n            p.prompt = original_prompt\r\n            p.negative_prompt = original_negative_prompt\r\n        if override_steps:\r\n            p.steps = st\r\n        if override_strength:\r\n            p.denoising_strength = 1.0\r\n\r\n        def sample_extra(conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):\r\n            lat = (p.init_latent.cpu().numpy() * 10).astype(int)\r\n\r\n            same_params = self.cache is not None and self.cache.cfg_scale == cfg and self.cache.steps == st \\\r\n                                and self.cache.original_prompt == original_prompt \\\r\n                                and self.cache.original_negative_prompt == original_negative_prompt \\\r\n                                and self.cache.sigma_adjustment == sigma_adjustment\r\n            same_everything = same_params and self.cache.latent.shape == lat.shape and np.abs(self.cache.latent-lat).sum() < 100\r\n\r\n            if same_everything:\r\n                rec_noise = self.cache.noise\r\n            else:\r\n                shared.state.job_count += 1\r\n                cond = p.sd_model.get_learned_conditioning(p.batch_size * [original_prompt])\r\n                uncond = p.sd_model.get_learned_conditioning(p.batch_size * [original_negative_prompt])\r\n                if sigma_adjustment:\r\n                    rec_noise = find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg, st)\r\n                else:\r\n                    rec_noise = find_noise_for_image(p, cond, uncond, cfg, st)\r\n                self.cache = Cached(rec_noise, cfg, st, lat, original_prompt, original_negative_prompt, sigma_adjustment)\r\n\r\n            rand_noise = processing.create_random_tensors(p.init_latent.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w, p=p)\r\n\r\n            combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5)\r\n\r\n            sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model)\r\n\r\n            sigmas = sampler.model_wrap.get_sigmas(p.steps)\r\n\r\n            noise_dt = combined_noise - (p.init_latent / sigmas[0])\r\n\r\n            p.seed = p.seed + 1\r\n\r\n            return sampler.sample_img2img(p, p.init_latent, noise_dt, conditioning, unconditional_conditioning, image_conditioning=p.image_conditioning)\r\n\r\n        p.sample = sample_extra\r\n\r\n        p.extra_generation_params[\"Decode prompt\"] = original_prompt\r\n        p.extra_generation_params[\"Decode negative prompt\"] = original_negative_prompt\r\n        p.extra_generation_params[\"Decode CFG scale\"] = cfg\r\n        p.extra_generation_params[\"Decode steps\"] = st\r\n        p.extra_generation_params[\"Randomness\"] = randomness\r\n        p.extra_generation_params[\"Sigma Adjustment\"] = sigma_adjustment\r\n\r\n        processed = processing.process_images(p)\r\n\r\n        return processed\r\n"
  },
  {
    "path": "scripts/loopback.py",
    "content": "import math\r\n\r\nimport gradio as gr\r\nimport modules.scripts as scripts\r\nfrom modules import deepbooru, images, processing, shared\r\nfrom modules.processing import Processed\r\nfrom modules.shared import opts, state\r\n\r\n\r\nclass Script(scripts.Script):\r\n    def title(self):\r\n        return \"Loopback\"\r\n\r\n    def show(self, is_img2img):\r\n        return is_img2img\r\n\r\n    def ui(self, is_img2img):\r\n        loops = gr.Slider(minimum=1, maximum=32, step=1, label='Loops', value=4, elem_id=self.elem_id(\"loops\"))\r\n        final_denoising_strength = gr.Slider(minimum=0, maximum=1, step=0.01, label='Final denoising strength', value=0.5, elem_id=self.elem_id(\"final_denoising_strength\"))\r\n        denoising_curve = gr.Dropdown(label=\"Denoising strength curve\", choices=[\"Aggressive\", \"Linear\", \"Lazy\"], value=\"Linear\")\r\n        append_interrogation = gr.Dropdown(label=\"Append interrogated prompt at each iteration\", choices=[\"None\", \"CLIP\", \"DeepBooru\"], value=\"None\")\r\n\r\n        return [loops, final_denoising_strength, denoising_curve, append_interrogation]\r\n\r\n    def run(self, p, loops, final_denoising_strength, denoising_curve, append_interrogation):\r\n        processing.fix_seed(p)\r\n        batch_count = p.n_iter\r\n        p.extra_generation_params = {\r\n            \"Final denoising strength\": final_denoising_strength,\r\n            \"Denoising curve\": denoising_curve\r\n        }\r\n\r\n        p.batch_size = 1\r\n        p.n_iter = 1\r\n\r\n        info = None\r\n        initial_seed = None\r\n        initial_info = None\r\n        initial_denoising_strength = p.denoising_strength\r\n\r\n        grids = []\r\n        all_images = []\r\n        original_init_image = p.init_images\r\n        original_prompt = p.prompt\r\n        original_inpainting_fill = p.inpainting_fill\r\n        state.job_count = loops * batch_count\r\n\r\n        initial_color_corrections = [processing.setup_color_correction(p.init_images[0])]\r\n\r\n        def calculate_denoising_strength(loop):\r\n            strength = initial_denoising_strength\r\n\r\n            if loops == 1:\r\n                return strength\r\n\r\n            progress = loop / (loops - 1)\r\n            if denoising_curve == \"Aggressive\":\r\n                strength = math.sin((progress) * math.pi * 0.5)\r\n            elif denoising_curve == \"Lazy\":\r\n                strength = 1 - math.cos((progress) * math.pi * 0.5)\r\n            else:\r\n                strength = progress\r\n\r\n            change = (final_denoising_strength - initial_denoising_strength) * strength\r\n            return initial_denoising_strength + change\r\n\r\n        history = []\r\n\r\n        for n in range(batch_count):\r\n            # Reset to original init image at the start of each batch\r\n            p.init_images = original_init_image\r\n\r\n            # Reset to original denoising strength\r\n            p.denoising_strength = initial_denoising_strength\r\n\r\n            last_image = None\r\n\r\n            for i in range(loops):\r\n                p.n_iter = 1\r\n                p.batch_size = 1\r\n                p.do_not_save_grid = True\r\n\r\n                if opts.img2img_color_correction:\r\n                    p.color_corrections = initial_color_corrections\r\n\r\n                if append_interrogation != \"None\":\r\n                    p.prompt = f\"{original_prompt}, \" if original_prompt else \"\"\r\n                    if append_interrogation == \"CLIP\":\r\n                        p.prompt += shared.interrogator.interrogate(p.init_images[0])\r\n                    elif append_interrogation == \"DeepBooru\":\r\n                        p.prompt += deepbooru.model.tag(p.init_images[0])\r\n\r\n                state.job = f\"Iteration {i + 1}/{loops}, batch {n + 1}/{batch_count}\"\r\n\r\n                processed = processing.process_images(p)\r\n\r\n                # Generation cancelled.\r\n                if state.interrupted or state.stopping_generation:\r\n                    break\r\n\r\n                if initial_seed is None:\r\n                    initial_seed = processed.seed\r\n                    initial_info = processed.info\r\n\r\n                p.seed = processed.seed + 1\r\n                p.denoising_strength = calculate_denoising_strength(i + 1)\r\n\r\n                if state.skipped:\r\n                    break\r\n\r\n                last_image = processed.images[0]\r\n                p.init_images = [last_image]\r\n                p.inpainting_fill = 1 # Set \"masked content\" to \"original\" for next loop.\r\n\r\n                if batch_count == 1:\r\n                    history.append(last_image)\r\n                    all_images.append(last_image)\r\n\r\n            if batch_count > 1 and not state.skipped and not state.interrupted:\r\n                history.append(last_image)\r\n                all_images.append(last_image)\r\n\r\n            p.inpainting_fill = original_inpainting_fill\r\n\r\n            if state.interrupted or state.stopping_generation:\r\n                break\r\n\r\n        if len(history) > 1:\r\n            grid = images.image_grid(history, rows=1)\r\n            if opts.grid_save:\r\n                images.save_image(grid, p.outpath_grids, \"grid\", initial_seed, p.prompt, opts.grid_format, info=info, short_filename=not opts.grid_extended_filename, grid=True, p=p)\r\n\r\n            if opts.return_grid:\r\n                grids.append(grid)\r\n\r\n        all_images = grids + all_images\r\n\r\n        processed = Processed(p, all_images, initial_seed, initial_info)\r\n\r\n        return processed\r\n"
  },
  {
    "path": "scripts/outpainting_mk_2.py",
    "content": "import math\r\n\r\nimport numpy as np\r\nimport skimage\r\n\r\nimport modules.scripts as scripts\r\nimport gradio as gr\r\nfrom PIL import Image, ImageDraw\r\n\r\nfrom modules import images\r\nfrom modules.processing import Processed, process_images\r\nfrom modules.shared import opts, state\r\n\r\n\r\n# this function is taken from https://github.com/parlance-zz/g-diffuser-bot\r\ndef get_matched_noise(_np_src_image, np_mask_rgb, noise_q=1, color_variation=0.05):\r\n    # helper fft routines that keep ortho normalization and auto-shift before and after fft\r\n    def _fft2(data):\r\n        if data.ndim > 2:  # has channels\r\n            out_fft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128)\r\n            for c in range(data.shape[2]):\r\n                c_data = data[:, :, c]\r\n                out_fft[:, :, c] = np.fft.fft2(np.fft.fftshift(c_data), norm=\"ortho\")\r\n                out_fft[:, :, c] = np.fft.ifftshift(out_fft[:, :, c])\r\n        else:  # one channel\r\n            out_fft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)\r\n            out_fft[:, :] = np.fft.fft2(np.fft.fftshift(data), norm=\"ortho\")\r\n            out_fft[:, :] = np.fft.ifftshift(out_fft[:, :])\r\n\r\n        return out_fft\r\n\r\n    def _ifft2(data):\r\n        if data.ndim > 2:  # has channels\r\n            out_ifft = np.zeros((data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128)\r\n            for c in range(data.shape[2]):\r\n                c_data = data[:, :, c]\r\n                out_ifft[:, :, c] = np.fft.ifft2(np.fft.fftshift(c_data), norm=\"ortho\")\r\n                out_ifft[:, :, c] = np.fft.ifftshift(out_ifft[:, :, c])\r\n        else:  # one channel\r\n            out_ifft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)\r\n            out_ifft[:, :] = np.fft.ifft2(np.fft.fftshift(data), norm=\"ortho\")\r\n            out_ifft[:, :] = np.fft.ifftshift(out_ifft[:, :])\r\n\r\n        return out_ifft\r\n\r\n    def _get_gaussian_window(width, height, std=3.14, mode=0):\r\n        window_scale_x = float(width / min(width, height))\r\n        window_scale_y = float(height / min(width, height))\r\n\r\n        window = np.zeros((width, height))\r\n        x = (np.arange(width) / width * 2. - 1.) * window_scale_x\r\n        for y in range(height):\r\n            fy = (y / height * 2. - 1.) * window_scale_y\r\n            if mode == 0:\r\n                window[:, y] = np.exp(-(x ** 2 + fy ** 2) * std)\r\n            else:\r\n                window[:, y] = (1 / ((x ** 2 + 1.) * (fy ** 2 + 1.))) ** (std / 3.14)  # hey wait a minute that's not gaussian\r\n\r\n        return window\r\n\r\n    def _get_masked_window_rgb(np_mask_grey, hardness=1.):\r\n        np_mask_rgb = np.zeros((np_mask_grey.shape[0], np_mask_grey.shape[1], 3))\r\n        if hardness != 1.:\r\n            hardened = np_mask_grey[:] ** hardness\r\n        else:\r\n            hardened = np_mask_grey[:]\r\n        for c in range(3):\r\n            np_mask_rgb[:, :, c] = hardened[:]\r\n        return np_mask_rgb\r\n\r\n    width = _np_src_image.shape[0]\r\n    height = _np_src_image.shape[1]\r\n    num_channels = _np_src_image.shape[2]\r\n\r\n    _np_src_image[:] * (1. - np_mask_rgb)\r\n    np_mask_grey = (np.sum(np_mask_rgb, axis=2) / 3.)\r\n    img_mask = np_mask_grey > 1e-6\r\n    ref_mask = np_mask_grey < 1e-3\r\n\r\n    windowed_image = _np_src_image * (1. - _get_masked_window_rgb(np_mask_grey))\r\n    windowed_image /= np.max(windowed_image)\r\n    windowed_image += np.average(_np_src_image) * np_mask_rgb  # / (1.-np.average(np_mask_rgb))  # rather than leave the masked area black, we get better results from fft by filling the average unmasked color\r\n\r\n    src_fft = _fft2(windowed_image)  # get feature statistics from masked src img\r\n    src_dist = np.absolute(src_fft)\r\n    src_phase = src_fft / src_dist\r\n\r\n    # create a generator with a static seed to make outpainting deterministic / only follow global seed\r\n    rng = np.random.default_rng(0)\r\n\r\n    noise_window = _get_gaussian_window(width, height, mode=1)  # start with simple gaussian noise\r\n    noise_rgb = rng.random((width, height, num_channels))\r\n    noise_grey = (np.sum(noise_rgb, axis=2) / 3.)\r\n    noise_rgb *= color_variation  # the colorfulness of the starting noise is blended to greyscale with a parameter\r\n    for c in range(num_channels):\r\n        noise_rgb[:, :, c] += (1. - color_variation) * noise_grey\r\n\r\n    noise_fft = _fft2(noise_rgb)\r\n    for c in range(num_channels):\r\n        noise_fft[:, :, c] *= noise_window\r\n    noise_rgb = np.real(_ifft2(noise_fft))\r\n    shaped_noise_fft = _fft2(noise_rgb)\r\n    shaped_noise_fft[:, :, :] = np.absolute(shaped_noise_fft[:, :, :]) ** 2 * (src_dist ** noise_q) * src_phase  # perform the actual shaping\r\n\r\n    brightness_variation = 0.  # color_variation # todo: temporarily tying brightness variation to color variation for now\r\n    contrast_adjusted_np_src = _np_src_image[:] * (brightness_variation + 1.) - brightness_variation * 2.\r\n\r\n    # scikit-image is used for histogram matching, very convenient!\r\n    shaped_noise = np.real(_ifft2(shaped_noise_fft))\r\n    shaped_noise -= np.min(shaped_noise)\r\n    shaped_noise /= np.max(shaped_noise)\r\n    shaped_noise[img_mask, :] = skimage.exposure.match_histograms(shaped_noise[img_mask, :] ** 1., contrast_adjusted_np_src[ref_mask, :], channel_axis=1)\r\n    shaped_noise = _np_src_image[:] * (1. - np_mask_rgb) + shaped_noise * np_mask_rgb\r\n\r\n    matched_noise = shaped_noise[:]\r\n\r\n    return np.clip(matched_noise, 0., 1.)\r\n\r\n\r\n\r\nclass Script(scripts.Script):\r\n    def title(self):\r\n        return \"Outpainting mk2\"\r\n\r\n    def show(self, is_img2img):\r\n        return is_img2img\r\n\r\n    def ui(self, is_img2img):\r\n        if not is_img2img:\r\n            return None\r\n\r\n        info = gr.HTML(\"<p style=\\\"margin-bottom:0.75em\\\">Recommended settings: Sampling Steps: 80-100, Sampler: Euler a, Denoising strength: 0.8</p>\")\r\n\r\n        pixels = gr.Slider(label=\"Pixels to expand\", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id(\"pixels\"))\r\n        mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8, elem_id=self.elem_id(\"mask_blur\"))\r\n        direction = gr.CheckboxGroup(label=\"Outpainting direction\", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id(\"direction\"))\r\n        noise_q = gr.Slider(label=\"Fall-off exponent (lower=higher detail)\", minimum=0.0, maximum=4.0, step=0.01, value=1.0, elem_id=self.elem_id(\"noise_q\"))\r\n        color_variation = gr.Slider(label=\"Color variation\", minimum=0.0, maximum=1.0, step=0.01, value=0.05, elem_id=self.elem_id(\"color_variation\"))\r\n\r\n        return [info, pixels, mask_blur, direction, noise_q, color_variation]\r\n\r\n    def run(self, p, _, pixels, mask_blur, direction, noise_q, color_variation):\r\n        initial_seed_and_info = [None, None]\r\n\r\n        process_width = p.width\r\n        process_height = p.height\r\n\r\n        p.inpaint_full_res = False\r\n        p.inpainting_fill = 1\r\n        p.do_not_save_samples = True\r\n        p.do_not_save_grid = True\r\n\r\n        left = pixels if \"left\" in direction else 0\r\n        right = pixels if \"right\" in direction else 0\r\n        up = pixels if \"up\" in direction else 0\r\n        down = pixels if \"down\" in direction else 0\r\n\r\n        if left > 0 or right > 0:\r\n            mask_blur_x = mask_blur\r\n        else:\r\n            mask_blur_x = 0\r\n\r\n        if up > 0 or down > 0:\r\n            mask_blur_y = mask_blur\r\n        else:\r\n            mask_blur_y = 0\r\n\r\n        p.mask_blur_x = mask_blur_x*4\r\n        p.mask_blur_y = mask_blur_y*4\r\n\r\n        init_img = p.init_images[0]\r\n        target_w = math.ceil((init_img.width + left + right) / 64) * 64\r\n        target_h = math.ceil((init_img.height + up + down) / 64) * 64\r\n\r\n        if left > 0:\r\n            left = left * (target_w - init_img.width) // (left + right)\r\n\r\n        if right > 0:\r\n            right = target_w - init_img.width - left\r\n\r\n        if up > 0:\r\n            up = up * (target_h - init_img.height) // (up + down)\r\n\r\n        if down > 0:\r\n            down = target_h - init_img.height - up\r\n\r\n        def expand(init, count, expand_pixels, is_left=False, is_right=False, is_top=False, is_bottom=False):\r\n            is_horiz = is_left or is_right\r\n            is_vert = is_top or is_bottom\r\n            pixels_horiz = expand_pixels if is_horiz else 0\r\n            pixels_vert = expand_pixels if is_vert else 0\r\n\r\n            images_to_process = []\r\n            output_images = []\r\n            for n in range(count):\r\n                res_w = init[n].width + pixels_horiz\r\n                res_h = init[n].height + pixels_vert\r\n                process_res_w = math.ceil(res_w / 64) * 64\r\n                process_res_h = math.ceil(res_h / 64) * 64\r\n\r\n                img = Image.new(\"RGB\", (process_res_w, process_res_h))\r\n                img.paste(init[n], (pixels_horiz if is_left else 0, pixels_vert if is_top else 0))\r\n                mask = Image.new(\"RGB\", (process_res_w, process_res_h), \"white\")\r\n                draw = ImageDraw.Draw(mask)\r\n                draw.rectangle((\r\n                    expand_pixels + mask_blur_x if is_left else 0,\r\n                    expand_pixels + mask_blur_y if is_top else 0,\r\n                    mask.width - expand_pixels - mask_blur_x if is_right else res_w,\r\n                    mask.height - expand_pixels - mask_blur_y if is_bottom else res_h,\r\n                ), fill=\"black\")\r\n\r\n                np_image = (np.asarray(img) / 255.0).astype(np.float64)\r\n                np_mask = (np.asarray(mask) / 255.0).astype(np.float64)\r\n                noised = get_matched_noise(np_image, np_mask, noise_q, color_variation)\r\n                output_images.append(Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode=\"RGB\"))\r\n\r\n                target_width = min(process_width, init[n].width + pixels_horiz) if is_horiz else img.width\r\n                target_height = min(process_height, init[n].height + pixels_vert) if is_vert else img.height\r\n                p.width = target_width if is_horiz else img.width\r\n                p.height = target_height if is_vert else img.height\r\n\r\n                crop_region = (\r\n                    0 if is_left else output_images[n].width - target_width,\r\n                    0 if is_top else output_images[n].height - target_height,\r\n                    target_width if is_left else output_images[n].width,\r\n                    target_height if is_top else output_images[n].height,\r\n                )\r\n                mask = mask.crop(crop_region)\r\n                p.image_mask = mask\r\n\r\n                image_to_process = output_images[n].crop(crop_region)\r\n                images_to_process.append(image_to_process)\r\n\r\n            p.init_images = images_to_process\r\n\r\n            latent_mask = Image.new(\"RGB\", (p.width, p.height), \"white\")\r\n            draw = ImageDraw.Draw(latent_mask)\r\n            draw.rectangle((\r\n                expand_pixels + mask_blur_x * 2 if is_left else 0,\r\n                expand_pixels + mask_blur_y * 2 if is_top else 0,\r\n                mask.width - expand_pixels - mask_blur_x * 2 if is_right else res_w,\r\n                mask.height - expand_pixels - mask_blur_y * 2 if is_bottom else res_h,\r\n            ), fill=\"black\")\r\n            p.latent_mask = latent_mask\r\n\r\n            proc = process_images(p)\r\n\r\n            if initial_seed_and_info[0] is None:\r\n                initial_seed_and_info[0] = proc.seed\r\n                initial_seed_and_info[1] = proc.info\r\n\r\n            for n in range(count):\r\n                output_images[n].paste(proc.images[n], (0 if is_left else output_images[n].width - proc.images[n].width, 0 if is_top else output_images[n].height - proc.images[n].height))\r\n                output_images[n] = output_images[n].crop((0, 0, res_w, res_h))\r\n\r\n            return output_images\r\n\r\n        batch_count = p.n_iter\r\n        batch_size = p.batch_size\r\n        p.n_iter = 1\r\n        state.job_count = batch_count * ((1 if left > 0 else 0) + (1 if right > 0 else 0) + (1 if up > 0 else 0) + (1 if down > 0 else 0))\r\n        all_processed_images = []\r\n\r\n        for i in range(batch_count):\r\n            imgs = [init_img] * batch_size\r\n            state.job = f\"Batch {i + 1} out of {batch_count}\"\r\n\r\n            if left > 0:\r\n                imgs = expand(imgs, batch_size, left, is_left=True)\r\n            if right > 0:\r\n                imgs = expand(imgs, batch_size, right, is_right=True)\r\n            if up > 0:\r\n                imgs = expand(imgs, batch_size, up, is_top=True)\r\n            if down > 0:\r\n                imgs = expand(imgs, batch_size, down, is_bottom=True)\r\n\r\n            all_processed_images += imgs\r\n\r\n        all_images = all_processed_images\r\n\r\n        combined_grid_image = images.image_grid(all_processed_images)\r\n        unwanted_grid_because_of_img_count = len(all_processed_images) < 2 and opts.grid_only_if_multiple\r\n        if opts.return_grid and not unwanted_grid_because_of_img_count:\r\n            all_images = [combined_grid_image] + all_processed_images\r\n\r\n        res = Processed(p, all_images, initial_seed_and_info[0], initial_seed_and_info[1])\r\n\r\n        if opts.samples_save:\r\n            for img in all_processed_images:\r\n                images.save_image(img, p.outpath_samples, \"\", res.seed, p.prompt, opts.samples_format, info=res.info, p=p)\r\n\r\n        if opts.grid_save and not unwanted_grid_because_of_img_count:\r\n            images.save_image(combined_grid_image, p.outpath_grids, \"grid\", res.seed, p.prompt, opts.grid_format, info=res.info, short_filename=not opts.grid_extended_filename, grid=True, p=p)\r\n\r\n        return res\r\n"
  },
  {
    "path": "scripts/poor_mans_outpainting.py",
    "content": "import math\r\n\r\nimport modules.scripts as scripts\r\nimport gradio as gr\r\nfrom PIL import Image, ImageDraw\r\n\r\nfrom modules import images, devices\r\nfrom modules.processing import Processed, process_images\r\nfrom modules.shared import opts, state\r\n\r\n\r\nclass Script(scripts.Script):\r\n    def title(self):\r\n        return \"Poor man's outpainting\"\r\n\r\n    def show(self, is_img2img):\r\n        return is_img2img\r\n\r\n    def ui(self, is_img2img):\r\n        if not is_img2img:\r\n            return None\r\n\r\n        pixels = gr.Slider(label=\"Pixels to expand\", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id(\"pixels\"))\r\n        mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=self.elem_id(\"mask_blur\"))\r\n        inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type=\"index\", elem_id=self.elem_id(\"inpainting_fill\"))\r\n        direction = gr.CheckboxGroup(label=\"Outpainting direction\", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id(\"direction\"))\r\n\r\n        return [pixels, mask_blur, inpainting_fill, direction]\r\n\r\n    def run(self, p, pixels, mask_blur, inpainting_fill, direction):\r\n        initial_seed = None\r\n        initial_info = None\r\n\r\n        p.mask_blur = mask_blur * 2\r\n        p.inpainting_fill = inpainting_fill\r\n        p.inpaint_full_res = False\r\n\r\n        left = pixels if \"left\" in direction else 0\r\n        right = pixels if \"right\" in direction else 0\r\n        up = pixels if \"up\" in direction else 0\r\n        down = pixels if \"down\" in direction else 0\r\n\r\n        init_img = p.init_images[0]\r\n        target_w = math.ceil((init_img.width + left + right) / 64) * 64\r\n        target_h = math.ceil((init_img.height + up + down) / 64) * 64\r\n\r\n        if left > 0:\r\n            left = left * (target_w - init_img.width) // (left + right)\r\n        if right > 0:\r\n            right = target_w - init_img.width - left\r\n\r\n        if up > 0:\r\n            up = up * (target_h - init_img.height) // (up + down)\r\n\r\n        if down > 0:\r\n            down = target_h - init_img.height - up\r\n\r\n        img = Image.new(\"RGB\", (target_w, target_h))\r\n        img.paste(init_img, (left, up))\r\n\r\n        mask = Image.new(\"L\", (img.width, img.height), \"white\")\r\n        draw = ImageDraw.Draw(mask)\r\n        draw.rectangle((\r\n            left + (mask_blur * 2 if left > 0 else 0),\r\n            up + (mask_blur * 2 if up > 0 else 0),\r\n            mask.width - right - (mask_blur * 2 if right > 0 else 0),\r\n            mask.height - down - (mask_blur * 2 if down > 0 else 0)\r\n        ), fill=\"black\")\r\n\r\n        latent_mask = Image.new(\"L\", (img.width, img.height), \"white\")\r\n        latent_draw = ImageDraw.Draw(latent_mask)\r\n        latent_draw.rectangle((\r\n             left + (mask_blur//2 if left > 0 else 0),\r\n             up + (mask_blur//2 if up > 0 else 0),\r\n             mask.width - right - (mask_blur//2 if right > 0 else 0),\r\n             mask.height - down - (mask_blur//2 if down > 0 else 0)\r\n        ), fill=\"black\")\r\n\r\n        devices.torch_gc()\r\n\r\n        grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=pixels)\r\n        grid_mask = images.split_grid(mask, tile_w=p.width, tile_h=p.height, overlap=pixels)\r\n        grid_latent_mask = images.split_grid(latent_mask, tile_w=p.width, tile_h=p.height, overlap=pixels)\r\n\r\n        p.n_iter = 1\r\n        p.batch_size = 1\r\n        p.do_not_save_grid = True\r\n        p.do_not_save_samples = True\r\n\r\n        work = []\r\n        work_mask = []\r\n        work_latent_mask = []\r\n        work_results = []\r\n\r\n        for (y, h, row), (_, _, row_mask), (_, _, row_latent_mask) in zip(grid.tiles, grid_mask.tiles, grid_latent_mask.tiles):\r\n            for tiledata, tiledata_mask, tiledata_latent_mask in zip(row, row_mask, row_latent_mask):\r\n                x, w = tiledata[0:2]\r\n\r\n                if x >= left and x+w <= img.width - right and y >= up and y+h <= img.height - down:\r\n                    continue\r\n\r\n                work.append(tiledata[2])\r\n                work_mask.append(tiledata_mask[2])\r\n                work_latent_mask.append(tiledata_latent_mask[2])\r\n\r\n        batch_count = len(work)\r\n        print(f\"Poor man's outpainting will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)}.\")\r\n\r\n        state.job_count = batch_count\r\n\r\n        for i in range(batch_count):\r\n            p.init_images = [work[i]]\r\n            p.image_mask = work_mask[i]\r\n            p.latent_mask = work_latent_mask[i]\r\n\r\n            state.job = f\"Batch {i + 1} out of {batch_count}\"\r\n            processed = process_images(p)\r\n\r\n            if initial_seed is None:\r\n                initial_seed = processed.seed\r\n                initial_info = processed.info\r\n\r\n            p.seed = processed.seed + 1\r\n            work_results += processed.images\r\n\r\n\r\n        image_index = 0\r\n        for y, h, row in grid.tiles:\r\n            for tiledata in row:\r\n                x, w = tiledata[0:2]\r\n\r\n                if x >= left and x+w <= img.width - right and y >= up and y+h <= img.height - down:\r\n                    continue\r\n\r\n                tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new(\"RGB\", (p.width, p.height))\r\n                image_index += 1\r\n\r\n        combined_image = images.combine_grid(grid)\r\n\r\n        if opts.samples_save:\r\n            images.save_image(combined_image, p.outpath_samples, \"\", initial_seed, p.prompt, opts.samples_format, info=initial_info, p=p)\r\n\r\n        processed = Processed(p, [combined_image], initial_seed, initial_info)\r\n\r\n        return processed\r\n\r\n"
  },
  {
    "path": "scripts/postprocessing_codeformer.py",
    "content": "from PIL import Image\r\nimport numpy as np\r\n\r\nfrom modules import scripts_postprocessing, codeformer_model, ui_components\r\nimport gradio as gr\r\n\r\n\r\nclass ScriptPostprocessingCodeFormer(scripts_postprocessing.ScriptPostprocessing):\r\n    name = \"CodeFormer\"\r\n    order = 3000\r\n\r\n    def ui(self):\r\n        with ui_components.InputAccordion(False, label=\"CodeFormer\") as enable:\r\n            with gr.Row():\r\n                codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label=\"Visibility\", value=1.0, elem_id=\"extras_codeformer_visibility\")\r\n                codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label=\"Weight (0 = maximum effect, 1 = minimum effect)\", value=0, elem_id=\"extras_codeformer_weight\")\r\n\r\n        return {\r\n            \"enable\": enable,\r\n            \"codeformer_visibility\": codeformer_visibility,\r\n            \"codeformer_weight\": codeformer_weight,\r\n        }\r\n\r\n    def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, codeformer_visibility, codeformer_weight):\r\n        if codeformer_visibility == 0 or not enable:\r\n            return\r\n\r\n        restored_img = codeformer_model.codeformer.restore(np.array(pp.image.convert(\"RGB\"), dtype=np.uint8), w=codeformer_weight)\r\n        res = Image.fromarray(restored_img)\r\n\r\n        if codeformer_visibility < 1.0:\r\n            res = Image.blend(pp.image, res, codeformer_visibility)\r\n\r\n        pp.image = res\r\n        pp.info[\"CodeFormer visibility\"] = round(codeformer_visibility, 3)\r\n        pp.info[\"CodeFormer weight\"] = round(codeformer_weight, 3)\r\n"
  },
  {
    "path": "scripts/postprocessing_gfpgan.py",
    "content": "from PIL import Image\r\nimport numpy as np\r\n\r\nfrom modules import scripts_postprocessing, gfpgan_model, ui_components\r\nimport gradio as gr\r\n\r\n\r\nclass ScriptPostprocessingGfpGan(scripts_postprocessing.ScriptPostprocessing):\r\n    name = \"GFPGAN\"\r\n    order = 2000\r\n\r\n    def ui(self):\r\n        with ui_components.InputAccordion(False, label=\"GFPGAN\") as enable:\r\n            gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label=\"Visibility\", value=1.0, elem_id=\"extras_gfpgan_visibility\")\r\n\r\n        return {\r\n            \"enable\": enable,\r\n            \"gfpgan_visibility\": gfpgan_visibility,\r\n        }\r\n\r\n    def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, gfpgan_visibility):\r\n        if gfpgan_visibility == 0 or not enable:\r\n            return\r\n\r\n        restored_img = gfpgan_model.gfpgan_fix_faces(np.array(pp.image.convert(\"RGB\"), dtype=np.uint8))\r\n        res = Image.fromarray(restored_img)\r\n\r\n        if gfpgan_visibility < 1.0:\r\n            res = Image.blend(pp.image, res, gfpgan_visibility)\r\n\r\n        pp.image = res\r\n        pp.info[\"GFPGAN visibility\"] = round(gfpgan_visibility, 3)\r\n"
  },
  {
    "path": "scripts/postprocessing_upscale.py",
    "content": "import re\r\n\r\nfrom PIL import Image\r\nimport numpy as np\r\n\r\nfrom modules import scripts_postprocessing, shared\r\nimport gradio as gr\r\n\r\nfrom modules.ui_components import FormRow, ToolButton, InputAccordion\r\nfrom modules.ui import switch_values_symbol\r\n\r\nupscale_cache = {}\r\n\r\n\r\ndef limit_size_by_one_dimention(w, h, limit):\r\n    if h > w and h > limit:\r\n        w = limit * w // h\r\n        h = limit\r\n    elif w > limit:\r\n        h = limit * h // w\r\n        w = limit\r\n\r\n    return int(w), int(h)\r\n\r\n\r\nclass ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):\r\n    name = \"Upscale\"\r\n    order = 1000\r\n\r\n    def ui(self):\r\n        selected_tab = gr.Number(value=0, visible=False)\r\n\r\n        with InputAccordion(True, label=\"Upscale\", elem_id=\"extras_upscale\") as upscale_enabled:\r\n            with FormRow():\r\n                extras_upscaler_1 = gr.Dropdown(label='Upscaler 1', elem_id=\"extras_upscaler_1\", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)\r\n\r\n            with FormRow():\r\n                extras_upscaler_2 = gr.Dropdown(label='Upscaler 2', elem_id=\"extras_upscaler_2\", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)\r\n                extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label=\"Upscaler 2 visibility\", value=0.0, elem_id=\"extras_upscaler_2_visibility\")\r\n\r\n            with FormRow():\r\n                with gr.Tabs(elem_id=\"extras_resize_mode\"):\r\n                    with gr.TabItem('Scale by', elem_id=\"extras_scale_by_tab\") as tab_scale_by:\r\n                        with gr.Row():\r\n                            with gr.Column(scale=4):\r\n                                upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label=\"Resize\", value=4, elem_id=\"extras_upscaling_resize\")\r\n                            with gr.Column(scale=1, min_width=160):\r\n                                max_side_length = gr.Number(label=\"Max side length\", value=0, elem_id=\"extras_upscale_max_side_length\", tooltip=\"If any of two sides of the image ends up larger than specified, will downscale it to fit. 0 = no limit.\", min_width=160, step=8, minimum=0)\r\n\r\n                    with gr.TabItem('Scale to', elem_id=\"extras_scale_to_tab\") as tab_scale_to:\r\n                        with FormRow():\r\n                            with gr.Column(elem_id=\"upscaling_column_size\", scale=4):\r\n                                upscaling_resize_w = gr.Slider(minimum=64, maximum=8192, step=8, label=\"Width\", value=512, elem_id=\"extras_upscaling_resize_w\")\r\n                                upscaling_resize_h = gr.Slider(minimum=64, maximum=8192, step=8, label=\"Height\", value=512, elem_id=\"extras_upscaling_resize_h\")\r\n                            with gr.Column(elem_id=\"upscaling_dimensions_row\", scale=1, elem_classes=\"dimensions-tools\"):\r\n                                upscaling_res_switch_btn = ToolButton(value=switch_values_symbol, elem_id=\"upscaling_res_switch_btn\", tooltip=\"Switch width/height\")\r\n                                upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id=\"extras_upscaling_crop\")\r\n\r\n        def on_selected_upscale_method(upscale_method):\r\n            if not shared.opts.set_scale_by_when_changing_upscaler:\r\n                return gr.update()\r\n\r\n            match = re.search(r'(\\d)[xX]|[xX](\\d)', upscale_method)\r\n            if not match:\r\n                return gr.update()\r\n\r\n            return gr.update(value=int(match.group(1) or match.group(2)))\r\n\r\n        upscaling_res_switch_btn.click(lambda w, h: (h, w), inputs=[upscaling_resize_w, upscaling_resize_h], outputs=[upscaling_resize_w, upscaling_resize_h], show_progress=False)\r\n        tab_scale_by.select(fn=lambda: 0, inputs=[], outputs=[selected_tab])\r\n        tab_scale_to.select(fn=lambda: 1, inputs=[], outputs=[selected_tab])\r\n\r\n        extras_upscaler_1.change(on_selected_upscale_method, inputs=[extras_upscaler_1], outputs=[upscaling_resize], show_progress=\"hidden\")\r\n\r\n        return {\r\n            \"upscale_enabled\": upscale_enabled,\r\n            \"upscale_mode\": selected_tab,\r\n            \"upscale_by\": upscaling_resize,\r\n            \"max_side_length\": max_side_length,\r\n            \"upscale_to_width\": upscaling_resize_w,\r\n            \"upscale_to_height\": upscaling_resize_h,\r\n            \"upscale_crop\": upscaling_crop,\r\n            \"upscaler_1_name\": extras_upscaler_1,\r\n            \"upscaler_2_name\": extras_upscaler_2,\r\n            \"upscaler_2_visibility\": extras_upscaler_2_visibility,\r\n        }\r\n\r\n    def upscale(self, image, info, upscaler, upscale_mode, upscale_by, max_side_length, upscale_to_width, upscale_to_height, upscale_crop):\r\n        if upscale_mode == 1:\r\n            upscale_by = max(upscale_to_width/image.width, upscale_to_height/image.height)\r\n            info[\"Postprocess upscale to\"] = f\"{upscale_to_width}x{upscale_to_height}\"\r\n        else:\r\n            info[\"Postprocess upscale by\"] = upscale_by\r\n            if max_side_length != 0 and max(*image.size)*upscale_by > max_side_length:\r\n                upscale_mode = 1\r\n                upscale_crop = False\r\n                upscale_to_width, upscale_to_height = limit_size_by_one_dimention(image.width*upscale_by, image.height*upscale_by, max_side_length)\r\n                upscale_by = max(upscale_to_width/image.width, upscale_to_height/image.height)\r\n                info[\"Max side length\"] = max_side_length\r\n\r\n        cache_key = (hash(np.array(image.getdata()).tobytes()), upscaler.name, upscale_mode, upscale_by,  upscale_to_width, upscale_to_height, upscale_crop)\r\n        cached_image = upscale_cache.pop(cache_key, None)\r\n\r\n        if cached_image is not None:\r\n            image = cached_image\r\n        else:\r\n            image = upscaler.scaler.upscale(image, upscale_by, upscaler.data_path)\r\n\r\n        upscale_cache[cache_key] = image\r\n        if len(upscale_cache) > shared.opts.upscaling_max_images_in_cache:\r\n            upscale_cache.pop(next(iter(upscale_cache), None), None)\r\n\r\n        if upscale_mode == 1 and upscale_crop:\r\n            cropped = Image.new(\"RGB\", (upscale_to_width, upscale_to_height))\r\n            cropped.paste(image, box=(upscale_to_width // 2 - image.width // 2, upscale_to_height // 2 - image.height // 2))\r\n            image = cropped\r\n            info[\"Postprocess crop to\"] = f\"{image.width}x{image.height}\"\r\n\r\n        return image\r\n\r\n    def process_firstpass(self, pp: scripts_postprocessing.PostprocessedImage, upscale_enabled=True, upscale_mode=1, upscale_by=2.0, max_side_length=0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0):\r\n        if upscale_mode == 1:\r\n            pp.shared.target_width = upscale_to_width\r\n            pp.shared.target_height = upscale_to_height\r\n        else:\r\n            pp.shared.target_width = int(pp.image.width * upscale_by)\r\n            pp.shared.target_height = int(pp.image.height * upscale_by)\r\n\r\n            pp.shared.target_width, pp.shared.target_height = limit_size_by_one_dimention(pp.shared.target_width, pp.shared.target_height, max_side_length)\r\n\r\n    def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_enabled=True, upscale_mode=1, upscale_by=2.0, max_side_length=0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0):\r\n        if not upscale_enabled:\r\n            return\r\n\r\n        upscaler_1_name = upscaler_1_name\r\n        if upscaler_1_name == \"None\":\r\n            upscaler_1_name = None\r\n\r\n        upscaler1 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_1_name]), None)\r\n        assert upscaler1 or (upscaler_1_name is None), f'could not find upscaler named {upscaler_1_name}'\r\n\r\n        if not upscaler1:\r\n            return\r\n\r\n        upscaler_2_name = upscaler_2_name\r\n        if upscaler_2_name == \"None\":\r\n            upscaler_2_name = None\r\n\r\n        upscaler2 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_2_name and x.name != \"None\"]), None)\r\n        assert upscaler2 or (upscaler_2_name is None), f'could not find upscaler named {upscaler_2_name}'\r\n\r\n        upscaled_image = self.upscale(pp.image, pp.info, upscaler1, upscale_mode, upscale_by, max_side_length, upscale_to_width, upscale_to_height, upscale_crop)\r\n        pp.info[\"Postprocess upscaler\"] = upscaler1.name\r\n\r\n        if upscaler2 and upscaler_2_visibility > 0:\r\n            second_upscale = self.upscale(pp.image, pp.info, upscaler2, upscale_mode, upscale_by, max_side_length, upscale_to_width, upscale_to_height, upscale_crop)\r\n            if upscaled_image.mode != second_upscale.mode:\r\n                second_upscale = second_upscale.convert(upscaled_image.mode)\r\n            upscaled_image = Image.blend(upscaled_image, second_upscale, upscaler_2_visibility)\r\n\r\n            pp.info[\"Postprocess upscaler 2\"] = upscaler2.name\r\n\r\n        pp.image = upscaled_image\r\n\r\n    def image_changed(self):\r\n        upscale_cache.clear()\r\n\r\n\r\nclass ScriptPostprocessingUpscaleSimple(ScriptPostprocessingUpscale):\r\n    name = \"Simple Upscale\"\r\n    order = 900\r\n\r\n    def ui(self):\r\n        with FormRow():\r\n            upscaler_name = gr.Dropdown(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)\r\n            upscale_by = gr.Slider(minimum=0.05, maximum=8.0, step=0.05, label=\"Upscale by\", value=2)\r\n\r\n        return {\r\n            \"upscale_by\": upscale_by,\r\n            \"upscaler_name\": upscaler_name,\r\n        }\r\n\r\n    def process_firstpass(self, pp: scripts_postprocessing.PostprocessedImage, upscale_by=2.0, upscaler_name=None):\r\n        pp.shared.target_width = int(pp.image.width * upscale_by)\r\n        pp.shared.target_height = int(pp.image.height * upscale_by)\r\n\r\n    def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_by=2.0, upscaler_name=None):\r\n        if upscaler_name is None or upscaler_name == \"None\":\r\n            return\r\n\r\n        upscaler1 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_name]), None)\r\n        assert upscaler1, f'could not find upscaler named {upscaler_name}'\r\n\r\n        pp.image = self.upscale(pp.image, pp.info, upscaler1, 0, upscale_by, 0, 0, 0, False)\r\n        pp.info[\"Postprocess upscaler\"] = upscaler1.name\r\n"
  },
  {
    "path": "scripts/prompt_matrix.py",
    "content": "import math\r\n\r\nimport modules.scripts as scripts\r\nimport gradio as gr\r\n\r\nfrom modules import images\r\nfrom modules.processing import process_images\r\nfrom modules.shared import opts, state\r\nimport modules.sd_samplers\r\n\r\n\r\ndef draw_xy_grid(xs, ys, x_label, y_label, cell):\r\n    res = []\r\n\r\n    ver_texts = [[images.GridAnnotation(y_label(y))] for y in ys]\r\n    hor_texts = [[images.GridAnnotation(x_label(x))] for x in xs]\r\n\r\n    first_processed = None\r\n\r\n    state.job_count = len(xs) * len(ys)\r\n\r\n    for iy, y in enumerate(ys):\r\n        for ix, x in enumerate(xs):\r\n            state.job = f\"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}\"\r\n\r\n            processed = cell(x, y)\r\n            if first_processed is None:\r\n                first_processed = processed\r\n\r\n            res.append(processed.images[0])\r\n\r\n    grid = images.image_grid(res, rows=len(ys))\r\n    grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts)\r\n\r\n    first_processed.images = [grid]\r\n\r\n    return first_processed\r\n\r\n\r\nclass Script(scripts.Script):\r\n    def title(self):\r\n        return \"Prompt matrix\"\r\n\r\n    def ui(self, is_img2img):\r\n        gr.HTML('<br />')\r\n        with gr.Row():\r\n            with gr.Column():\r\n                put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', value=False, elem_id=self.elem_id(\"put_at_start\"))\r\n                different_seeds = gr.Checkbox(label='Use different seed for each picture', value=False, elem_id=self.elem_id(\"different_seeds\"))\r\n            with gr.Column():\r\n                prompt_type = gr.Radio([\"positive\", \"negative\"], label=\"Select prompt\", elem_id=self.elem_id(\"prompt_type\"), value=\"positive\")\r\n                variations_delimiter = gr.Radio([\"comma\", \"space\"], label=\"Select joining char\", elem_id=self.elem_id(\"variations_delimiter\"), value=\"comma\")\r\n            with gr.Column():\r\n                margin_size = gr.Slider(label=\"Grid margins (px)\", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id(\"margin_size\"))\r\n\r\n        return [put_at_start, different_seeds, prompt_type, variations_delimiter, margin_size]\r\n\r\n    def run(self, p, put_at_start, different_seeds, prompt_type, variations_delimiter, margin_size):\r\n        modules.processing.fix_seed(p)\r\n        # Raise error if promp type is not positive or negative\r\n        if prompt_type not in [\"positive\", \"negative\"]:\r\n            raise ValueError(f\"Unknown prompt type {prompt_type}\")\r\n        # Raise error if variations delimiter is not comma or space\r\n        if variations_delimiter not in [\"comma\", \"space\"]:\r\n            raise ValueError(f\"Unknown variations delimiter {variations_delimiter}\")\r\n\r\n        prompt = p.prompt if prompt_type == \"positive\" else p.negative_prompt\r\n        original_prompt = prompt[0] if type(prompt) == list else prompt\r\n        positive_prompt = p.prompt[0] if type(p.prompt) == list else p.prompt\r\n\r\n        delimiter = \", \" if variations_delimiter == \"comma\" else \" \"\r\n\r\n        all_prompts = []\r\n        prompt_matrix_parts = original_prompt.split(\"|\")\r\n        combination_count = 2 ** (len(prompt_matrix_parts) - 1)\r\n        for combination_num in range(combination_count):\r\n            selected_prompts = [text.strip().strip(',') for n, text in enumerate(prompt_matrix_parts[1:]) if combination_num & (1 << n)]\r\n\r\n            if put_at_start:\r\n                selected_prompts = selected_prompts + [prompt_matrix_parts[0]]\r\n            else:\r\n                selected_prompts = [prompt_matrix_parts[0]] + selected_prompts\r\n\r\n            all_prompts.append(delimiter.join(selected_prompts))\r\n\r\n        p.n_iter = math.ceil(len(all_prompts) / p.batch_size)\r\n        p.do_not_save_grid = True\r\n\r\n        print(f\"Prompt matrix will create {len(all_prompts)} images using a total of {p.n_iter} batches.\")\r\n\r\n        if prompt_type == \"positive\":\r\n            p.prompt = all_prompts\r\n        else:\r\n            p.negative_prompt = all_prompts\r\n        p.seed = [p.seed + (i if different_seeds else 0) for i in range(len(all_prompts))]\r\n        p.prompt_for_display = positive_prompt\r\n        processed = process_images(p)\r\n\r\n        grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2))\r\n        grid = images.draw_prompt_matrix(grid, processed.images[0].width, processed.images[0].height, prompt_matrix_parts, margin_size)\r\n        processed.images.insert(0, grid)\r\n        processed.index_of_first_image = 1\r\n        processed.infotexts.insert(0, processed.infotexts[0])\r\n\r\n        if opts.grid_save:\r\n            images.save_image(processed.images[0], p.outpath_grids, \"prompt_matrix\", extension=opts.grid_format, prompt=original_prompt, seed=processed.seed, grid=True, p=p)\r\n\r\n        return processed\r\n"
  },
  {
    "path": "scripts/prompts_from_file.py",
    "content": "import copy\r\nimport random\r\nimport shlex\r\n\r\nimport modules.scripts as scripts\r\nimport gradio as gr\r\n\r\nfrom modules import sd_samplers, errors, sd_models\r\nfrom modules.processing import Processed, process_images\r\nfrom modules.shared import state\r\n\r\n\r\ndef process_model_tag(tag):\r\n    info = sd_models.get_closet_checkpoint_match(tag)\r\n    assert info is not None, f'Unknown checkpoint: {tag}'\r\n    return info.name\r\n\r\n\r\ndef process_string_tag(tag):\r\n    return tag\r\n\r\n\r\ndef process_int_tag(tag):\r\n    return int(tag)\r\n\r\n\r\ndef process_float_tag(tag):\r\n    return float(tag)\r\n\r\n\r\ndef process_boolean_tag(tag):\r\n    return True if (tag == \"true\") else False\r\n\r\n\r\nprompt_tags = {\r\n    \"sd_model\": process_model_tag,\r\n    \"outpath_samples\": process_string_tag,\r\n    \"outpath_grids\": process_string_tag,\r\n    \"prompt_for_display\": process_string_tag,\r\n    \"prompt\": process_string_tag,\r\n    \"negative_prompt\": process_string_tag,\r\n    \"styles\": process_string_tag,\r\n    \"seed\": process_int_tag,\r\n    \"subseed_strength\": process_float_tag,\r\n    \"subseed\": process_int_tag,\r\n    \"seed_resize_from_h\": process_int_tag,\r\n    \"seed_resize_from_w\": process_int_tag,\r\n    \"sampler_index\": process_int_tag,\r\n    \"sampler_name\": process_string_tag,\r\n    \"batch_size\": process_int_tag,\r\n    \"n_iter\": process_int_tag,\r\n    \"steps\": process_int_tag,\r\n    \"cfg_scale\": process_float_tag,\r\n    \"width\": process_int_tag,\r\n    \"height\": process_int_tag,\r\n    \"restore_faces\": process_boolean_tag,\r\n    \"tiling\": process_boolean_tag,\r\n    \"do_not_save_samples\": process_boolean_tag,\r\n    \"do_not_save_grid\": process_boolean_tag\r\n}\r\n\r\n\r\ndef cmdargs(line):\r\n    args = shlex.split(line)\r\n    pos = 0\r\n    res = {}\r\n\r\n    while pos < len(args):\r\n        arg = args[pos]\r\n\r\n        assert arg.startswith(\"--\"), f'must start with \"--\": {arg}'\r\n        assert pos+1 < len(args), f'missing argument for command line option {arg}'\r\n\r\n        tag = arg[2:]\r\n\r\n        if tag == \"prompt\" or tag == \"negative_prompt\":\r\n            pos += 1\r\n            prompt = args[pos]\r\n            pos += 1\r\n            while pos < len(args) and not args[pos].startswith(\"--\"):\r\n                prompt += \" \"\r\n                prompt += args[pos]\r\n                pos += 1\r\n            res[tag] = prompt\r\n            continue\r\n\r\n\r\n        func = prompt_tags.get(tag, None)\r\n        assert func, f'unknown commandline option: {arg}'\r\n\r\n        val = args[pos+1]\r\n        if tag == \"sampler_name\":\r\n            val = sd_samplers.samplers_map.get(val.lower(), None)\r\n\r\n        res[tag] = func(val)\r\n\r\n        pos += 2\r\n\r\n    return res\r\n\r\n\r\ndef load_prompt_file(file):\r\n    if file is None:\r\n        return None, gr.update(), gr.update(lines=7)\r\n    else:\r\n        lines = [x.strip() for x in file.decode('utf8', errors='ignore').split(\"\\n\")]\r\n        return None, \"\\n\".join(lines), gr.update(lines=7)\r\n\r\n\r\nclass Script(scripts.Script):\r\n    def title(self):\r\n        return \"Prompts from file or textbox\"\r\n\r\n    def ui(self, is_img2img):\r\n        checkbox_iterate = gr.Checkbox(label=\"Iterate seed every line\", value=False, elem_id=self.elem_id(\"checkbox_iterate\"))\r\n        checkbox_iterate_batch = gr.Checkbox(label=\"Use same random seed for all lines\", value=False, elem_id=self.elem_id(\"checkbox_iterate_batch\"))\r\n        prompt_position = gr.Radio([\"start\", \"end\"], label=\"Insert prompts at the\", elem_id=self.elem_id(\"prompt_position\"), value=\"start\")\r\n\r\n        prompt_txt = gr.Textbox(label=\"List of prompt inputs\", lines=1, elem_id=self.elem_id(\"prompt_txt\"))\r\n        file = gr.File(label=\"Upload prompt inputs\", type='binary', elem_id=self.elem_id(\"file\"))\r\n\r\n        file.change(fn=load_prompt_file, inputs=[file], outputs=[file, prompt_txt, prompt_txt], show_progress=False)\r\n\r\n        # We start at one line. When the text changes, we jump to seven lines, or two lines if no \\n.\r\n        # We don't shrink back to 1, because that causes the control to ignore [enter], and it may\r\n        # be unclear to the user that shift-enter is needed.\r\n        prompt_txt.change(lambda tb: gr.update(lines=7) if (\"\\n\" in tb) else gr.update(lines=2), inputs=[prompt_txt], outputs=[prompt_txt], show_progress=False)\r\n        return [checkbox_iterate, checkbox_iterate_batch, prompt_position, prompt_txt]\r\n\r\n    def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_position, prompt_txt: str):\r\n        lines = [x for x in (x.strip() for x in prompt_txt.splitlines()) if x]\r\n\r\n        p.do_not_save_grid = True\r\n\r\n        job_count = 0\r\n        jobs = []\r\n\r\n        for line in lines:\r\n            if \"--\" in line:\r\n                try:\r\n                    args = cmdargs(line)\r\n                except Exception:\r\n                    errors.report(f\"Error parsing line {line} as commandline\", exc_info=True)\r\n                    args = {\"prompt\": line}\r\n            else:\r\n                args = {\"prompt\": line}\r\n\r\n            job_count += args.get(\"n_iter\", p.n_iter)\r\n\r\n            jobs.append(args)\r\n\r\n        print(f\"Will process {len(lines)} lines in {job_count} jobs.\")\r\n        if (checkbox_iterate or checkbox_iterate_batch) and p.seed == -1:\r\n            p.seed = int(random.randrange(4294967294))\r\n\r\n        state.job_count = job_count\r\n\r\n        images = []\r\n        all_prompts = []\r\n        infotexts = []\r\n        for args in jobs:\r\n            state.job = f\"{state.job_no + 1} out of {state.job_count}\"\r\n\r\n            copy_p = copy.copy(p)\r\n            for k, v in args.items():\r\n                if k == \"sd_model\":\r\n                    copy_p.override_settings['sd_model_checkpoint'] = v\r\n                else:\r\n                    setattr(copy_p, k, v)\r\n\r\n            if args.get(\"prompt\") and p.prompt:\r\n                if prompt_position == \"start\":\r\n                    copy_p.prompt = args.get(\"prompt\") + \" \" + p.prompt\r\n                else:\r\n                    copy_p.prompt = p.prompt + \" \" + args.get(\"prompt\")\r\n\r\n            if args.get(\"negative_prompt\") and p.negative_prompt:\r\n                if prompt_position == \"start\":\r\n                    copy_p.negative_prompt = args.get(\"negative_prompt\") + \" \" + p.negative_prompt\r\n                else:\r\n                    copy_p.negative_prompt = p.negative_prompt + \" \" + args.get(\"negative_prompt\")\r\n\r\n            proc = process_images(copy_p)\r\n            images += proc.images\r\n\r\n            if checkbox_iterate:\r\n                p.seed = p.seed + (p.batch_size * p.n_iter)\r\n            all_prompts += proc.all_prompts\r\n            infotexts += proc.infotexts\r\n\r\n        return Processed(p, images, p.seed, \"\", all_prompts=all_prompts, infotexts=infotexts)\r\n"
  },
  {
    "path": "scripts/sd_upscale.py",
    "content": "import math\r\n\r\nimport modules.scripts as scripts\r\nimport gradio as gr\r\nfrom PIL import Image\r\n\r\nfrom modules import processing, shared, images, devices\r\nfrom modules.processing import Processed\r\nfrom modules.shared import opts, state\r\n\r\n\r\nclass Script(scripts.Script):\r\n    def title(self):\r\n        return \"SD upscale\"\r\n\r\n    def show(self, is_img2img):\r\n        return is_img2img\r\n\r\n    def ui(self, is_img2img):\r\n        info = gr.HTML(\"<p style=\\\"margin-bottom:0.75em\\\">Will upscale the image by the selected scale factor; use width and height sliders to set tile size</p>\")\r\n        overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, elem_id=self.elem_id(\"overlap\"))\r\n        scale_factor = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label='Scale Factor', value=2.0, elem_id=self.elem_id(\"scale_factor\"))\r\n        upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type=\"index\", elem_id=self.elem_id(\"upscaler_index\"))\r\n\r\n        return [info, overlap, upscaler_index, scale_factor]\r\n\r\n    def run(self, p, _, overlap, upscaler_index, scale_factor):\r\n        if isinstance(upscaler_index, str):\r\n            upscaler_index = [x.name.lower() for x in shared.sd_upscalers].index(upscaler_index.lower())\r\n        processing.fix_seed(p)\r\n        upscaler = shared.sd_upscalers[upscaler_index]\r\n\r\n        p.extra_generation_params[\"SD upscale overlap\"] = overlap\r\n        p.extra_generation_params[\"SD upscale upscaler\"] = upscaler.name\r\n\r\n        initial_info = None\r\n        seed = p.seed\r\n\r\n        init_img = p.init_images[0]\r\n        init_img = images.flatten(init_img, opts.img2img_background_color)\r\n\r\n        if upscaler.name != \"None\":\r\n            img = upscaler.scaler.upscale(init_img, scale_factor, upscaler.data_path)\r\n        else:\r\n            img = init_img\r\n\r\n        devices.torch_gc()\r\n\r\n        grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=overlap)\r\n\r\n        batch_size = p.batch_size\r\n        upscale_count = p.n_iter\r\n        p.n_iter = 1\r\n        p.do_not_save_grid = True\r\n        p.do_not_save_samples = True\r\n\r\n        work = []\r\n\r\n        for _y, _h, row in grid.tiles:\r\n            for tiledata in row:\r\n                work.append(tiledata[2])\r\n\r\n        batch_count = math.ceil(len(work) / batch_size)\r\n        state.job_count = batch_count * upscale_count\r\n\r\n        print(f\"SD upscaling will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)} per upscale in a total of {state.job_count} batches.\")\r\n\r\n        result_images = []\r\n        for n in range(upscale_count):\r\n            start_seed = seed + n\r\n            p.seed = start_seed\r\n\r\n            work_results = []\r\n            for i in range(batch_count):\r\n                p.batch_size = batch_size\r\n                p.init_images = work[i * batch_size:(i + 1) * batch_size]\r\n\r\n                state.job = f\"Batch {i + 1 + n * batch_count} out of {state.job_count}\"\r\n                processed = processing.process_images(p)\r\n\r\n                if initial_info is None:\r\n                    initial_info = processed.info\r\n\r\n                p.seed = processed.seed + 1\r\n                work_results += processed.images\r\n\r\n            image_index = 0\r\n            for _y, _h, row in grid.tiles:\r\n                for tiledata in row:\r\n                    tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new(\"RGB\", (p.width, p.height))\r\n                    image_index += 1\r\n\r\n            combined_image = images.combine_grid(grid)\r\n            result_images.append(combined_image)\r\n\r\n            if opts.samples_save:\r\n                images.save_image(combined_image, p.outpath_samples, \"\", start_seed, p.prompt, opts.samples_format, info=initial_info, p=p)\r\n\r\n        processed = Processed(p, result_images, seed, initial_info)\r\n\r\n        return processed\r\n"
  },
  {
    "path": "scripts/xyz_grid.py",
    "content": "from collections import namedtuple\r\nfrom copy import copy\r\nfrom itertools import permutations, chain\r\nimport random\r\nimport csv\r\nimport os.path\r\nfrom io import StringIO\r\nfrom PIL import Image\r\nimport numpy as np\r\n\r\nimport modules.scripts as scripts\r\nimport gradio as gr\r\n\r\nfrom modules import images, sd_samplers, processing, sd_models, sd_vae, sd_schedulers, errors\r\nfrom modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img\r\nfrom modules.shared import opts, state\r\nimport modules.shared as shared\r\nimport modules.sd_samplers\r\nimport modules.sd_models\r\nimport modules.sd_vae\r\nimport re\r\n\r\nfrom modules.ui_components import ToolButton\r\n\r\nfill_values_symbol = \"\\U0001f4d2\"  # 📒\r\n\r\nAxisInfo = namedtuple('AxisInfo', ['axis', 'values'])\r\n\r\n\r\ndef apply_field(field):\r\n    def fun(p, x, xs):\r\n        setattr(p, field, x)\r\n\r\n    return fun\r\n\r\n\r\ndef apply_prompt(p, x, xs):\r\n    if xs[0] not in p.prompt and xs[0] not in p.negative_prompt:\r\n        raise RuntimeError(f\"Prompt S/R did not find {xs[0]} in prompt or negative prompt.\")\r\n\r\n    p.prompt = p.prompt.replace(xs[0], x)\r\n    p.negative_prompt = p.negative_prompt.replace(xs[0], x)\r\n\r\n\r\ndef apply_order(p, x, xs):\r\n    token_order = []\r\n\r\n    # Initially grab the tokens from the prompt, so they can be replaced in order of earliest seen\r\n    for token in x:\r\n        token_order.append((p.prompt.find(token), token))\r\n\r\n    token_order.sort(key=lambda t: t[0])\r\n\r\n    prompt_parts = []\r\n\r\n    # Split the prompt up, taking out the tokens\r\n    for _, token in token_order:\r\n        n = p.prompt.find(token)\r\n        prompt_parts.append(p.prompt[0:n])\r\n        p.prompt = p.prompt[n + len(token):]\r\n\r\n    # Rebuild the prompt with the tokens in the order we want\r\n    prompt_tmp = \"\"\r\n    for idx, part in enumerate(prompt_parts):\r\n        prompt_tmp += part\r\n        prompt_tmp += x[idx]\r\n    p.prompt = prompt_tmp + p.prompt\r\n\r\n\r\ndef confirm_samplers(p, xs):\r\n    for x in xs:\r\n        if x.lower() not in sd_samplers.samplers_map:\r\n            raise RuntimeError(f\"Unknown sampler: {x}\")\r\n\r\n\r\ndef apply_checkpoint(p, x, xs):\r\n    info = modules.sd_models.get_closet_checkpoint_match(x)\r\n    if info is None:\r\n        raise RuntimeError(f\"Unknown checkpoint: {x}\")\r\n    p.override_settings['sd_model_checkpoint'] = info.name\r\n\r\n\r\ndef confirm_checkpoints(p, xs):\r\n    for x in xs:\r\n        if modules.sd_models.get_closet_checkpoint_match(x) is None:\r\n            raise RuntimeError(f\"Unknown checkpoint: {x}\")\r\n\r\n\r\ndef confirm_checkpoints_or_none(p, xs):\r\n    for x in xs:\r\n        if x in (None, \"\", \"None\", \"none\"):\r\n            continue\r\n\r\n        if modules.sd_models.get_closet_checkpoint_match(x) is None:\r\n            raise RuntimeError(f\"Unknown checkpoint: {x}\")\r\n\r\n\r\ndef confirm_range(min_val, max_val, axis_label):\r\n    \"\"\"Generates a AxisOption.confirm() function that checks all values are within the specified range.\"\"\"\r\n\r\n    def confirm_range_fun(p, xs):\r\n        for x in xs:\r\n            if not (max_val >= x >= min_val):\r\n                raise ValueError(f'{axis_label} value \"{x}\" out of range [{min_val}, {max_val}]')\r\n\r\n    return confirm_range_fun\r\n\r\n\r\ndef apply_size(p, x: str, xs) -> None:\r\n    try:\r\n        width, _, height = x.partition('x')\r\n        width = int(width.strip())\r\n        height = int(height.strip())\r\n        p.width = width\r\n        p.height = height\r\n    except ValueError:\r\n        print(f\"Invalid size in XYZ plot: {x}\")\r\n\r\n\r\ndef find_vae(name: str):\r\n    if (name := name.strip().lower()) in ('auto', 'automatic'):\r\n        return 'Automatic'\r\n    elif name == 'none':\r\n        return 'None'\r\n    return next((k for k in modules.sd_vae.vae_dict if k.lower() == name), print(f'No VAE found for {name}; using Automatic') or 'Automatic')\r\n\r\n\r\ndef apply_vae(p, x, xs):\r\n    p.override_settings['sd_vae'] = find_vae(x)\r\n\r\n\r\ndef apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _):\r\n    p.styles.extend(x.split(','))\r\n\r\n\r\ndef apply_uni_pc_order(p, x, xs):\r\n    p.override_settings['uni_pc_order'] = min(x, p.steps - 1)\r\n\r\n\r\ndef apply_face_restore(p, opt, x):\r\n    opt = opt.lower()\r\n    if opt == 'codeformer':\r\n        is_active = True\r\n        p.face_restoration_model = 'CodeFormer'\r\n    elif opt == 'gfpgan':\r\n        is_active = True\r\n        p.face_restoration_model = 'GFPGAN'\r\n    else:\r\n        is_active = opt in ('true', 'yes', 'y', '1')\r\n\r\n    p.restore_faces = is_active\r\n\r\n\r\ndef apply_override(field, boolean: bool = False):\r\n    def fun(p, x, xs):\r\n        if boolean:\r\n            x = True if x.lower() == \"true\" else False\r\n        p.override_settings[field] = x\r\n\r\n    return fun\r\n\r\n\r\ndef boolean_choice(reverse: bool = False):\r\n    def choice():\r\n        return [\"False\", \"True\"] if reverse else [\"True\", \"False\"]\r\n\r\n    return choice\r\n\r\n\r\ndef format_value_add_label(p, opt, x):\r\n    if type(x) == float:\r\n        x = round(x, 8)\r\n\r\n    return f\"{opt.label}: {x}\"\r\n\r\n\r\ndef format_value(p, opt, x):\r\n    if type(x) == float:\r\n        x = round(x, 8)\r\n    return x\r\n\r\n\r\ndef format_value_join_list(p, opt, x):\r\n    return \", \".join(x)\r\n\r\n\r\ndef do_nothing(p, x, xs):\r\n    pass\r\n\r\n\r\ndef format_nothing(p, opt, x):\r\n    return \"\"\r\n\r\n\r\ndef format_remove_path(p, opt, x):\r\n    return os.path.basename(x)\r\n\r\n\r\ndef str_permutations(x):\r\n    \"\"\"dummy function for specifying it in AxisOption's type when you want to get a list of permutations\"\"\"\r\n    return x\r\n\r\n\r\ndef list_to_csv_string(data_list):\r\n    with StringIO() as o:\r\n        csv.writer(o).writerow(data_list)\r\n        return o.getvalue().strip()\r\n\r\n\r\ndef csv_string_to_list_strip(data_str):\r\n    return list(map(str.strip, chain.from_iterable(csv.reader(StringIO(data_str), skipinitialspace=True))))\r\n\r\n\r\nclass AxisOption:\r\n    def __init__(self, label, type, apply, format_value=format_value_add_label, confirm=None, cost=0.0, choices=None, prepare=None):\r\n        self.label = label\r\n        self.type = type\r\n        self.apply = apply\r\n        self.format_value = format_value\r\n        self.confirm = confirm\r\n        self.cost = cost\r\n        self.prepare = prepare\r\n        self.choices = choices\r\n\r\n\r\nclass AxisOptionImg2Img(AxisOption):\r\n    def __init__(self, *args, **kwargs):\r\n        super().__init__(*args, **kwargs)\r\n        self.is_img2img = True\r\n\r\n\r\nclass AxisOptionTxt2Img(AxisOption):\r\n    def __init__(self, *args, **kwargs):\r\n        super().__init__(*args, **kwargs)\r\n        self.is_img2img = False\r\n\r\n\r\naxis_options = [\r\n    AxisOption(\"Nothing\", str, do_nothing, format_value=format_nothing),\r\n    AxisOption(\"Seed\", int, apply_field(\"seed\")),\r\n    AxisOption(\"Var. seed\", int, apply_field(\"subseed\")),\r\n    AxisOption(\"Var. strength\", float, apply_field(\"subseed_strength\")),\r\n    AxisOption(\"Steps\", int, apply_field(\"steps\")),\r\n    AxisOptionTxt2Img(\"Hires steps\", int, apply_field(\"hr_second_pass_steps\")),\r\n    AxisOption(\"CFG Scale\", float, apply_field(\"cfg_scale\")),\r\n    AxisOptionImg2Img(\"Image CFG Scale\", float, apply_field(\"image_cfg_scale\")),\r\n    AxisOption(\"Prompt S/R\", str, apply_prompt, format_value=format_value),\r\n    AxisOption(\"Prompt order\", str_permutations, apply_order, format_value=format_value_join_list),\r\n    AxisOptionTxt2Img(\"Sampler\", str, apply_field(\"sampler_name\"), format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers if x.name not in opts.hide_samplers]),\r\n    AxisOptionTxt2Img(\"Hires sampler\", str, apply_field(\"hr_sampler_name\"), confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img if x.name not in opts.hide_samplers]),\r\n    AxisOptionImg2Img(\"Sampler\", str, apply_field(\"sampler_name\"), format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img if x.name not in opts.hide_samplers]),\r\n    AxisOption(\"Checkpoint name\", str, apply_checkpoint, format_value=format_remove_path, confirm=confirm_checkpoints, cost=1.0, choices=lambda: sorted(sd_models.checkpoints_list, key=str.casefold)),\r\n    AxisOption(\"Negative Guidance minimum sigma\", float, apply_field(\"s_min_uncond\")),\r\n    AxisOption(\"Sigma Churn\", float, apply_field(\"s_churn\")),\r\n    AxisOption(\"Sigma min\", float, apply_field(\"s_tmin\")),\r\n    AxisOption(\"Sigma max\", float, apply_field(\"s_tmax\")),\r\n    AxisOption(\"Sigma noise\", float, apply_field(\"s_noise\")),\r\n    AxisOption(\"Schedule type\", str, apply_field(\"scheduler\"), choices=lambda: [x.label for x in sd_schedulers.schedulers]),\r\n    AxisOption(\"Schedule min sigma\", float, apply_override(\"sigma_min\")),\r\n    AxisOption(\"Schedule max sigma\", float, apply_override(\"sigma_max\")),\r\n    AxisOption(\"Schedule rho\", float, apply_override(\"rho\")),\r\n    AxisOption(\"Beta schedule alpha\", float, apply_override(\"beta_dist_alpha\")),\r\n    AxisOption(\"Beta schedule beta\", float, apply_override(\"beta_dist_beta\")),\r\n    AxisOption(\"Eta\", float, apply_field(\"eta\")),\r\n    AxisOption(\"Clip skip\", int, apply_override('CLIP_stop_at_last_layers')),\r\n    AxisOption(\"Denoising\", float, apply_field(\"denoising_strength\")),\r\n    AxisOption(\"Initial noise multiplier\", float, apply_field(\"initial_noise_multiplier\")),\r\n    AxisOption(\"Extra noise\", float, apply_override(\"img2img_extra_noise\")),\r\n    AxisOptionTxt2Img(\"Hires upscaler\", str, apply_field(\"hr_upscaler\"), choices=lambda: [*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]]),\r\n    AxisOptionImg2Img(\"Cond. Image Mask Weight\", float, apply_field(\"inpainting_mask_weight\")),\r\n    AxisOption(\"VAE\", str, apply_vae, cost=0.7, choices=lambda: ['Automatic', 'None'] + list(sd_vae.vae_dict)),\r\n    AxisOption(\"Styles\", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)),\r\n    AxisOption(\"UniPC Order\", int, apply_uni_pc_order, cost=0.5),\r\n    AxisOption(\"Face restore\", str, apply_face_restore, format_value=format_value),\r\n    AxisOption(\"Token merging ratio\", float, apply_override('token_merging_ratio')),\r\n    AxisOption(\"Token merging ratio high-res\", float, apply_override('token_merging_ratio_hr')),\r\n    AxisOption(\"Always discard next-to-last sigma\", str, apply_override('always_discard_next_to_last_sigma', boolean=True), choices=boolean_choice(reverse=True)),\r\n    AxisOption(\"SGM noise multiplier\", str, apply_override('sgm_noise_multiplier', boolean=True), choices=boolean_choice(reverse=True)),\r\n    AxisOption(\"Refiner checkpoint\", str, apply_field('refiner_checkpoint'), format_value=format_remove_path, confirm=confirm_checkpoints_or_none, cost=1.0, choices=lambda: ['None'] + sorted(sd_models.checkpoints_list, key=str.casefold)),\r\n    AxisOption(\"Refiner switch at\", float, apply_field('refiner_switch_at')),\r\n    AxisOption(\"RNG source\", str, apply_override(\"randn_source\"), choices=lambda: [\"GPU\", \"CPU\", \"NV\"]),\r\n    AxisOption(\"FP8 mode\", str, apply_override(\"fp8_storage\"), cost=0.9, choices=lambda: [\"Disable\", \"Enable for SDXL\", \"Enable\"]),\r\n    AxisOption(\"Size\", str, apply_size),\r\n]\r\n\r\n\r\ndef draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend, include_lone_images, include_sub_grids, first_axes_processed, second_axes_processed, margin_size):\r\n    hor_texts = [[images.GridAnnotation(x)] for x in x_labels]\r\n    ver_texts = [[images.GridAnnotation(y)] for y in y_labels]\r\n    title_texts = [[images.GridAnnotation(z)] for z in z_labels]\r\n\r\n    list_size = (len(xs) * len(ys) * len(zs))\r\n\r\n    processed_result = None\r\n\r\n    state.job_count = list_size * p.n_iter\r\n\r\n    def process_cell(x, y, z, ix, iy, iz):\r\n        nonlocal processed_result\r\n\r\n        def index(ix, iy, iz):\r\n            return ix + iy * len(xs) + iz * len(xs) * len(ys)\r\n\r\n        state.job = f\"{index(ix, iy, iz) + 1} out of {list_size}\"\r\n\r\n        processed: Processed = cell(x, y, z, ix, iy, iz)\r\n\r\n        if processed_result is None:\r\n            # Use our first processed result object as a template container to hold our full results\r\n            processed_result = copy(processed)\r\n            processed_result.images = [None] * list_size\r\n            processed_result.all_prompts = [None] * list_size\r\n            processed_result.all_seeds = [None] * list_size\r\n            processed_result.infotexts = [None] * list_size\r\n            processed_result.index_of_first_image = 1\r\n\r\n        idx = index(ix, iy, iz)\r\n        if processed.images:\r\n            # Non-empty list indicates some degree of success.\r\n            processed_result.images[idx] = processed.images[0]\r\n            processed_result.all_prompts[idx] = processed.prompt\r\n            processed_result.all_seeds[idx] = processed.seed\r\n            processed_result.infotexts[idx] = processed.infotexts[0]\r\n        else:\r\n            cell_mode = \"P\"\r\n            cell_size = (processed_result.width, processed_result.height)\r\n            if processed_result.images[0] is not None:\r\n                cell_mode = processed_result.images[0].mode\r\n                # This corrects size in case of batches:\r\n                cell_size = processed_result.images[0].size\r\n            processed_result.images[idx] = Image.new(cell_mode, cell_size)\r\n\r\n    if first_axes_processed == 'x':\r\n        for ix, x in enumerate(xs):\r\n            if second_axes_processed == 'y':\r\n                for iy, y in enumerate(ys):\r\n                    for iz, z in enumerate(zs):\r\n                        process_cell(x, y, z, ix, iy, iz)\r\n            else:\r\n                for iz, z in enumerate(zs):\r\n                    for iy, y in enumerate(ys):\r\n                        process_cell(x, y, z, ix, iy, iz)\r\n    elif first_axes_processed == 'y':\r\n        for iy, y in enumerate(ys):\r\n            if second_axes_processed == 'x':\r\n                for ix, x in enumerate(xs):\r\n                    for iz, z in enumerate(zs):\r\n                        process_cell(x, y, z, ix, iy, iz)\r\n            else:\r\n                for iz, z in enumerate(zs):\r\n                    for ix, x in enumerate(xs):\r\n                        process_cell(x, y, z, ix, iy, iz)\r\n    elif first_axes_processed == 'z':\r\n        for iz, z in enumerate(zs):\r\n            if second_axes_processed == 'x':\r\n                for ix, x in enumerate(xs):\r\n                    for iy, y in enumerate(ys):\r\n                        process_cell(x, y, z, ix, iy, iz)\r\n            else:\r\n                for iy, y in enumerate(ys):\r\n                    for ix, x in enumerate(xs):\r\n                        process_cell(x, y, z, ix, iy, iz)\r\n\r\n    if not processed_result:\r\n        # Should never happen, I've only seen it on one of four open tabs and it needed to refresh.\r\n        print(\"Unexpected error: Processing could not begin, you may need to refresh the tab or restart the service.\")\r\n        return Processed(p, [])\r\n    elif not any(processed_result.images):\r\n        print(\"Unexpected error: draw_xyz_grid failed to return even a single processed image\")\r\n        return Processed(p, [])\r\n\r\n    z_count = len(zs)\r\n\r\n    for i in range(z_count):\r\n        start_index = (i * len(xs) * len(ys)) + i\r\n        end_index = start_index + len(xs) * len(ys)\r\n        grid = images.image_grid(processed_result.images[start_index:end_index], rows=len(ys))\r\n        if draw_legend:\r\n            grid_max_w, grid_max_h = map(max, zip(*(img.size for img in processed_result.images[start_index:end_index])))\r\n            grid = images.draw_grid_annotations(grid, grid_max_w, grid_max_h, hor_texts, ver_texts, margin_size)\r\n        processed_result.images.insert(i, grid)\r\n        processed_result.all_prompts.insert(i, processed_result.all_prompts[start_index])\r\n        processed_result.all_seeds.insert(i, processed_result.all_seeds[start_index])\r\n        processed_result.infotexts.insert(i, processed_result.infotexts[start_index])\r\n\r\n    z_grid = images.image_grid(processed_result.images[:z_count], rows=1)\r\n    z_sub_grid_max_w, z_sub_grid_max_h = map(max, zip(*(img.size for img in processed_result.images[:z_count])))\r\n    if draw_legend:\r\n        z_grid = images.draw_grid_annotations(z_grid, z_sub_grid_max_w, z_sub_grid_max_h, title_texts, [[images.GridAnnotation()]])\r\n    processed_result.images.insert(0, z_grid)\r\n    # TODO: Deeper aspects of the program rely on grid info being misaligned between metadata arrays, which is not ideal.\r\n    # processed_result.all_prompts.insert(0, processed_result.all_prompts[0])\r\n    # processed_result.all_seeds.insert(0, processed_result.all_seeds[0])\r\n    processed_result.infotexts.insert(0, processed_result.infotexts[0])\r\n\r\n    return processed_result\r\n\r\n\r\nclass SharedSettingsStackHelper(object):\r\n    def __enter__(self):\r\n        pass\r\n\r\n    def __exit__(self, exc_type, exc_value, tb):\r\n        modules.sd_models.reload_model_weights()\r\n        modules.sd_vae.reload_vae_weights()\r\n\r\n\r\nre_range = re.compile(r\"\\s*([+-]?\\s*\\d+)\\s*-\\s*([+-]?\\s*\\d+)(?:\\s*\\(([+-]\\d+)\\s*\\))?\\s*\")\r\nre_range_float = re.compile(r\"\\s*([+-]?\\s*\\d+(?:.\\d*)?)\\s*-\\s*([+-]?\\s*\\d+(?:.\\d*)?)(?:\\s*\\(([+-]\\d+(?:.\\d*)?)\\s*\\))?\\s*\")\r\n\r\nre_range_count = re.compile(r\"\\s*([+-]?\\s*\\d+)\\s*-\\s*([+-]?\\s*\\d+)(?:\\s*\\[(\\d+)\\s*])?\\s*\")\r\nre_range_count_float = re.compile(r\"\\s*([+-]?\\s*\\d+(?:.\\d*)?)\\s*-\\s*([+-]?\\s*\\d+(?:.\\d*)?)(?:\\s*\\[(\\d+(?:.\\d*)?)\\s*])?\\s*\")\r\n\r\n\r\nclass Script(scripts.Script):\r\n    def title(self):\r\n        return \"X/Y/Z plot\"\r\n\r\n    def ui(self, is_img2img):\r\n        self.current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img == is_img2img]\r\n\r\n        with gr.Row():\r\n            with gr.Column(scale=19):\r\n                with gr.Row():\r\n                    x_type = gr.Dropdown(label=\"X type\", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type=\"index\", elem_id=self.elem_id(\"x_type\"))\r\n                    x_values = gr.Textbox(label=\"X values\", lines=1, elem_id=self.elem_id(\"x_values\"))\r\n                    x_values_dropdown = gr.Dropdown(label=\"X values\", visible=False, multiselect=True, interactive=True)\r\n                    fill_x_button = ToolButton(value=fill_values_symbol, elem_id=\"xyz_grid_fill_x_tool_button\", visible=False)\r\n\r\n                with gr.Row():\r\n                    y_type = gr.Dropdown(label=\"Y type\", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type=\"index\", elem_id=self.elem_id(\"y_type\"))\r\n                    y_values = gr.Textbox(label=\"Y values\", lines=1, elem_id=self.elem_id(\"y_values\"))\r\n                    y_values_dropdown = gr.Dropdown(label=\"Y values\", visible=False, multiselect=True, interactive=True)\r\n                    fill_y_button = ToolButton(value=fill_values_symbol, elem_id=\"xyz_grid_fill_y_tool_button\", visible=False)\r\n\r\n                with gr.Row():\r\n                    z_type = gr.Dropdown(label=\"Z type\", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type=\"index\", elem_id=self.elem_id(\"z_type\"))\r\n                    z_values = gr.Textbox(label=\"Z values\", lines=1, elem_id=self.elem_id(\"z_values\"))\r\n                    z_values_dropdown = gr.Dropdown(label=\"Z values\", visible=False, multiselect=True, interactive=True)\r\n                    fill_z_button = ToolButton(value=fill_values_symbol, elem_id=\"xyz_grid_fill_z_tool_button\", visible=False)\r\n\r\n        with gr.Row(variant=\"compact\", elem_id=\"axis_options\"):\r\n            with gr.Column():\r\n                draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id(\"draw_legend\"))\r\n                no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id(\"no_fixed_seeds\"))\r\n                with gr.Row():\r\n                    vary_seeds_x = gr.Checkbox(label='Vary seeds for X', value=False, min_width=80, elem_id=self.elem_id(\"vary_seeds_x\"), tooltip=\"Use different seeds for images along X axis.\")\r\n                    vary_seeds_y = gr.Checkbox(label='Vary seeds for Y', value=False, min_width=80, elem_id=self.elem_id(\"vary_seeds_y\"), tooltip=\"Use different seeds for images along Y axis.\")\r\n                    vary_seeds_z = gr.Checkbox(label='Vary seeds for Z', value=False, min_width=80, elem_id=self.elem_id(\"vary_seeds_z\"), tooltip=\"Use different seeds for images along Z axis.\")\r\n            with gr.Column():\r\n                include_lone_images = gr.Checkbox(label='Include Sub Images', value=False, elem_id=self.elem_id(\"include_lone_images\"))\r\n                include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id(\"include_sub_grids\"))\r\n                csv_mode = gr.Checkbox(label='Use text inputs instead of dropdowns', value=False, elem_id=self.elem_id(\"csv_mode\"))\r\n            with gr.Column():\r\n                margin_size = gr.Slider(label=\"Grid margins (px)\", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id(\"margin_size\"))\r\n\r\n        with gr.Row(variant=\"compact\", elem_id=\"swap_axes\"):\r\n            swap_xy_axes_button = gr.Button(value=\"Swap X/Y axes\", elem_id=\"xy_grid_swap_axes_button\")\r\n            swap_yz_axes_button = gr.Button(value=\"Swap Y/Z axes\", elem_id=\"yz_grid_swap_axes_button\")\r\n            swap_xz_axes_button = gr.Button(value=\"Swap X/Z axes\", elem_id=\"xz_grid_swap_axes_button\")\r\n\r\n        def swap_axes(axis1_type, axis1_values, axis1_values_dropdown, axis2_type, axis2_values, axis2_values_dropdown):\r\n            return self.current_axis_options[axis2_type].label, axis2_values, axis2_values_dropdown, self.current_axis_options[axis1_type].label, axis1_values, axis1_values_dropdown\r\n\r\n        xy_swap_args = [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown]\r\n        swap_xy_axes_button.click(swap_axes, inputs=xy_swap_args, outputs=xy_swap_args)\r\n        yz_swap_args = [y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown]\r\n        swap_yz_axes_button.click(swap_axes, inputs=yz_swap_args, outputs=yz_swap_args)\r\n        xz_swap_args = [x_type, x_values, x_values_dropdown, z_type, z_values, z_values_dropdown]\r\n        swap_xz_axes_button.click(swap_axes, inputs=xz_swap_args, outputs=xz_swap_args)\r\n\r\n        def fill(axis_type, csv_mode):\r\n            axis = self.current_axis_options[axis_type]\r\n            if axis.choices:\r\n                if csv_mode:\r\n                    return list_to_csv_string(axis.choices()), gr.update()\r\n                else:\r\n                    return gr.update(), axis.choices()\r\n            else:\r\n                return gr.update(), gr.update()\r\n\r\n        fill_x_button.click(fn=fill, inputs=[x_type, csv_mode], outputs=[x_values, x_values_dropdown])\r\n        fill_y_button.click(fn=fill, inputs=[y_type, csv_mode], outputs=[y_values, y_values_dropdown])\r\n        fill_z_button.click(fn=fill, inputs=[z_type, csv_mode], outputs=[z_values, z_values_dropdown])\r\n\r\n        def select_axis(axis_type, axis_values, axis_values_dropdown, csv_mode):\r\n            axis_type = axis_type or 0  # if axle type is None set to 0\r\n\r\n            choices = self.current_axis_options[axis_type].choices\r\n            has_choices = choices is not None\r\n\r\n            if has_choices:\r\n                choices = choices()\r\n                if csv_mode:\r\n                    if axis_values_dropdown:\r\n                        axis_values = list_to_csv_string(list(filter(lambda x: x in choices, axis_values_dropdown)))\r\n                        axis_values_dropdown = []\r\n                else:\r\n                    if axis_values:\r\n                        axis_values_dropdown = list(filter(lambda x: x in choices, csv_string_to_list_strip(axis_values)))\r\n                        axis_values = \"\"\r\n\r\n            return (gr.Button.update(visible=has_choices), gr.Textbox.update(visible=not has_choices or csv_mode, value=axis_values),\r\n                    gr.update(choices=choices if has_choices else None, visible=has_choices and not csv_mode, value=axis_values_dropdown))\r\n\r\n        x_type.change(fn=select_axis, inputs=[x_type, x_values, x_values_dropdown, csv_mode], outputs=[fill_x_button, x_values, x_values_dropdown])\r\n        y_type.change(fn=select_axis, inputs=[y_type, y_values, y_values_dropdown, csv_mode], outputs=[fill_y_button, y_values, y_values_dropdown])\r\n        z_type.change(fn=select_axis, inputs=[z_type, z_values, z_values_dropdown, csv_mode], outputs=[fill_z_button, z_values, z_values_dropdown])\r\n\r\n        def change_choice_mode(csv_mode, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown):\r\n            _fill_x_button, _x_values, _x_values_dropdown = select_axis(x_type, x_values, x_values_dropdown, csv_mode)\r\n            _fill_y_button, _y_values, _y_values_dropdown = select_axis(y_type, y_values, y_values_dropdown, csv_mode)\r\n            _fill_z_button, _z_values, _z_values_dropdown = select_axis(z_type, z_values, z_values_dropdown, csv_mode)\r\n            return _fill_x_button, _x_values, _x_values_dropdown, _fill_y_button, _y_values, _y_values_dropdown, _fill_z_button, _z_values, _z_values_dropdown\r\n\r\n        csv_mode.change(fn=change_choice_mode, inputs=[csv_mode, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown], outputs=[fill_x_button, x_values, x_values_dropdown, fill_y_button, y_values, y_values_dropdown, fill_z_button, z_values, z_values_dropdown])\r\n\r\n        def get_dropdown_update_from_params(axis, params):\r\n            val_key = f\"{axis} Values\"\r\n            vals = params.get(val_key, \"\")\r\n            valslist = csv_string_to_list_strip(vals)\r\n            return gr.update(value=valslist)\r\n\r\n        self.infotext_fields = (\r\n            (x_type, \"X Type\"),\r\n            (x_values, \"X Values\"),\r\n            (x_values_dropdown, lambda params: get_dropdown_update_from_params(\"X\", params)),\r\n            (y_type, \"Y Type\"),\r\n            (y_values, \"Y Values\"),\r\n            (y_values_dropdown, lambda params: get_dropdown_update_from_params(\"Y\", params)),\r\n            (z_type, \"Z Type\"),\r\n            (z_values, \"Z Values\"),\r\n            (z_values_dropdown, lambda params: get_dropdown_update_from_params(\"Z\", params)),\r\n        )\r\n\r\n        return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, vary_seeds_x, vary_seeds_y, vary_seeds_z, margin_size, csv_mode]\r\n\r\n    def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, vary_seeds_x, vary_seeds_y, vary_seeds_z, margin_size, csv_mode):\r\n        x_type, y_type, z_type = x_type or 0, y_type or 0, z_type or 0  # if axle type is None set to 0\r\n\r\n        if not no_fixed_seeds:\r\n            modules.processing.fix_seed(p)\r\n\r\n        if not opts.return_grid:\r\n            p.batch_size = 1\r\n\r\n        def process_axis(opt, vals, vals_dropdown):\r\n            if opt.label == 'Nothing':\r\n                return [0]\r\n\r\n            if opt.choices is not None and not csv_mode:\r\n                valslist = vals_dropdown\r\n            elif opt.prepare is not None:\r\n                valslist = opt.prepare(vals)\r\n            else:\r\n                valslist = csv_string_to_list_strip(vals)\r\n\r\n            if opt.type == int:\r\n                valslist_ext = []\r\n\r\n                for val in valslist:\r\n                    if val.strip() == '':\r\n                        continue\r\n                    m = re_range.fullmatch(val)\r\n                    mc = re_range_count.fullmatch(val)\r\n                    if m is not None:\r\n                        start = int(m.group(1))\r\n                        end = int(m.group(2)) + 1\r\n                        step = int(m.group(3)) if m.group(3) is not None else 1\r\n\r\n                        valslist_ext += list(range(start, end, step))\r\n                    elif mc is not None:\r\n                        start = int(mc.group(1))\r\n                        end = int(mc.group(2))\r\n                        num = int(mc.group(3)) if mc.group(3) is not None else 1\r\n\r\n                        valslist_ext += [int(x) for x in np.linspace(start=start, stop=end, num=num).tolist()]\r\n                    else:\r\n                        valslist_ext.append(val)\r\n\r\n                valslist = valslist_ext\r\n            elif opt.type == float:\r\n                valslist_ext = []\r\n\r\n                for val in valslist:\r\n                    if val.strip() == '':\r\n                        continue\r\n                    m = re_range_float.fullmatch(val)\r\n                    mc = re_range_count_float.fullmatch(val)\r\n                    if m is not None:\r\n                        start = float(m.group(1))\r\n                        end = float(m.group(2))\r\n                        step = float(m.group(3)) if m.group(3) is not None else 1\r\n\r\n                        valslist_ext += np.arange(start, end + step, step).tolist()\r\n                    elif mc is not None:\r\n                        start = float(mc.group(1))\r\n                        end = float(mc.group(2))\r\n                        num = int(mc.group(3)) if mc.group(3) is not None else 1\r\n\r\n                        valslist_ext += np.linspace(start=start, stop=end, num=num).tolist()\r\n                    else:\r\n                        valslist_ext.append(val)\r\n\r\n                valslist = valslist_ext\r\n            elif opt.type == str_permutations:\r\n                valslist = list(permutations(valslist))\r\n\r\n            valslist = [opt.type(x) for x in valslist]\r\n\r\n            # Confirm options are valid before starting\r\n            if opt.confirm:\r\n                opt.confirm(p, valslist)\r\n\r\n            return valslist\r\n\r\n        x_opt = self.current_axis_options[x_type]\r\n        if x_opt.choices is not None and not csv_mode:\r\n            x_values = list_to_csv_string(x_values_dropdown)\r\n        xs = process_axis(x_opt, x_values, x_values_dropdown)\r\n\r\n        y_opt = self.current_axis_options[y_type]\r\n        if y_opt.choices is not None and not csv_mode:\r\n            y_values = list_to_csv_string(y_values_dropdown)\r\n        ys = process_axis(y_opt, y_values, y_values_dropdown)\r\n\r\n        z_opt = self.current_axis_options[z_type]\r\n        if z_opt.choices is not None and not csv_mode:\r\n            z_values = list_to_csv_string(z_values_dropdown)\r\n        zs = process_axis(z_opt, z_values, z_values_dropdown)\r\n\r\n        # this could be moved to common code, but unlikely to be ever triggered anywhere else\r\n        Image.MAX_IMAGE_PIXELS = None  # disable check in Pillow and rely on check below to allow large custom image sizes\r\n        grid_mp = round(len(xs) * len(ys) * len(zs) * p.width * p.height / 1000000)\r\n        assert grid_mp < opts.img_max_size_mp, f'Error: Resulting grid would be too large ({grid_mp} MPixels) (max configured size is {opts.img_max_size_mp} MPixels)'\r\n\r\n        def fix_axis_seeds(axis_opt, axis_list):\r\n            if axis_opt.label in ['Seed', 'Var. seed']:\r\n                return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list]\r\n            else:\r\n                return axis_list\r\n\r\n        if not no_fixed_seeds:\r\n            xs = fix_axis_seeds(x_opt, xs)\r\n            ys = fix_axis_seeds(y_opt, ys)\r\n            zs = fix_axis_seeds(z_opt, zs)\r\n\r\n        if x_opt.label == 'Steps':\r\n            total_steps = sum(xs) * len(ys) * len(zs)\r\n        elif y_opt.label == 'Steps':\r\n            total_steps = sum(ys) * len(xs) * len(zs)\r\n        elif z_opt.label == 'Steps':\r\n            total_steps = sum(zs) * len(xs) * len(ys)\r\n        else:\r\n            total_steps = p.steps * len(xs) * len(ys) * len(zs)\r\n\r\n        if isinstance(p, StableDiffusionProcessingTxt2Img) and p.enable_hr:\r\n            if x_opt.label == \"Hires steps\":\r\n                total_steps += sum(xs) * len(ys) * len(zs)\r\n            elif y_opt.label == \"Hires steps\":\r\n                total_steps += sum(ys) * len(xs) * len(zs)\r\n            elif z_opt.label == \"Hires steps\":\r\n                total_steps += sum(zs) * len(xs) * len(ys)\r\n            elif p.hr_second_pass_steps:\r\n                total_steps += p.hr_second_pass_steps * len(xs) * len(ys) * len(zs)\r\n            else:\r\n                total_steps *= 2\r\n\r\n        total_steps *= p.n_iter\r\n\r\n        image_cell_count = p.n_iter * p.batch_size\r\n        cell_console_text = f\"; {image_cell_count} images per cell\" if image_cell_count > 1 else \"\"\r\n        plural_s = 's' if len(zs) > 1 else ''\r\n        print(f\"X/Y/Z plot will create {len(xs) * len(ys) * len(zs) * image_cell_count} images on {len(zs)} {len(xs)}x{len(ys)} grid{plural_s}{cell_console_text}. (Total steps to process: {total_steps})\")\r\n        shared.total_tqdm.updateTotal(total_steps)\r\n\r\n        state.xyz_plot_x = AxisInfo(x_opt, xs)\r\n        state.xyz_plot_y = AxisInfo(y_opt, ys)\r\n        state.xyz_plot_z = AxisInfo(z_opt, zs)\r\n\r\n        # If one of the axes is very slow to change between (like SD model\r\n        # checkpoint), then make sure it is in the outer iteration of the nested\r\n        # `for` loop.\r\n        first_axes_processed = 'z'\r\n        second_axes_processed = 'y'\r\n        if x_opt.cost > y_opt.cost and x_opt.cost > z_opt.cost:\r\n            first_axes_processed = 'x'\r\n            if y_opt.cost > z_opt.cost:\r\n                second_axes_processed = 'y'\r\n            else:\r\n                second_axes_processed = 'z'\r\n        elif y_opt.cost > x_opt.cost and y_opt.cost > z_opt.cost:\r\n            first_axes_processed = 'y'\r\n            if x_opt.cost > z_opt.cost:\r\n                second_axes_processed = 'x'\r\n            else:\r\n                second_axes_processed = 'z'\r\n        elif z_opt.cost > x_opt.cost and z_opt.cost > y_opt.cost:\r\n            first_axes_processed = 'z'\r\n            if x_opt.cost > y_opt.cost:\r\n                second_axes_processed = 'x'\r\n            else:\r\n                second_axes_processed = 'y'\r\n\r\n        grid_infotext = [None] * (1 + len(zs))\r\n\r\n        def cell(x, y, z, ix, iy, iz):\r\n            if shared.state.interrupted or state.stopping_generation:\r\n                return Processed(p, [], p.seed, \"\")\r\n\r\n            pc = copy(p)\r\n            pc.styles = pc.styles[:]\r\n            x_opt.apply(pc, x, xs)\r\n            y_opt.apply(pc, y, ys)\r\n            z_opt.apply(pc, z, zs)\r\n\r\n            xdim = len(xs) if vary_seeds_x else 1\r\n            ydim = len(ys) if vary_seeds_y else 1\r\n\r\n            if vary_seeds_x:\r\n                pc.seed += ix\r\n            if vary_seeds_y:\r\n                pc.seed += iy * xdim\r\n            if vary_seeds_z:\r\n                pc.seed += iz * xdim * ydim\r\n\r\n            try:\r\n                res = process_images(pc)\r\n            except Exception as e:\r\n                errors.display(e, \"generating image for xyz plot\")\r\n\r\n                res = Processed(p, [], p.seed, \"\")\r\n\r\n            # Sets subgrid infotexts\r\n            subgrid_index = 1 + iz\r\n            if grid_infotext[subgrid_index] is None and ix == 0 and iy == 0:\r\n                pc.extra_generation_params = copy(pc.extra_generation_params)\r\n                pc.extra_generation_params['Script'] = self.title()\r\n\r\n                if x_opt.label != 'Nothing':\r\n                    pc.extra_generation_params[\"X Type\"] = x_opt.label\r\n                    pc.extra_generation_params[\"X Values\"] = x_values\r\n                    if x_opt.label in [\"Seed\", \"Var. seed\"] and not no_fixed_seeds:\r\n                        pc.extra_generation_params[\"Fixed X Values\"] = \", \".join([str(x) for x in xs])\r\n\r\n                if y_opt.label != 'Nothing':\r\n                    pc.extra_generation_params[\"Y Type\"] = y_opt.label\r\n                    pc.extra_generation_params[\"Y Values\"] = y_values\r\n                    if y_opt.label in [\"Seed\", \"Var. seed\"] and not no_fixed_seeds:\r\n                        pc.extra_generation_params[\"Fixed Y Values\"] = \", \".join([str(y) for y in ys])\r\n\r\n                grid_infotext[subgrid_index] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds)\r\n\r\n            # Sets main grid infotext\r\n            if grid_infotext[0] is None and ix == 0 and iy == 0 and iz == 0:\r\n                pc.extra_generation_params = copy(pc.extra_generation_params)\r\n\r\n                if z_opt.label != 'Nothing':\r\n                    pc.extra_generation_params[\"Z Type\"] = z_opt.label\r\n                    pc.extra_generation_params[\"Z Values\"] = z_values\r\n                    if z_opt.label in [\"Seed\", \"Var. seed\"] and not no_fixed_seeds:\r\n                        pc.extra_generation_params[\"Fixed Z Values\"] = \", \".join([str(z) for z in zs])\r\n\r\n                grid_infotext[0] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds)\r\n\r\n            return res\r\n\r\n        with SharedSettingsStackHelper():\r\n            processed = draw_xyz_grid(\r\n                p,\r\n                xs=xs,\r\n                ys=ys,\r\n                zs=zs,\r\n                x_labels=[x_opt.format_value(p, x_opt, x) for x in xs],\r\n                y_labels=[y_opt.format_value(p, y_opt, y) for y in ys],\r\n                z_labels=[z_opt.format_value(p, z_opt, z) for z in zs],\r\n                cell=cell,\r\n                draw_legend=draw_legend,\r\n                include_lone_images=include_lone_images,\r\n                include_sub_grids=include_sub_grids,\r\n                first_axes_processed=first_axes_processed,\r\n                second_axes_processed=second_axes_processed,\r\n                margin_size=margin_size\r\n            )\r\n\r\n        if not processed.images:\r\n            # It broke, no further handling needed.\r\n            return processed\r\n\r\n        z_count = len(zs)\r\n\r\n        # Set the grid infotexts to the real ones with extra_generation_params (1 main grid + z_count sub-grids)\r\n        processed.infotexts[:1 + z_count] = grid_infotext[:1 + z_count]\r\n\r\n        if not include_lone_images:\r\n            # Don't need sub-images anymore, drop from list:\r\n            processed.images = processed.images[:z_count + 1]\r\n\r\n        if opts.grid_save:\r\n            # Auto-save main and sub-grids:\r\n            grid_count = z_count + 1 if z_count > 1 else 1\r\n            for g in range(grid_count):\r\n                # TODO: See previous comment about intentional data misalignment.\r\n                adj_g = g - 1 if g > 0 else g\r\n                images.save_image(processed.images[g], p.outpath_grids, \"xyz_grid\", info=processed.infotexts[g], extension=opts.grid_format, prompt=processed.all_prompts[adj_g], seed=processed.all_seeds[adj_g], grid=True, p=processed)\r\n                if not include_sub_grids:  # if not include_sub_grids then skip saving after the first grid\r\n                    break\r\n\r\n        if not include_sub_grids:\r\n            # Done with sub-grids, drop all related information:\r\n            for _ in range(z_count):\r\n                del processed.images[1]\r\n                del processed.all_prompts[1]\r\n                del processed.all_seeds[1]\r\n                del processed.infotexts[1]\r\n\r\n        return processed\r\n"
  },
  {
    "path": "style.css",
    "content": "/* temporary fix to load default gradio font in frontend instead of backend */\r\n\r\n@import url('webui-assets/css/sourcesanspro.css');\r\n\r\n\r\n/* temporary fix to hide gradio crop tool until it's fixed https://github.com/gradio-app/gradio/issues/3810 */\r\n\r\ndiv.gradio-image button[aria-label=\"Edit\"] {\r\n    display: none;\r\n}\r\n\r\n\r\n/* general gradio fixes */\r\n\r\n:root, .dark{\r\n    --checkbox-label-gap: 0.25em 0.1em;\r\n    --section-header-text-size: 12pt;\r\n    --block-background-fill: transparent;\r\n\r\n}\r\n\r\n.block.padded:not(.gradio-accordion) {\r\n    padding: 0 !important;\r\n}\r\n\r\ndiv.gradio-container{\r\n    max-width: unset !important;\r\n}\r\n\r\n.hidden{\r\n    display: none !important;\r\n}\r\n\r\n.compact{\r\n    background: transparent !important;\r\n    padding: 0 !important;\r\n}\r\n\r\ndiv.form{\r\n    border-width: 0;\r\n    box-shadow: none;\r\n    background: transparent;\r\n    overflow: visible;\r\n    gap: 0.5em;\r\n}\r\n\r\n.block.gradio-dropdown,\r\n.block.gradio-slider,\r\n.block.gradio-checkbox,\r\n.block.gradio-textbox,\r\n.block.gradio-radio,\r\n.block.gradio-checkboxgroup,\r\n.block.gradio-number,\r\n.block.gradio-colorpicker {\r\n    border-width: 0 !important;\r\n    box-shadow: none !important;\r\n}\r\n\r\ndiv.gradio-group, div.styler{\r\n    border-width: 0 !important;\r\n    background: none;\r\n}\r\n.gap.compact{\r\n    padding: 0;\r\n    gap: 0.2em 0;\r\n}\r\n\r\ndiv.compact{\r\n    gap: 1em;\r\n}\r\n\r\n.gradio-dropdown label span:not(.has-info),\r\n.gradio-textbox label span:not(.has-info),\r\n.gradio-number label span:not(.has-info)\r\n{\r\n    margin-bottom: 0;\r\n}\r\n\r\n.gradio-dropdown ul.options{\r\n    z-index: 3000;\r\n    min-width: fit-content;\r\n    max-width: inherit;\r\n    white-space: nowrap;\r\n}\r\n\r\n@media (pointer:fine) {\r\n    .gradio-dropdown ul.options li.item {\r\n        padding: 0.05em 0;\r\n    }\r\n}\r\n\r\n.gradio-dropdown ul.options li.item.selected {\r\n    background-color: var(--neutral-100);\r\n}\r\n\r\n.dark .gradio-dropdown ul.options li.item.selected {\r\n    background-color: var(--neutral-900);\r\n}\r\n\r\n.gradio-dropdown div.wrap.wrap.wrap.wrap{\r\n    box-shadow: 0 1px 2px 0 rgba(0, 0, 0, 0.05);\r\n}\r\n\r\n.gradio-dropdown:not(.multiselect) .wrap-inner.wrap-inner.wrap-inner{\r\n    flex-wrap: unset;\r\n}\r\n\r\n.gradio-dropdown .single-select{\r\n    white-space: nowrap;\r\n    overflow: hidden;\r\n}\r\n\r\n.gradio-dropdown .token-remove.remove-all.remove-all{\r\n    display: none;\r\n}\r\n\r\n.gradio-dropdown.multiselect .token-remove.remove-all.remove-all{\r\n    display: flex;\r\n}\r\n\r\n.gradio-slider input[type=\"number\"]{\r\n    width: 6em;\r\n}\r\n\r\n.block.gradio-checkbox {\r\n    margin: 0.75em 1.5em 0 0;\r\n}\r\n\r\n.gradio-html div.wrap{\r\n    height: 100%;\r\n}\r\ndiv.gradio-html.min{\r\n    min-height: 0;\r\n}\r\n\r\n.block.gradio-gallery{\r\n    background: var(--input-background-fill);\r\n}\r\n\r\n.gradio-container .prose a, .gradio-container .prose a:visited{\r\n    color: unset;\r\n    text-decoration: none;\r\n}\r\n\r\na{\r\n    font-weight: bold;\r\n    cursor: pointer;\r\n}\r\n\r\n/* gradio 3.39 puts a lot of overflow: hidden all over the place for an unknown reason. */\r\ndiv.gradio-container, .block.gradio-textbox, div.gradio-group, div.gradio-dropdown{\r\n    overflow: visible !important;\r\n}\r\n\r\n/* align-items isn't enough and elements may overflow in Safari. */\r\n.unequal-height {\r\n    align-content: flex-start;\r\n}\r\n\r\n\r\n/* general styled components */\r\n\r\n.gradio-button.tool{\r\n    max-width: 2.2em;\r\n    min-width: 2.2em !important;\r\n    height: 2.4em;\r\n    align-self: end;\r\n    line-height: 1em;\r\n    border-radius: 0.5em;\r\n}\r\n\r\n.gradio-button.secondary-down{\r\n    background: var(--button-secondary-background-fill);\r\n    color: var(--button-secondary-text-color);\r\n}\r\n.gradio-button.secondary-down, .gradio-button.secondary-down:hover{\r\n    box-shadow: 1px 1px 1px rgba(0,0,0,0.25) inset, 0px 0px 3px rgba(0,0,0,0.15) inset;\r\n}\r\n.gradio-button.secondary-down:hover{\r\n    background: var(--button-secondary-background-fill-hover);\r\n    color: var(--button-secondary-text-color-hover);\r\n}\r\n\r\nbutton.custom-button{\r\n    border-radius: var(--button-large-radius);\r\n    padding: var(--button-large-padding);\r\n    font-weight: var(--button-large-text-weight);\r\n    border: var(--button-border-width) solid var(--button-secondary-border-color);\r\n    background: var(--button-secondary-background-fill);\r\n    color: var(--button-secondary-text-color);\r\n    font-size: var(--button-large-text-size);\r\n    display: inline-flex;\r\n    justify-content: center;\r\n    align-items: center;\r\n    transition: var(--button-transition);\r\n    box-shadow: var(--button-shadow);\r\n    text-align: center;\r\n}\r\n\r\ndiv.block.gradio-accordion {\r\n    border: 1px solid var(--block-border-color) !important;\r\n    border-radius: 8px !important;\r\n    margin: 2px 0;\r\n    padding: 8px 8px;\r\n}\r\n\r\ninput[type=\"checkbox\"].input-accordion-checkbox{\r\n    vertical-align: sub;\r\n    margin-right: 0.5em;\r\n}\r\n\r\n\r\n/* txt2img/img2img specific */\r\n\r\n.block.token-counter{\r\n    position: absolute;\r\n    display: inline-block;\r\n    right: 1em;\r\n    min-width: 0 !important;\r\n    width: auto;\r\n    z-index: 100;\r\n    top: -0.75em;\r\n}\r\n\r\n.block.token-counter-visible{\r\n    display: block !important;\r\n}\r\n\r\n.block.token-counter span{\r\n    background: var(--input-background-fill) !important;\r\n    box-shadow: 0 0 0.0 0.3em rgba(192,192,192,0.15), inset 0 0 0.6em rgba(192,192,192,0.075);\r\n    border: 2px solid rgba(192,192,192,0.4) !important;\r\n    border-radius: 0.4em;\r\n}\r\n\r\n.block.token-counter.error span{\r\n    box-shadow: 0 0 0.0 0.3em rgba(255,0,0,0.15), inset 0 0 0.6em rgba(255,0,0,0.075);\r\n    border: 2px solid rgba(255,0,0,0.4) !important;\r\n}\r\n\r\n.block.token-counter div{\r\n    display: inline;\r\n}\r\n\r\n.block.token-counter span{\r\n    padding: 0.1em 0.75em;\r\n}\r\n\r\n[id$=_subseed_show]{\r\n    min-width: auto !important;\r\n    flex-grow: 0 !important;\r\n    display: flex;\r\n}\r\n\r\n[id$=_subseed_show] label{\r\n    margin-bottom: 0.65em;\r\n    align-self: end;\r\n}\r\n\r\n[id$=_seed_extras] > div{\r\n    gap: 0.5em;\r\n}\r\n\r\n.html-log .comments{\r\n    padding-top: 0.5em;\r\n}\r\n\r\n.html-log .comments:empty{\r\n    padding-top: 0;\r\n}\r\n\r\n.html-log .performance {\r\n    font-size: 0.85em;\r\n    color: #444;\r\n    display: flex;\r\n}\r\n\r\n.html-log .performance p{\r\n    display: inline-block;\r\n}\r\n\r\n.html-log .performance p.time, .performance p.vram, .performance p.profile, .performance p.time abbr, .performance p.vram abbr {\r\n    margin-bottom: 0;\r\n    color: var(--block-title-text-color);\r\n}\r\n\r\n.html-log .performance p.time {\r\n}\r\n\r\n.html-log .performance p.vram {\r\n    margin-left: auto;\r\n}\r\n\r\n.html-log .performance p.profile {\r\n    margin-left: 0.5em;\r\n}\r\n\r\n.html-log .performance .measurement{\r\n    color: var(--body-text-color);\r\n    font-weight: bold;\r\n}\r\n\r\n#txt2img_generate, #img2img_generate {\r\n    min-height: 4.5em;\r\n}\r\n\r\n#txt2img_generate, #img2img_generate {\r\n    min-height: 4.5em;\r\n}\r\n.generate-box-compact #txt2img_generate, .generate-box-compact #img2img_generate {\r\n    min-height: 3em;\r\n}\r\n\r\n@media screen and (min-width: 2500px) {\r\n    #txt2img_gallery, #img2img_gallery {\r\n        min-height: 768px;\r\n    }\r\n}\r\n\r\n.gradio-gallery .thumbnails img {\r\n    object-fit: scale-down !important;\r\n}\r\n#txt2img_actions_column, #img2img_actions_column {\r\n    gap: 0.5em;\r\n}\r\n#txt2img_tools, #img2img_tools{\r\n    gap: 0.4em;\r\n}\r\n\r\n.interrogate-col{\r\n    min-width: 0 !important;\r\n    max-width: fit-content;\r\n    gap: 0.5em;\r\n}\r\n.interrogate-col > button{\r\n    flex: 1;\r\n}\r\n\r\n.generate-box{\r\n    position: relative;\r\n}\r\n.gradio-button.generate-box-skip, .gradio-button.generate-box-interrupt, .gradio-button.generate-box-interrupting{\r\n    position: absolute;\r\n    width: 50%;\r\n    height: 100%;\r\n    display: none;\r\n    background: #b4c0cc;\r\n}\r\n.gradio-button.generate-box-skip:hover, .gradio-button.generate-box-interrupt:hover, .gradio-button.generate-box-interrupting:hover{\r\n    background: #c2cfdb;\r\n}\r\n.gradio-button.generate-box-interrupt, .gradio-button.generate-box-interrupting{\r\n    left: 0;\r\n    border-radius: 0.5rem 0 0 0.5rem;\r\n}\r\n.gradio-button.generate-box-skip{\r\n    right: 0;\r\n    border-radius: 0 0.5rem 0.5rem 0;\r\n}\r\n\r\n#img2img_scale_resolution_preview.block{\r\n    display: flex;\r\n    align-items: end;\r\n}\r\n\r\n#txtimg_hr_finalres .resolution, #img2img_scale_resolution_preview .resolution{\r\n    font-weight: bold;\r\n}\r\n\r\n#txtimg_hr_finalres div.pending, #img2img_scale_resolution_preview div.pending {\r\n    opacity: 1;\r\n    transition: opacity 0s;\r\n}\r\n\r\n.inactive{\r\n    opacity: 0.5;\r\n}\r\n\r\n[id$=_column_batch]{\r\n    min-width: min(13.5em, 100%) !important;\r\n}\r\n\r\ndiv.dimensions-tools{\r\n    min-width: 1.6em !important;\r\n    max-width: fit-content;\r\n    flex-direction: column;\r\n    place-content: center;\r\n}\r\n\r\ndiv#extras_scale_to_tab div.form{\r\n    flex-direction: row;\r\n}\r\n\r\n#img2img_sketch, #img2maskimg, #inpaint_sketch {\r\n    overflow: overlay !important;\r\n    resize: auto;\r\n    background: var(--panel-background-fill);\r\n    z-index: 5;\r\n}\r\n\r\n.image-buttons > .form{\r\n    justify-content: center;\r\n}\r\n\r\n.infotext {\r\n    overflow-wrap: break-word;\r\n}\r\n\r\n#img2img_column_batch{\r\n    align-self: end;\r\n    margin-bottom: 0.9em;\r\n}\r\n\r\n#img2img_unused_scale_by_slider{\r\n    visibility: hidden;\r\n    width: 0.5em;\r\n    max-width: 0.5em;\r\n    min-width: 0.5em;\r\n}\r\n\r\ndiv.toprow-compact-stylerow{\r\n    margin: 0.5em 0;\r\n}\r\n\r\ndiv.toprow-compact-tools{\r\n    min-width: fit-content !important;\r\n    max-width: fit-content;\r\n}\r\n\r\n/* settings */\r\n#quicksettings {\r\n    align-items: end;\r\n}\r\n\r\n#quicksettings > div, #quicksettings > fieldset{\r\n    max-width: 36em;\r\n    width: fit-content;\r\n    flex: 0 1 fit-content;\r\n    padding: 0;\r\n    border: none;\r\n    box-shadow: none;\r\n    background: none;\r\n}\r\n#quicksettings > div.gradio-dropdown{\r\n    min-width: 24em !important;\r\n}\r\n\r\n#settings{\r\n    display: block;\r\n}\r\n\r\n#settings > div{\r\n    border: none;\r\n    margin-left: 10em;\r\n    padding: 0 var(--spacing-xl);\r\n}\r\n\r\n#settings > div.tab-nav{\r\n    float: left;\r\n    display: block;\r\n    margin-left: 0;\r\n    width: 10em;\r\n}\r\n\r\n#settings > div.tab-nav button{\r\n    display: block;\r\n    border: none;\r\n    text-align: left;\r\n    white-space: initial;\r\n    padding: 4px;\r\n}\r\n\r\n#settings > div.tab-nav .settings-category{\r\n    display: block;\r\n    margin: 1em 0 0.25em 0;\r\n    font-weight: bold;\r\n    text-decoration: underline;\r\n    cursor: default;\r\n    user-select: none;\r\n}\r\n\r\n#settings_result{\r\n    height: 1.4em;\r\n    margin: 0 1.2em;\r\n}\r\n\r\ntable.popup-table{\r\n    background: var(--body-background-fill);\r\n    color: var(--body-text-color);\r\n    border-collapse: collapse;\r\n    margin: 1em;\r\n    border: 4px solid var(--body-background-fill);\r\n}\r\n\r\ntable.popup-table td{\r\n    padding: 0.4em;\r\n    border: 1px solid rgba(128, 128, 128, 0.5);\r\n    max-width: 36em;\r\n}\r\n\r\ntable.popup-table .muted{\r\n    color: #aaa;\r\n}\r\n\r\ntable.popup-table .link{\r\n    text-decoration: underline;\r\n    cursor: pointer;\r\n    font-weight: bold;\r\n}\r\n\r\n.ui-defaults-none{\r\n    color: #aaa !important;\r\n}\r\n\r\n#settings span{\r\n    color: var(--body-text-color);\r\n}\r\n\r\n#settings .gradio-textbox, #settings .gradio-slider, #settings .gradio-number, #settings .gradio-dropdown, #settings .gradio-checkboxgroup, #settings .gradio-radio{\r\n    margin-top: 0.75em;\r\n}\r\n\r\n#settings span .settings-comment {\r\n    display: inline\r\n}\r\n\r\n.settings-comment a{\r\n    text-decoration: underline;\r\n}\r\n\r\n.settings-comment .info{\r\n    opacity: 0.75;\r\n}\r\n\r\n.settings-comment .info ol{\r\n    margin: 0.4em 0 0.8em 1em;\r\n}\r\n\r\n#sysinfo_download a.sysinfo_big_link{\r\n    font-size: 24pt;\r\n}\r\n\r\n#sysinfo_download a{\r\n    text-decoration: underline;\r\n}\r\n\r\n#sysinfo_validity{\r\n    font-size: 18pt;\r\n}\r\n\r\n#settings .settings-info{\r\n    max-width: 48em;\r\n    border: 1px dotted #777;\r\n    margin: 0;\r\n    padding: 1em;\r\n}\r\n\r\n\r\n/* live preview */\r\n.progressDiv{\r\n    position: absolute;\r\n    height: 20px;\r\n    background: #b4c0cc;\r\n    border-radius: 3px !important;\r\n    top: -14px;\r\n    left: 0px;\r\n    width: 100%;\r\n}\r\n\r\n.progress-container{\r\n    position: relative;\r\n}\r\n\r\n[id$=_results].mobile{\r\n    margin-top: 28px;\r\n}\r\n\r\n.dark .progressDiv{\r\n    background: #424c5b;\r\n}\r\n\r\n.progressDiv .progress{\r\n    width: 0%;\r\n    height: 20px;\r\n    background: #0060df;\r\n    color: white;\r\n    font-weight: bold;\r\n    line-height: 20px;\r\n    padding: 0 8px 0 0;\r\n    text-align: right;\r\n    border-radius: 3px;\r\n    overflow: visible;\r\n    white-space: nowrap;\r\n    padding: 0 0.5em;\r\n}\r\n\r\n.livePreview{\r\n    position: absolute;\r\n    z-index: 300;\r\n    background: var(--background-fill-primary);\r\n    width: 100%;\r\n    height: 100%;\r\n}\r\n\r\n.livePreview img{\r\n    position: absolute;\r\n    object-fit: contain;\r\n    width: 100%;\r\n    height: calc(100% - 60px);  /* to match gradio's height */\r\n}\r\n\r\n/* fullscreen popup (ie in Lora's (i) button) */\r\n\r\n.popup-metadata{\r\n    color: black;\r\n    background: white;\r\n    display: inline-block;\r\n    padding: 1em;\r\n    white-space: pre-wrap;\r\n}\r\n\r\n.global-popup{\r\n    display: flex;\r\n    position: fixed;\r\n    z-index: 1001;\r\n    left: 0;\r\n    top: 0;\r\n    width: 100%;\r\n    height: 100%;\r\n    overflow: auto;\r\n}\r\n\r\n.global-popup *{\r\n    box-sizing: border-box;\r\n}\r\n\r\n.global-popup-close:before {\r\n    content: \"×\";\r\n    position: fixed;\r\n    right: 0.25em;\r\n    top: 0;\r\n    cursor: pointer;\r\n    color: white;\r\n    font-size: 32pt;\r\n}\r\n\r\n.global-popup-close{\r\n    position: fixed;\r\n    left: 0;\r\n    top: 0;\r\n    width: 100%;\r\n    height: 100%;\r\n    background-color: rgba(20, 20, 20, 0.95);\r\n}\r\n\r\n.global-popup-inner{\r\n    display: inline-block;\r\n    margin: auto;\r\n    padding: 2em;\r\n    z-index: 1001;\r\n    max-height: 90%;\r\n    max-width: 90%;\r\n}\r\n\r\n/* fullpage image viewer */\r\n\r\n#lightboxModal{\r\n    display: none;\r\n    position: fixed;\r\n    z-index: 1001;\r\n    left: 0;\r\n    top: 0;\r\n    width: 100%;\r\n    height: 100%;\r\n    overflow: auto;\r\n    background-color: rgba(20, 20, 20, 0.95);\r\n    user-select: none;\r\n    -webkit-user-select: none;\r\n    flex-direction: column;\r\n}\r\n\r\n.modalControls {\r\n    display: flex;\r\n    position: absolute;\r\n    right: 0px;\r\n    left: 0px;\r\n    gap: 1em;\r\n    padding: 1em;\r\n    background-color:rgba(0,0,0,0);\r\n    z-index: 1;\r\n    transition: 0.2s ease background-color;\r\n}\r\n.modalControls:hover {\r\n    background-color:rgba(0,0,0, var(--sd-webui-modal-lightbox-toolbar-opacity));\r\n}\r\n.modalClose {\r\n    margin-left: auto;\r\n}\r\n.modalControls span{\r\n    color: white;\r\n    text-shadow: 0px 0px 0.25em black;\r\n    font-size: 35px;\r\n    font-weight: bold;\r\n    cursor: pointer;\r\n    width: 1em;\r\n}\r\n\r\n.modalControls span:hover, .modalControls span:focus{\r\n    color: #999;\r\n    text-decoration: none;\r\n}\r\n\r\n#lightboxModal > img {\r\n    display: block;\r\n    margin: auto;\r\n    width: auto;\r\n}\r\n\r\n#lightboxModal > img.modalImageFullscreen{\r\n    object-fit: contain;\r\n    height: 100%;\r\n    width: 100%;\r\n    min-height: 0;\r\n}\r\n\r\n.modalPrev,\r\n.modalNext {\r\n  cursor: pointer;\r\n  position: absolute;\r\n  top: 50%;\r\n  width: auto;\r\n  padding: 16px;\r\n  margin-top: -50px;\r\n  color: white;\r\n  font-weight: bold;\r\n  font-size: 20px;\r\n  transition: 0.6s ease;\r\n  border-radius: 0 3px 3px 0;\r\n  user-select: none;\r\n  -webkit-user-select: none;\r\n}\r\n\r\n.modalNext {\r\n  right: 0;\r\n  border-radius: 3px 0 0 3px;\r\n}\r\n\r\n.modalPrev:hover,\r\n.modalNext:hover {\r\n  background-color: rgba(0, 0, 0, 0.8);\r\n}\r\n\r\n#imageARPreview {\r\n    position: absolute;\r\n    top: 0px;\r\n    left: 0px;\r\n    border: 2px solid red;\r\n    background: rgba(255, 0, 0, 0.3);\r\n    z-index: 900;\r\n    pointer-events: none;\r\n    display: none;\r\n}\r\n\r\n@media (pointer: fine) {\r\n    .modalPrev:hover,\r\n    .modalNext:hover,\r\n    .modalControls:hover ~ .modalPrev,\r\n    .modalControls:hover ~ .modalNext,\r\n    .modalControls:hover .cursor {\r\n        opacity: 1;\r\n    }\r\n\r\n    .modalPrev,\r\n    .modalNext,\r\n    .modalControls .cursor {\r\n        opacity: var(--sd-webui-modal-lightbox-icon-opacity);\r\n    }\r\n}\r\n\r\n/* context menu (ie for the generate button) */\r\n\r\n#context-menu{\r\n    z-index:9999;\r\n    position:absolute;\r\n    display:block;\r\n    padding:0px 0;\r\n    border:2px solid var(--primary-800);\r\n    border-radius:8px;\r\n    box-shadow:1px 1px 2px var(--primary-500);\r\n    width: 200px;\r\n}\r\n\r\n.context-menu-items{\r\n    list-style: none;\r\n    margin: 0;\r\n    padding: 0;\r\n}\r\n\r\n.context-menu-items a{\r\n    display:block;\r\n    padding:5px;\r\n    cursor:pointer;\r\n}\r\n\r\n.context-menu-items a:hover{\r\n    background: var(--primary-700);\r\n}\r\n\r\n\r\n/* extensions */\r\n\r\n#tab_extensions table{\r\n    border-collapse: collapse;\r\n    overflow-x: auto;\r\n    display: block;\r\n}\r\n\r\n#tab_extensions table td, #tab_extensions table th{\r\n    border: 1px solid #ccc;\r\n    padding: 0.25em 0.5em;\r\n}\r\n\r\n#tab_extensions table input[type=\"checkbox\"]{\r\n    margin-right: 0.5em;\r\n    appearance: checkbox;\r\n}\r\n\r\n#tab_extensions button{\r\n    max-width: 16em;\r\n}\r\n\r\n#tab_extensions input[disabled=\"disabled\"]{\r\n    opacity: 0.5;\r\n}\r\n\r\n.extension-tag{\r\n    font-weight: bold;\r\n    font-size: 95%;\r\n}\r\n\r\n#available_extensions .info{\r\n    margin: 0;\r\n}\r\n\r\n#available_extensions .info{\r\n    margin: 0.5em 0;\r\n    display: flex;\r\n    margin-top: auto;\r\n    opacity: 0.80;\r\n    font-size: 90%;\r\n}\r\n\r\n#available_extensions .date_added{\r\n    margin-right: auto;\r\n    display: inline-block;\r\n}\r\n\r\n#available_extensions .star_count{\r\n    margin-left: auto;\r\n    display: inline-block;\r\n}\r\n\r\n.compact-checkbox-group  div label {\r\n    padding: 0.1em 0.3em !important;\r\n}\r\n\r\n/* extensions tab table row hover highlight */\r\n\r\n#extensions tr:hover td,\r\n#config_state_extensions tr:hover td,\r\n#available_extensions tr:hover td {\r\n    background: rgba(0, 0, 0, 0.15);\r\n}\r\n\r\n.dark #extensions tr:hover td ,\r\n.dark #config_state_extensions tr:hover td ,\r\n.dark #available_extensions tr:hover td {\r\n    background: rgba(255, 255, 255, 0.15);\r\n}\r\n\r\n/* replace original footer with ours */\r\n\r\nfooter {\r\n    display: none !important;\r\n}\r\n\r\n#footer{\r\n    text-align: center;\r\n}\r\n\r\n#footer div{\r\n    display: inline-block;\r\n}\r\n\r\n#footer .versions{\r\n    font-size: 85%;\r\n    opacity: 0.85;\r\n}\r\n\r\n/* extra networks UI */\r\n\r\n.extra-page > div.gap{\r\n    gap: 0;\r\n}\r\n\r\n.extra-page-prompts{\r\n    margin-bottom: 0;\r\n}\r\n\r\n.extra-page-prompts.extra-page-prompts-active{\r\n    margin-bottom: 1em;\r\n}\r\n\r\n.extra-networks > div.tab-nav{\r\n    min-height: 2.7rem;\r\n}\r\n\r\n.extra-networks-controls-div{\r\n    align-self: center;\r\n    margin-left: auto;\r\n}\r\n\r\n.extra-networks > div > [id *= '_extra_']{\r\n    margin: 0.3em;\r\n}\r\n\r\n.extra-networks .tab-nav .search,\r\n.extra-networks .tab-nav .sort\r\n{\r\n    margin: 0.3em;\r\n    align-self: center;\r\n    width: auto;\r\n}\r\n\r\n.extra-networks .tab-nav .search {\r\n    width: 16em;\r\n    max-width: 16em;\r\n}\r\n\r\n.extra-networks .tab-nav .sort {\r\n    width: 12em;\r\n    max-width: 12em;\r\n}\r\n\r\n#txt2img_extra_view, #img2img_extra_view {\r\n    width: auto;\r\n}\r\n\r\n.extra-network-pane .nocards{\r\n    margin: 1.25em 0.5em 0.5em 0.5em;\r\n}\r\n\r\n.extra-network-pane .nocards h1{\r\n    font-size: 1.5em;\r\n    margin-bottom: 1em;\r\n}\r\n\r\n.extra-network-pane .nocards li{\r\n    margin-left: 0.5em;\r\n}\r\n\r\n.extra-network-pane .card .button-row{\r\n    display: inline-flex;\r\n    visibility: hidden;\r\n    color: white;\r\n}\r\n\r\n.extra-network-pane .card .button-row {\r\n    position: absolute;\r\n    right: 0;\r\n    z-index: 1;\r\n}\r\n\r\n.extra-network-pane .card:hover .button-row{\r\n    visibility: visible;\r\n}\r\n\r\n.extra-network-pane .card-button{\r\n    color: white;\r\n}\r\n\r\n.extra-network-pane .copy-path-button::before {\r\n    content: \"⎘\";\r\n}\r\n\r\n.extra-network-pane .metadata-button::before{\r\n    content: \"🛈\";\r\n}\r\n\r\n.extra-network-pane .edit-button::before{\r\n    content: \"🛠\";\r\n}\r\n\r\n.extra-network-pane .card-button {\r\n    width: 1.5em;\r\n    text-shadow: 2px 2px 3px black;\r\n    color: white;\r\n    padding: 0.25em 0.1em;\r\n}\r\n\r\n.extra-network-pane .card-button:hover{\r\n    color: red;\r\n}\r\n\r\n.extra-network-pane .card .card-button {\r\n    font-size: 2rem;\r\n}\r\n\r\n.extra-network-pane .card-minimal .card-button {\r\n    font-size: 1rem;\r\n}\r\n\r\n.standalone-card-preview.card .preview{\r\n    position: absolute;\r\n    object-fit: cover;\r\n    width: 100%;\r\n    height:100%;\r\n}\r\n\r\n.extra-network-pane .card, .standalone-card-preview.card{\r\n    display: inline-block;\r\n    margin: 0.5rem;\r\n    width: 16rem;\r\n    height: 24rem;\r\n    box-shadow: 0 0 5px rgba(128, 128, 128, 0.5);\r\n    border-radius: 0.2rem;\r\n    position: relative;\r\n\r\n    background-size: auto 100%;\r\n    background-position: center;\r\n    overflow: hidden;\r\n    cursor: pointer;\r\n\r\n    background-image: url('./file=html/card-no-preview.png')\r\n}\r\n\r\n.extra-network-pane .card:hover{\r\n    box-shadow: 0 0 2px 0.3em rgba(0, 128, 255, 0.35);\r\n}\r\n\r\n.extra-network-pane .card .actions .additional{\r\n    display: none;\r\n}\r\n\r\n.extra-network-pane .card .actions{\r\n    position: absolute;\r\n    bottom: 0;\r\n    left: 0;\r\n    right: 0;\r\n    padding: 0.5em;\r\n    background: rgba(0,0,0,0.5);\r\n    box-shadow: 0 0 0.25em 0.25em rgba(0,0,0,0.5);\r\n    text-shadow: 0 0 0.2em black;\r\n}\r\n\r\n.extra-network-pane .card .actions *{\r\n    color: white;\r\n}\r\n\r\n.extra-network-pane .card .actions .name{\r\n    font-size: 1.7em;\r\n    font-weight: bold;\r\n    line-break: anywhere;\r\n}\r\n\r\n.extra-network-pane .card .actions .description {\r\n    display: block;\r\n    max-height: 3em;\r\n    white-space: pre-wrap;\r\n    line-height: 1.1;\r\n}\r\n\r\n.extra-network-pane .card .actions .description:hover {\r\n    max-height: none;\r\n}\r\n\r\n.extra-network-pane .card .actions:hover .additional{\r\n    display: block;\r\n}\r\n\r\n.extra-network-pane .card ul{\r\n    margin: 0.25em 0 0.75em 0.25em;\r\n    cursor: unset;\r\n}\r\n\r\n.extra-network-pane .card ul a{\r\n    cursor: pointer;\r\n}\r\n\r\n.extra-network-pane .card ul a:hover{\r\n    color: red;\r\n}\r\n\r\n.extra-network-pane .card .preview{\r\n    position: absolute;\r\n    object-fit: cover;\r\n    width: 100%;\r\n    height:100%;\r\n}\r\n\r\ndiv.block.gradio-box.edit-user-metadata {\r\n    width: 56em;\r\n    background: var(--body-background-fill);\r\n    padding: 2em !important;\r\n}\r\n\r\n.edit-user-metadata .extra-network-name{\r\n    font-size: 18pt;\r\n    color: var(--body-text-color);\r\n}\r\n\r\n.edit-user-metadata .file-metadata{\r\n    color: var(--body-text-color);\r\n}\r\n\r\n.edit-user-metadata .file-metadata th{\r\n    text-align: left;\r\n}\r\n\r\n.edit-user-metadata .file-metadata th, .edit-user-metadata .file-metadata td{\r\n    padding: 0.3em 1em;\r\n    overflow-wrap: anywhere;\r\n    word-break: break-word;\r\n}\r\n\r\n.edit-user-metadata .wrap.translucent{\r\n    background: var(--body-background-fill);\r\n}\r\n.edit-user-metadata .gradio-highlightedtext span{\r\n    word-break: break-word;\r\n}\r\n\r\n.edit-user-metadata-buttons{\r\n    margin-top: 1.5em;\r\n}\r\n\r\ndiv.block.gradio-box.popup-dialog, .popup-dialog {\r\n    width: 56em;\r\n    background: var(--body-background-fill);\r\n    padding: 2em !important;\r\n}\r\n\r\ndiv.block.gradio-box.popup-dialog > div:last-child, .popup-dialog > div:last-child{\r\n    margin-top: 1em;\r\n}\r\n\r\ndiv.block.input-accordion{\r\n\r\n}\r\n\r\n.input-accordion-extra{\r\n    flex: 0 0 auto !important;\r\n    margin: 0 0.5em 0 auto;\r\n}\r\n\r\ndiv.accordions > div.input-accordion{\r\n    min-width: fit-content !important;\r\n}\r\n\r\ndiv.accordions > div.gradio-accordion .label-wrap span{\r\n    white-space: nowrap;\r\n    margin-right: 0.25em;\r\n}\r\n\r\ndiv.accordions{\r\n    gap: 0.5em;\r\n}\r\n\r\ndiv.accordions > div.input-accordion.input-accordion-open{\r\n    flex: 1 auto;\r\n    flex-flow: column;\r\n}\r\n\r\n\r\n/* sticky right hand columns */\r\n\r\n#img2img_results, #txt2img_results, #extras_results {\r\n    position: sticky;\r\n    top: 0.5em;\r\n}\r\n\r\nbody.resizing {\r\n    cursor: col-resize !important;\r\n}\r\n\r\nbody.resizing * {\r\n    pointer-events: none !important;\r\n}\r\n\r\nbody.resizing .resize-handle {\r\n    pointer-events: initial !important;\r\n}\r\n\r\n.resize-handle {\r\n    position: relative;\r\n    cursor: col-resize;\r\n    grid-column: 2 / 3;\r\n    min-width: 16px !important;\r\n    max-width: 16px !important;\r\n    height: 100%;\r\n}\r\n\r\n.resize-handle::after {\r\n    content: '';\r\n    position: absolute;\r\n    top: 0;\r\n    bottom: 0;\r\n    left: 7.5px;\r\n    border-left: 1px dashed var(--border-color-primary);\r\n}\r\n\r\n/* ========================= */\r\n.extra-network-pane {\r\n    display: flex;\r\n    height: calc(100vh - 24rem);\r\n    resize: vertical;\r\n    min-height: 52rem;\r\n    flex-direction: column;\r\n    overflow: hidden;\r\n}\r\n\r\n.extra-network-pane .extra-network-pane-content-dirs {\r\n    display: flex;\r\n    flex: 1;\r\n    flex-direction: column;\r\n    overflow: hidden;\r\n}\r\n\r\n.extra-network-pane .extra-network-pane-content-tree {\r\n    display: flex;\r\n    flex: 1;\r\n    overflow: hidden;\r\n}\r\n\r\n.extra-network-dirs-hidden .extra-network-dirs{ display: none; }\r\n.extra-network-dirs-hidden .extra-network-tree{ display: none; }\r\n.extra-network-dirs-hidden .resize-handle { display: none; }\r\n.extra-network-dirs-hidden .resize-handle-row { display: flex !important; }\r\n\r\n.extra-network-pane .extra-network-tree {\r\n    flex: 1;\r\n    font-size: 1rem;\r\n    border: 1px solid var(--block-border-color);\r\n    overflow: clip auto !important;\r\n}\r\n\r\n.extra-network-pane .extra-network-cards {\r\n    flex: 3;\r\n    overflow: clip auto !important;\r\n    border: 1px solid var(--block-border-color);\r\n}\r\n\r\n.extra-network-pane .extra-network-tree .tree-list {\r\n    flex: 1;\r\n    display: flex;\r\n    flex-direction: column;\r\n    padding: 0;\r\n    width: 100%;\r\n    overflow: hidden;\r\n}\r\n\r\n\r\n.extra-network-pane .extra-network-cards::-webkit-scrollbar,\r\n.extra-network-pane .extra-network-tree::-webkit-scrollbar {\r\n    background-color: transparent;\r\n    width: 16px;\r\n}\r\n\r\n.extra-network-pane .extra-network-cards::-webkit-scrollbar-track,\r\n.extra-network-pane .extra-network-tree::-webkit-scrollbar-track {\r\n    background-color: transparent;\r\n    background-clip: content-box;\r\n}\r\n\r\n.extra-network-pane .extra-network-cards::-webkit-scrollbar-thumb,\r\n.extra-network-pane .extra-network-tree::-webkit-scrollbar-thumb {\r\n    background-color: var(--border-color-primary);\r\n    border-radius: 16px;\r\n    border: 4px solid var(--background-fill-primary);\r\n}\r\n\r\n.extra-network-pane .extra-network-cards::-webkit-scrollbar-button,\r\n.extra-network-pane .extra-network-tree::-webkit-scrollbar-button {\r\n    display: none;\r\n}\r\n\r\n.extra-network-control {\r\n    position: relative;\r\n    display: flex;\r\n    width: 100%;\r\n    padding: 0 !important;\r\n    margin-top: 0 !important;\r\n    margin-bottom: 0 !important;\r\n    font-size: 1rem;\r\n    text-align: left;\r\n    user-select: none;\r\n    background-color: transparent;\r\n    border: none;\r\n    transition: background 33.333ms linear;\r\n    grid-template-rows: min-content;\r\n    grid-template-columns: minmax(0, auto) repeat(4, min-content);\r\n    grid-gap: 0.1rem;\r\n    align-items: start;\r\n}\r\n\r\n.extra-network-control small{\r\n    color: var(--input-placeholder-color);\r\n    line-height: 2.2rem;\r\n    margin: 0 0.5rem 0 0.75rem;\r\n}\r\n\r\n.extra-network-tree .tree-list--tree {}\r\n\r\n/* Remove auto indentation from tree. Will be overridden later. */\r\n.extra-network-tree .tree-list--subgroup {\r\n    margin: 0 !important;\r\n    padding: 0 !important;\r\n    box-shadow: 0.5rem 0 0 var(--body-background-fill) inset,\r\n                0.7rem 0 0 var(--neutral-800) inset;\r\n}\r\n\r\n/* Set indentation for each depth of tree. */\r\n.extra-network-tree .tree-list--subgroup > .tree-list-item {\r\n    margin-left: 0.4rem !important;\r\n    padding-left: 0.4rem !important;\r\n}\r\n\r\n/* Styles for tree <li> elements. */\r\n.extra-network-tree .tree-list-item {\r\n    list-style: none;\r\n    position: relative;\r\n    background-color: transparent;\r\n}\r\n\r\n/* Directory <ul> visibility based on data-expanded attribute. */\r\n.extra-network-tree .tree-list-content+.tree-list--subgroup {\r\n    height: 0;\r\n    visibility: hidden;\r\n    opacity: 0;\r\n}\r\n\r\n.extra-network-tree .tree-list-content[data-expanded]+.tree-list--subgroup {\r\n    height: auto;\r\n    visibility: visible;\r\n    opacity: 1;\r\n}\r\n\r\n/* File <li> */\r\n.extra-network-tree .tree-list-item--subitem {\r\n    padding-top: 0 !important;\r\n    padding-bottom: 0 !important;\r\n    margin-top: 0 !important;\r\n    margin-bottom: 0 !important;\r\n}\r\n\r\n/* <li> containing <ul> */\r\n.extra-network-tree .tree-list-item--has-subitem {}\r\n\r\n/* BUTTON ELEMENTS */\r\n/* <button> */\r\n.extra-network-tree .tree-list-content {\r\n    position: relative;\r\n    display: grid;\r\n    width: 100%;\r\n    padding: 0 !important;\r\n    margin-top: 0 !important;\r\n    margin-bottom: 0 !important;\r\n    font-size: 1rem;\r\n    text-align: left;\r\n    user-select: none;\r\n    background-color: transparent;\r\n    border: none;\r\n    transition: background 33.333ms linear;\r\n    grid-template-rows: min-content;\r\n    grid-template-areas: \"leading-action leading-visual label trailing-visual trailing-action\";\r\n    grid-template-columns: min-content min-content minmax(0, auto) min-content min-content;\r\n    grid-gap: 0.1rem;\r\n    align-items: start;\r\n    flex-grow: 1;\r\n    flex-basis: 100%;\r\n}\r\n/* Buttons for directories. */\r\n.extra-network-tree .tree-list-content-dir {}   \r\n\r\n/* Buttons for files. */\r\n.extra-network-tree .tree-list-item--has-subitem .tree-list--subgroup > li:first-child {\r\n    padding-top: 0.5rem !important;\r\n}\r\n\r\n.dark .extra-network-tree div.tree-list-content:hover {\r\n    -webkit-transition: all 0.05s ease-in-out;\r\n\ttransition: all 0.05s ease-in-out;\r\n    background-color: var(--neutral-800);\r\n}\r\n\r\n.dark .extra-network-tree div.tree-list-content[data-selected] {\r\n    background-color: var(--neutral-700);\r\n}\r\n\r\n.extra-network-tree div.tree-list-content[data-selected] {\r\n    background-color: var(--neutral-300);\r\n}\r\n\r\n.extra-network-tree div.tree-list-content:hover {\r\n    -webkit-transition: all 0.05s ease-in-out;\r\n\ttransition: all 0.05s ease-in-out;\r\n    background-color: var(--neutral-200);\r\n}\r\n\r\n/* ==== CHEVRON ICON ACTIONS ==== */\r\n/* Define the animation for the arrow when it is clicked. */\r\n.extra-network-tree .tree-list-content-dir .tree-list-item-action-chevron {\r\n    -ms-transform: rotate(135deg);\r\n    -webkit-transform: rotate(135deg);\r\n    transform: rotate(135deg);\r\n    transition: transform 0.2s;\r\n}\r\n\r\n.extra-network-tree .tree-list-content-dir[data-expanded] .tree-list-item-action-chevron {\r\n    -ms-transform: rotate(225deg);\r\n    -webkit-transform: rotate(225deg);\r\n    transform: rotate(225deg);\r\n    transition: transform 0.2s;\r\n}\r\n\r\n.tree-list-item-action-chevron {\r\n    display: inline-flex;\r\n    /* Uses box shadow to generate a pseudo chevron `>` icon. */\r\n    padding: 0.3rem;\r\n    box-shadow: 0.1rem 0.1rem 0 0 var(--neutral-200) inset;\r\n    transform: rotate(135deg);\r\n}\r\n\r\n/* ==== SEARCH INPUT ACTIONS ==== */\r\n/* Add icon to left side of <input> */\r\n.extra-network-control .extra-network-control--search::before {\r\n    content: \"🔎︎\";\r\n    position: absolute;\r\n    margin: 0.5rem;\r\n    font-size: 1rem;\r\n    color: var(--input-placeholder-color);\r\n}\r\n\r\n.extra-network-control .extra-network-control--search {\r\n    display: inline-flex;\r\n    position: relative;\r\n}\r\n\r\n.extra-network-control .extra-network-control--search .extra-network-control--search-text {\r\n    border: 1px solid var(--button-secondary-border-color);\r\n    border-radius: 0.5rem;\r\n    color: var(--button-secondary-text-color);\r\n    background-color: transparent;\r\n    width: 100%;\r\n    padding-left: 2rem;\r\n    line-height: 1rem;\r\n}\r\n\r\n\r\n.extra-network-control .extra-network-control--search .extra-network-control--search-text::placeholder {\r\n    color: var(--input-placeholder-color);\r\n}\r\n\r\n\r\n/* <input> clear button (x on right side) styling */\r\n.extra-network-control .extra-network-control--search .extra-network-control--search-text::-webkit-search-cancel-button {\r\n    -webkit-appearance: none;\r\n    appearance: none;\r\n    cursor: pointer;\r\n    height: 1rem;\r\n    width: 1rem;\r\n    mask-image: url('data:image/svg+xml,<svg xmlns=\"http://www.w3.org/2000/svg\" width=\"24\" height=\"24\" viewBox=\"0 0 24 24\" fill=\"none\" stroke=\"black\" stroke-width=\"4\" stroke-linecap=\"round\" stroke-linejoin=\"round\"><line x1=\"18\" y1=\"6\" x2=\"6\" y2=\"18\"></line><line x1=\"6\" y1=\"6\" x2=\"18\" y2=\"18\"></line></svg>');\r\n    mask-repeat: no-repeat;\r\n    mask-position: center center;\r\n    mask-size: 100%;\r\n    background-color: var(--input-placeholder-color);\r\n}\r\n\r\n/* ==== SORT ICON ACTIONS ==== */\r\n.extra-network-control .extra-network-control--sort {\r\n    padding: 0.25rem;\r\n    display: inline-flex;\r\n    cursor: pointer;\r\n    justify-self: center;\r\n    align-self: center;\r\n}\r\n\r\n.extra-network-control .extra-network-control--sort .extra-network-control--sort-icon {\r\n    height: 1.5rem;\r\n    width: 1.5rem;\r\n    mask-repeat: no-repeat;\r\n    mask-position: center center;\r\n    mask-size: 100%;\r\n    background-color: var(--input-placeholder-color);\r\n}\r\n\r\n.extra-network-control .extra-network-control--sort[data-sortkey=\"default\"] .extra-network-control--sort-icon {\r\n    mask-image: url('data:image/svg+xml,<svg viewBox=\"0 0 24 24\" fill=\"none\" xmlns=\"http://www.w3.org/2000/svg\"><g id=\"SVGRepo_bgCarrier\" stroke-width=\"0\"></g><g id=\"SVGRepo_tracerCarrier\" stroke-linecap=\"round\" stroke-linejoin=\"round\"></g><g id=\"SVGRepo_iconCarrier\"><path fill-rule=\"evenodd\" clip-rule=\"evenodd\" d=\"M1 5C1 3.34315 2.34315 2 4 2H8.43845C9.81505 2 11.015 2.93689 11.3489 4.27239L11.7808 6H13.5H20C21.6569 6 23 7.34315 23 9V11C23 11.5523 22.5523 12 22 12C21.4477 12 21 11.5523 21 11V9C21 8.44772 20.5523 8 20 8H13.5H11.7808H4C3.44772 8 3 8.44772 3 9V10V19C3 19.5523 3.44772 20 4 20H9C9.55228 20 10 20.4477 10 21C10 21.5523 9.55228 22 9 22H4C2.34315 22 1 20.6569 1 19V10V9V5ZM3 6.17071C3.31278 6.06015 3.64936 6 4 6H9.71922L9.40859 4.75746C9.2973 4.3123 8.89732 4 8.43845 4H4C3.44772 4 3 4.44772 3 5V6.17071ZM20.1716 18.7574C20.6951 17.967 21 17.0191 21 16C21 13.2386 18.7614 11 16 11C13.2386 11 11 13.2386 11 16C11 18.7614 13.2386 21 16 21C17.0191 21 17.967 20.6951 18.7574 20.1716L21.2929 22.7071C21.6834 23.0976 22.3166 23.0976 22.7071 22.7071C23.0976 22.3166 23.0976 21.6834 22.7071 21.2929L20.1716 18.7574ZM13 16C13 14.3431 14.3431 13 16 13C17.6569 13 19 14.3431 19 16C19 17.6569 17.6569 19 16 19C14.3431 19 13 17.6569 13 16Z\" fill=\"%23000000\"></path></g></svg>');\r\n}\r\n\r\n.extra-network-control .extra-network-control--sort[data-sortkey=\"name\"] .extra-network-control--sort-icon {\r\n    mask-image: url('data:image/svg+xml,<svg viewBox=\"0 0 24 24\" fill=\"none\" xmlns=\"http://www.w3.org/2000/svg\"><g id=\"SVGRepo_bgCarrier\" stroke-width=\"0\"></g><g id=\"SVGRepo_tracerCarrier\" stroke-linecap=\"round\" stroke-linejoin=\"round\"></g><g id=\"SVGRepo_iconCarrier\"><path fill-rule=\"evenodd\" clip-rule=\"evenodd\" d=\"M17.1841 6.69223C17.063 6.42309 16.7953 6.25 16.5002 6.25C16.2051 6.25 15.9374 6.42309 15.8162 6.69223L11.3162 16.6922C11.1463 17.07 11.3147 17.514 11.6924 17.6839C12.0701 17.8539 12.5141 17.6855 12.6841 17.3078L14.1215 14.1136H18.8789L20.3162 17.3078C20.4862 17.6855 20.9302 17.8539 21.308 17.6839C21.6857 17.514 21.8541 17.07 21.6841 16.6922L17.1841 6.69223ZM16.5002 8.82764L14.7965 12.6136H18.2039L16.5002 8.82764Z\" fill=\"%231C274C\"></path><path opacity=\"0.5\" fill-rule=\"evenodd\" clip-rule=\"evenodd\" d=\"M2.25 7C2.25 6.58579 2.58579 6.25 3 6.25H13C13.4142 6.25 13.75 6.58579 13.75 7C13.75 7.41421 13.4142 7.75 13 7.75H3C2.58579 7.75 2.25 7.41421 2.25 7Z\" fill=\"%231C274C\"></path><path opacity=\"0.5\" d=\"M2.25 12C2.25 11.5858 2.58579 11.25 3 11.25H10C10.4142 11.25 10.75 11.5858 10.75 12C10.75 12.4142 10.4142 12.75 10 12.75H3C2.58579 12.75 2.25 12.4142 2.25 12Z\" fill=\"%231C274C\"></path><path opacity=\"0.5\" d=\"M2.25 17C2.25 16.5858 2.58579 16.25 3 16.25H8C8.41421 16.25 8.75 16.5858 8.75 17C8.75 17.4142 8.41421 17.75 8 17.75H3C2.58579 17.75 2.25 17.4142 2.25 17Z\" fill=\"%231C274C\"></path></g></svg>');\r\n}\r\n\r\n.extra-network-control .extra-network-control--sort[data-sortkey=\"date_created\"] .extra-network-control--sort-icon {\r\n    mask-image: url('data:image/svg+xml,<svg viewBox=\"0 0 24 24\" fill=\"none\" xmlns=\"http://www.w3.org/2000/svg\"><g id=\"SVGRepo_bgCarrier\" stroke-width=\"0\"></g><g id=\"SVGRepo_tracerCarrier\" stroke-linecap=\"round\" stroke-linejoin=\"round\"></g><g id=\"SVGRepo_iconCarrier\"><path d=\"M17 11C14.2386 11 12 13.2386 12 16C12 18.7614 14.2386 21 17 21C19.7614 21 22 18.7614 22 16C22 13.2386 19.7614 11 17 11ZM17 11V9M2 9V15.8C2 16.9201 2 17.4802 2.21799 17.908C2.40973 18.2843 2.71569 18.5903 3.09202 18.782C3.51984 19 4.0799 19 5.2 19H13M2 9V8.2C2 7.0799 2 6.51984 2.21799 6.09202C2.40973 5.71569 2.71569 5.40973 3.09202 5.21799C3.51984 5 4.0799 5 5.2 5H13.8C14.9201 5 15.4802 5 15.908 5.21799C16.2843 5.40973 16.5903 5.71569 16.782 6.09202C17 6.51984 17 7.0799 17 8.2V9M2 9H17M5 3V5M14 3V5M15 16H17M17 16H19M17 16V14M17 16V18\" stroke=\"black\" stroke-width=\"2\" stroke-linecap=\"round\" stroke-linejoin=\"round\"></path></g></svg>');\r\n}\r\n\r\n.extra-network-control .extra-network-control--sort[data-sortkey=\"date_modified\"] .extra-network-control--sort-icon {\r\n    mask-image: url('data:image/svg+xml,<svg viewBox=\"0 0 24 24\" fill=\"none\" xmlns=\"http://www.w3.org/2000/svg\"><g id=\"SVGRepo_bgCarrier\" stroke-width=\"0\"></g><g id=\"SVGRepo_tracerCarrier\" stroke-linecap=\"round\" stroke-linejoin=\"round\"></g><g id=\"SVGRepo_iconCarrier\"><path d=\"M10 21H6.2C5.0799 21 4.51984 21 4.09202 20.782C3.71569 20.5903 3.40973 20.2843 3.21799 19.908C3 19.4802 3 18.9201 3 17.8V8.2C3 7.0799 3 6.51984 3.21799 6.09202C3.40973 5.71569 3.71569 5.40973 4.09202 5.21799C4.51984 5 5.0799 5 6.2 5H17.8C18.9201 5 19.4802 5 19.908 5.21799C20.2843 5.40973 20.5903 5.71569 20.782 6.09202C21 6.51984 21 7.0799 21 8.2V10M7 3V5M17 3V5M3 9H21M13.5 13.0001L7 13M10 17.0001L7 17M14 21L16.025 20.595C16.2015 20.5597 16.2898 20.542 16.3721 20.5097C16.4452 20.4811 16.5147 20.4439 16.579 20.399C16.6516 20.3484 16.7152 20.2848 16.8426 20.1574L21 16C21.5523 15.4477 21.5523 14.5523 21 14C20.4477 13.4477 19.5523 13.4477 19 14L14.8426 18.1574C14.7152 18.2848 14.6516 18.3484 14.601 18.421C14.5561 18.4853 14.5189 18.5548 14.4903 18.6279C14.458 18.7102 14.4403 18.7985 14.405 18.975L14 21Z\" stroke=\"black\" stroke-width=\"2\" stroke-linecap=\"round\" stroke-linejoin=\"round\"></path></g></svg>');\r\n}\r\n\r\n/* ==== SORT DIRECTION ICON ACTIONS ==== */\r\n.extra-network-control .extra-network-control--sort-dir {\r\n    padding: 0.25rem;\r\n    display: inline-flex;\r\n    cursor: pointer;\r\n    justify-self: center;\r\n    align-self: center;\r\n}\r\n\r\n.extra-network-control .extra-network-control--sort-dir .extra-network-control--sort-dir-icon {\r\n    height: 1.5rem;\r\n    width: 1.5rem;\r\n    mask-repeat: no-repeat;\r\n    mask-position: center center;\r\n    mask-size: 100%;\r\n    background-color: var(--input-placeholder-color);\r\n}\r\n\r\n.extra-network-control .extra-network-control--sort-dir[data-sortdir=\"Ascending\"] .extra-network-control--sort-dir-icon {\r\n    mask-image: url('data:image/svg+xml,<svg viewBox=\"0 0 24 24\" fill=\"none\" xmlns=\"http://www.w3.org/2000/svg\"><g id=\"SVGRepo_bgCarrier\" stroke-width=\"0\"></g><g id=\"SVGRepo_tracerCarrier\" stroke-linecap=\"round\" stroke-linejoin=\"round\"></g><g id=\"SVGRepo_iconCarrier\"><path d=\"M13 12H21M13 8H21M13 16H21M6 7V17M6 7L3 10M6 7L9 10\" stroke=\"black\" stroke-width=\"2\" stroke-linecap=\"round\" stroke-linejoin=\"round\"></path></g></svg>');\r\n}\r\n\r\n.extra-network-control .extra-network-control--sort-dir[data-sortdir=\"Descending\"] .extra-network-control--sort-dir-icon {\r\n    mask-image: url('data:image/svg+xml,<svg viewBox=\"0 0 24 24\" fill=\"none\" xmlns=\"http://www.w3.org/2000/svg\"><g id=\"SVGRepo_bgCarrier\" stroke-width=\"0\"></g><g id=\"SVGRepo_tracerCarrier\" stroke-linecap=\"round\" stroke-linejoin=\"round\"></g><g id=\"SVGRepo_iconCarrier\"><path d=\"M13 12H21M13 8H21M13 16H21M6 7V17M6 17L3 14M6 17L9 14\" stroke=\"black\" stroke-width=\"2\" stroke-linecap=\"round\" stroke-linejoin=\"round\"></path></g></svg>');\r\n}\r\n\r\n/* ==== TREE VIEW ICON ACTIONS ==== */\r\n.extra-network-control .extra-network-control--tree-view {\r\n    padding: 0.25rem;\r\n    display: inline-flex;\r\n    cursor: pointer;\r\n    justify-self: center;\r\n    align-self: center;\r\n}\r\n\r\n.extra-network-control .extra-network-control--tree-view .extra-network-control--tree-view-icon {\r\n    height: 1.5rem;\r\n    width: 1.5rem;\r\n    mask-image: url('data:image/svg+xml,<svg viewBox=\"0 0 16 16\" version=\"1.1\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\" fill=\"black\"><g id=\"SVGRepo_bgCarrier\" stroke-width=\"0\"></g><g id=\"SVGRepo_tracerCarrier\" stroke-linecap=\"round\" stroke-linejoin=\"round\"></g><g id=\"SVGRepo_iconCarrier\"><path fill=\"black\" d=\"M16 10v-4h-11v1h-2v-3h9v-4h-12v4h2v10h3v2h11v-4h-11v1h-2v-5h2v2z\"></path></g></svg>');\r\n    mask-repeat: no-repeat;\r\n    mask-position: center center;\r\n    mask-size: 100%;\r\n    background-color: var(--input-placeholder-color);\r\n}\r\n\r\n.extra-network-control .extra-network-control--enabled {\r\n    background-color: rgba(0, 0, 0, 0.1);\r\n    border-radius: 0.25rem;\r\n}\r\n\r\n.dark .extra-network-control .extra-network-control--enabled {\r\n    background-color: rgba(255, 255, 255, 0.15);\r\n}\r\n\r\n.extra-network-control .extra-network-control--enabled .extra-network-control--icon{\r\n    background-color: var(--button-secondary-text-color);\r\n}\r\n\r\n/* ==== REFRESH ICON ACTIONS ==== */\r\n.extra-network-control .extra-network-control--refresh {\r\n    padding: 0.25rem;\r\n    display: inline-flex;\r\n    cursor: pointer;\r\n    justify-self: center;\r\n    align-self: center;\r\n}\r\n\r\n.extra-network-control .extra-network-control--refresh .extra-network-control--refresh-icon {\r\n    height: 1.5rem;\r\n    width: 1.5rem;\r\n    mask-image: url('data:image/svg+xml,<svg xmlns=\"http://www.w3.org/2000/svg\" width=\"24\" height=\"24\" viewBox=\"0 0 24 24\" fill=\"none\" stroke=\"black\" stroke-width=\"2\" stroke-linecap=\"round\" stroke-linejoin=\"bevel\"><path d=\"M21.5 2v6h-6M21.34 15.57a10 10 0 1 1-.57-8.38\"/></svg>');\r\n    mask-repeat: no-repeat;\r\n    mask-position: center center;\r\n    mask-size: 100%;\r\n    background-color: var(--input-placeholder-color);\r\n}\r\n\r\n.extra-network-control .extra-network-control--refresh-icon:active {\r\n    -ms-transform: rotate(180deg);\r\n    -webkit-transform: rotate(180deg);\r\n    transform: rotate(180deg);\r\n    transition: transform 0.2s;\r\n}\r\n\r\n/* ==== TREE GRID CONFIG ==== */\r\n\r\n/* Text for button. */\r\n.extra-network-tree .tree-list-item-label {\r\n    position: relative;\r\n    line-height: 1.25rem;\r\n    color: var(--button-secondary-text-color);\r\n    grid-area: label;\r\n    padding-left: 0.5rem;\r\n}\r\n\r\n/* Text for button truncated. */\r\n.extra-network-tree .tree-list-item-label--truncate {\r\n    overflow: hidden;\r\n    text-overflow: ellipsis;\r\n    white-space: nowrap;\r\n}\r\n\r\n/* Icon for button. */\r\n.extra-network-tree .tree-list-item-visual {\r\n    min-height: 1rem;\r\n    color: var(--button-secondary-text-color);\r\n    pointer-events: none;\r\n    align-items: right;\r\n}\r\n\r\n\r\n/* Icon for button when it is before label. */\r\n.extra-network-tree .tree-list-item-visual--leading {\r\n    grid-area: leading-visual;\r\n    width: 1rem;\r\n    text-align: right;\r\n}\r\n\r\n/* Icon for button when it is after label. */\r\n.extra-network-tree .tree-list-item-visual--trailing {\r\n    grid-area: trailing-visual;\r\n    width: 1rem;\r\n    text-align: right;\r\n}\r\n\r\n/* Dropdown arrow for button. */\r\n.extra-network-tree .tree-list-item-action--leading {\r\n    margin-right: 0.5rem;\r\n    margin-left: 0.2rem;\r\n}\r\n\r\n.extra-network-tree .tree-list-content-file .tree-list-item-action--leading {\r\n    visibility: hidden;\r\n}\r\n\r\n.extra-network-tree .tree-list-item-action--leading {\r\n    grid-area: leading-action;\r\n}\r\n\r\n.extra-network-tree .tree-list-item-action--trailing {\r\n    grid-area: trailing-action;\r\n    display: inline-flex;\r\n}\r\n\r\n.extra-network-tree .tree-list-content .button-row {\r\n    display: inline-flex;\r\n    visibility: hidden;\r\n    color: var(--button-secondary-text-color);\r\n    width: 0;\r\n}\r\n\r\n.extra-network-tree .tree-list-content:hover .button-row {\r\n    visibility: visible;\r\n    width: auto;\r\n}\r\n"
  },
  {
    "path": "test/__init__.py",
    "content": ""
  },
  {
    "path": "test/conftest.py",
    "content": "import base64\nimport os\n\nimport pytest\n\ntest_files_path = os.path.dirname(__file__) + \"/test_files\"\ntest_outputs_path = os.path.dirname(__file__) + \"/test_outputs\"\n\n\ndef pytest_configure(config):\n    # We don't want to fail on Py.test command line arguments being\n    # parsed by webui:\n    os.environ.setdefault(\"IGNORE_CMD_ARGS_ERRORS\", \"1\")\n\n\ndef file_to_base64(filename):\n    with open(filename, \"rb\") as file:\n        data = file.read()\n\n    base64_str = str(base64.b64encode(data), \"utf-8\")\n    return \"data:image/png;base64,\" + base64_str\n\n\n@pytest.fixture(scope=\"session\")  # session so we don't read this over and over\ndef img2img_basic_image_base64() -> str:\n    return file_to_base64(os.path.join(test_files_path, \"img2img_basic.png\"))\n\n\n@pytest.fixture(scope=\"session\")  # session so we don't read this over and over\ndef mask_basic_image_base64() -> str:\n    return file_to_base64(os.path.join(test_files_path, \"mask_basic.png\"))\n\n\n@pytest.fixture(scope=\"session\")\ndef initialize() -> None:\n    import webui  # noqa: F401\n"
  },
  {
    "path": "test/test_extras.py",
    "content": "import requests\n\n\ndef test_simple_upscaling_performed(base_url, img2img_basic_image_base64):\n    payload = {\n        \"resize_mode\": 0,\n        \"show_extras_results\": True,\n        \"gfpgan_visibility\": 0,\n        \"codeformer_visibility\": 0,\n        \"codeformer_weight\": 0,\n        \"upscaling_resize\": 2,\n        \"upscaling_resize_w\": 128,\n        \"upscaling_resize_h\": 128,\n        \"upscaling_crop\": True,\n        \"upscaler_1\": \"Lanczos\",\n        \"upscaler_2\": \"None\",\n        \"extras_upscaler_2_visibility\": 0,\n        \"image\": img2img_basic_image_base64,\n    }\n    assert requests.post(f\"{base_url}/sdapi/v1/extra-single-image\", json=payload).status_code == 200\n\n\ndef test_png_info_performed(base_url, img2img_basic_image_base64):\n    payload = {\n        \"image\": img2img_basic_image_base64,\n    }\n    assert requests.post(f\"{base_url}/sdapi/v1/extra-single-image\", json=payload).status_code == 200\n\n\ndef test_interrogate_performed(base_url, img2img_basic_image_base64):\n    payload = {\n        \"image\": img2img_basic_image_base64,\n        \"model\": \"clip\",\n    }\n    assert requests.post(f\"{base_url}/sdapi/v1/extra-single-image\", json=payload).status_code == 200\n"
  },
  {
    "path": "test/test_face_restorers.py",
    "content": "import os\nfrom test.conftest import test_files_path, test_outputs_path\n\nimport numpy as np\nimport pytest\nfrom PIL import Image\n\n\n@pytest.mark.usefixtures(\"initialize\")\n@pytest.mark.parametrize(\"restorer_name\", [\"gfpgan\", \"codeformer\"])\ndef test_face_restorers(restorer_name):\n    from modules import shared\n\n    if restorer_name == \"gfpgan\":\n        from modules import gfpgan_model\n        gfpgan_model.setup_model(shared.cmd_opts.gfpgan_models_path)\n        restorer = gfpgan_model.gfpgan_fix_faces\n    elif restorer_name == \"codeformer\":\n        from modules import codeformer_model\n        codeformer_model.setup_model(shared.cmd_opts.codeformer_models_path)\n        restorer = codeformer_model.codeformer.restore\n    else:\n        raise NotImplementedError(\"...\")\n    img = Image.open(os.path.join(test_files_path, \"two-faces.jpg\"))\n    np_img = np.array(img, dtype=np.uint8)\n    fixed_image = restorer(np_img)\n    assert fixed_image.shape == np_img.shape\n    assert not np.allclose(fixed_image, np_img)  # should have visibly changed\n    Image.fromarray(fixed_image).save(os.path.join(test_outputs_path, f\"{restorer_name}.png\"))\n"
  },
  {
    "path": "test/test_img2img.py",
    "content": "\nimport pytest\nimport requests\n\n\n@pytest.fixture()\ndef url_img2img(base_url):\n    return f\"{base_url}/sdapi/v1/img2img\"\n\n\n@pytest.fixture()\ndef simple_img2img_request(img2img_basic_image_base64):\n    return {\n        \"batch_size\": 1,\n        \"cfg_scale\": 7,\n        \"denoising_strength\": 0.75,\n        \"eta\": 0,\n        \"height\": 64,\n        \"include_init_images\": False,\n        \"init_images\": [img2img_basic_image_base64],\n        \"inpaint_full_res\": False,\n        \"inpaint_full_res_padding\": 0,\n        \"inpainting_fill\": 0,\n        \"inpainting_mask_invert\": False,\n        \"mask\": None,\n        \"mask_blur\": 4,\n        \"n_iter\": 1,\n        \"negative_prompt\": \"\",\n        \"override_settings\": {},\n        \"prompt\": \"example prompt\",\n        \"resize_mode\": 0,\n        \"restore_faces\": False,\n        \"s_churn\": 0,\n        \"s_noise\": 1,\n        \"s_tmax\": 0,\n        \"s_tmin\": 0,\n        \"sampler_index\": \"Euler a\",\n        \"seed\": -1,\n        \"seed_resize_from_h\": -1,\n        \"seed_resize_from_w\": -1,\n        \"steps\": 3,\n        \"styles\": [],\n        \"subseed\": -1,\n        \"subseed_strength\": 0,\n        \"tiling\": False,\n        \"width\": 64,\n    }\n\n\ndef test_img2img_simple_performed(url_img2img, simple_img2img_request):\n    assert requests.post(url_img2img, json=simple_img2img_request).status_code == 200\n\n\ndef test_inpainting_masked_performed(url_img2img, simple_img2img_request, mask_basic_image_base64):\n    simple_img2img_request[\"mask\"] = mask_basic_image_base64\n    assert requests.post(url_img2img, json=simple_img2img_request).status_code == 200\n\n\ndef test_inpainting_with_inverted_masked_performed(url_img2img, simple_img2img_request, mask_basic_image_base64):\n    simple_img2img_request[\"mask\"] = mask_basic_image_base64\n    simple_img2img_request[\"inpainting_mask_invert\"] = True\n    assert requests.post(url_img2img, json=simple_img2img_request).status_code == 200\n\n\ndef test_img2img_sd_upscale_performed(url_img2img, simple_img2img_request):\n    simple_img2img_request[\"script_name\"] = \"sd upscale\"\n    simple_img2img_request[\"script_args\"] = [\"\", 8, \"Lanczos\", 2.0]\n    assert requests.post(url_img2img, json=simple_img2img_request).status_code == 200\n"
  },
  {
    "path": "test/test_torch_utils.py",
    "content": "import types\n\nimport pytest\nimport torch\n\nfrom modules import torch_utils\n\n\n@pytest.mark.parametrize(\"wrapped\", [True, False])\ndef test_get_param(wrapped):\n    mod = torch.nn.Linear(1, 1)\n    cpu = torch.device(\"cpu\")\n    mod.to(dtype=torch.float16, device=cpu)\n    if wrapped:\n        # more or less how spandrel wraps a thing\n        mod = types.SimpleNamespace(model=mod)\n    p = torch_utils.get_param(mod)\n    assert p.dtype == torch.float16\n    assert p.device == cpu\n"
  },
  {
    "path": "test/test_txt2img.py",
    "content": "\nimport pytest\nimport requests\n\n\n@pytest.fixture()\ndef url_txt2img(base_url):\n    return f\"{base_url}/sdapi/v1/txt2img\"\n\n\n@pytest.fixture()\ndef simple_txt2img_request():\n    return {\n        \"batch_size\": 1,\n        \"cfg_scale\": 7,\n        \"denoising_strength\": 0,\n        \"enable_hr\": False,\n        \"eta\": 0,\n        \"firstphase_height\": 0,\n        \"firstphase_width\": 0,\n        \"height\": 64,\n        \"n_iter\": 1,\n        \"negative_prompt\": \"\",\n        \"prompt\": \"example prompt\",\n        \"restore_faces\": False,\n        \"s_churn\": 0,\n        \"s_noise\": 1,\n        \"s_tmax\": 0,\n        \"s_tmin\": 0,\n        \"sampler_index\": \"Euler a\",\n        \"seed\": -1,\n        \"seed_resize_from_h\": -1,\n        \"seed_resize_from_w\": -1,\n        \"steps\": 3,\n        \"styles\": [],\n        \"subseed\": -1,\n        \"subseed_strength\": 0,\n        \"tiling\": False,\n        \"width\": 64,\n    }\n\n\ndef test_txt2img_simple_performed(url_txt2img, simple_txt2img_request):\n    assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200\n\n\ndef test_txt2img_with_negative_prompt_performed(url_txt2img, simple_txt2img_request):\n    simple_txt2img_request[\"negative_prompt\"] = \"example negative prompt\"\n    assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200\n\n\ndef test_txt2img_with_complex_prompt_performed(url_txt2img, simple_txt2img_request):\n    simple_txt2img_request[\"prompt\"] = \"((emphasis)), (emphasis1:1.1), [to:1], [from::2], [from:to:0.3], [alt|alt1]\"\n    assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200\n\n\ndef test_txt2img_not_square_image_performed(url_txt2img, simple_txt2img_request):\n    simple_txt2img_request[\"height\"] = 128\n    assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200\n\n\ndef test_txt2img_with_hrfix_performed(url_txt2img, simple_txt2img_request):\n    simple_txt2img_request[\"enable_hr\"] = True\n    assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200\n\n\ndef test_txt2img_with_tiling_performed(url_txt2img, simple_txt2img_request):\n    simple_txt2img_request[\"tiling\"] = True\n    assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200\n\n\ndef test_txt2img_with_restore_faces_performed(url_txt2img, simple_txt2img_request):\n    simple_txt2img_request[\"restore_faces\"] = True\n    assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200\n\n\n@pytest.mark.parametrize(\"sampler\", [\"PLMS\", \"DDIM\", \"UniPC\"])\ndef test_txt2img_with_vanilla_sampler_performed(url_txt2img, simple_txt2img_request, sampler):\n    simple_txt2img_request[\"sampler_index\"] = sampler\n    assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200\n\n\ndef test_txt2img_multiple_batches_performed(url_txt2img, simple_txt2img_request):\n    simple_txt2img_request[\"n_iter\"] = 2\n    assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200\n\n\ndef test_txt2img_batch_performed(url_txt2img, simple_txt2img_request):\n    simple_txt2img_request[\"batch_size\"] = 2\n    assert requests.post(url_txt2img, json=simple_txt2img_request).status_code == 200\n"
  },
  {
    "path": "test/test_utils.py",
    "content": "import pytest\nimport requests\n\n\ndef test_options_write(base_url):\n    url_options = f\"{base_url}/sdapi/v1/options\"\n    response = requests.get(url_options)\n    assert response.status_code == 200\n\n    pre_value = response.json()[\"send_seed\"]\n\n    assert requests.post(url_options, json={'send_seed': (not pre_value)}).status_code == 200\n\n    response = requests.get(url_options)\n    assert response.status_code == 200\n    assert response.json()['send_seed'] == (not pre_value)\n\n    requests.post(url_options, json={\"send_seed\": pre_value})\n\n\n@pytest.mark.parametrize(\"url\", [\n    \"sdapi/v1/cmd-flags\",\n    \"sdapi/v1/samplers\",\n    \"sdapi/v1/upscalers\",\n    \"sdapi/v1/sd-models\",\n    \"sdapi/v1/hypernetworks\",\n    \"sdapi/v1/face-restorers\",\n    \"sdapi/v1/realesrgan-models\",\n    \"sdapi/v1/prompt-styles\",\n    \"sdapi/v1/embeddings\",\n])\ndef test_get_api_url(base_url, url):\n    assert requests.get(f\"{base_url}/{url}\").status_code == 200\n"
  },
  {
    "path": "textual_inversion_templates/hypernetwork.txt",
    "content": "a photo of a [filewords]\na rendering of a [filewords]\na cropped photo of the [filewords]\nthe photo of a [filewords]\na photo of a clean [filewords]\na photo of a dirty [filewords]\na dark photo of the [filewords]\na photo of my [filewords]\na photo of the cool [filewords]\na close-up photo of a [filewords]\na bright photo of the [filewords]\na cropped photo of a [filewords]\na photo of the [filewords]\na good photo of the [filewords]\na photo of one [filewords]\na close-up photo of the [filewords]\na rendition of the [filewords]\na photo of the clean [filewords]\na rendition of a [filewords]\na photo of a nice [filewords]\na good photo of a [filewords]\na photo of the nice [filewords]\na photo of the small [filewords]\na photo of the weird [filewords]\na photo of the large [filewords]\na photo of a cool [filewords]\na photo of a small [filewords]\n"
  },
  {
    "path": "textual_inversion_templates/none.txt",
    "content": "picture\n"
  },
  {
    "path": "textual_inversion_templates/style.txt",
    "content": "a painting, art by [name]\na rendering, art by [name]\na cropped painting, art by [name]\nthe painting, art by [name]\na clean painting, art by [name]\na dirty painting, art by [name]\na dark painting, art by [name]\na picture, art by [name]\na cool painting, art by [name]\na close-up painting, art by [name]\na bright painting, art by [name]\na cropped painting, art by [name]\na good painting, art by [name]\na close-up painting, art by [name]\na rendition, art by [name]\na nice painting, art by [name]\na small painting, art by [name]\na weird painting, art by [name]\na large painting, art by [name]\n"
  },
  {
    "path": "textual_inversion_templates/style_filewords.txt",
    "content": "a painting of [filewords], art by [name]\na rendering of [filewords], art by [name]\na cropped painting of [filewords], art by [name]\nthe painting of [filewords], art by [name]\na clean painting of [filewords], art by [name]\na dirty painting of [filewords], art by [name]\na dark painting of [filewords], art by [name]\na picture of [filewords], art by [name]\na cool painting of [filewords], art by [name]\na close-up painting of [filewords], art by [name]\na bright painting of [filewords], art by [name]\na cropped painting of [filewords], art by [name]\na good painting of [filewords], art by [name]\na close-up painting of [filewords], art by [name]\na rendition of [filewords], art by [name]\na nice painting of [filewords], art by [name]\na small painting of [filewords], art by [name]\na weird painting of [filewords], art by [name]\na large painting of [filewords], art by [name]\n"
  },
  {
    "path": "textual_inversion_templates/subject.txt",
    "content": "a photo of a [name]\na rendering of a [name]\na cropped photo of the [name]\nthe photo of a [name]\na photo of a clean [name]\na photo of a dirty [name]\na dark photo of the [name]\na photo of my [name]\na photo of the cool [name]\na close-up photo of a [name]\na bright photo of the [name]\na cropped photo of a [name]\na photo of the [name]\na good photo of the [name]\na photo of one [name]\na close-up photo of the [name]\na rendition of the [name]\na photo of the clean [name]\na rendition of a [name]\na photo of a nice [name]\na good photo of a [name]\na photo of the nice [name]\na photo of the small [name]\na photo of the weird [name]\na photo of the large [name]\na photo of a cool [name]\na photo of a small [name]\n"
  },
  {
    "path": "textual_inversion_templates/subject_filewords.txt",
    "content": "a photo of a [name], [filewords]\na rendering of a [name], [filewords]\na cropped photo of the [name], [filewords]\nthe photo of a [name], [filewords]\na photo of a clean [name], [filewords]\na photo of a dirty [name], [filewords]\na dark photo of the [name], [filewords]\na photo of my [name], [filewords]\na photo of the cool [name], [filewords]\na close-up photo of a [name], [filewords]\na bright photo of the [name], [filewords]\na cropped photo of a [name], [filewords]\na photo of the [name], [filewords]\na good photo of the [name], [filewords]\na photo of one [name], [filewords]\na close-up photo of the [name], [filewords]\na rendition of the [name], [filewords]\na photo of the clean [name], [filewords]\na rendition of a [name], [filewords]\na photo of a nice [name], [filewords]\na good photo of a [name], [filewords]\na photo of the nice [name], [filewords]\na photo of the small [name], [filewords]\na photo of the weird [name], [filewords]\na photo of the large [name], [filewords]\na photo of a cool [name], [filewords]\na photo of a small [name], [filewords]\n"
  },
  {
    "path": "webui-macos-env.sh",
    "content": "#!/bin/bash\n####################################################################\n#                          macOS defaults                          #\n# Please modify webui-user.sh to change these instead of this file #\n####################################################################\n\nexport install_dir=\"$HOME\"\nexport COMMANDLINE_ARGS=\"--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate\"\nexport PYTORCH_ENABLE_MPS_FALLBACK=1\n\nif [[ \"$(sysctl -n machdep.cpu.brand_string)\" =~ ^.*\"Intel\".*$ ]]; then\n    export TORCH_COMMAND=\"pip install torch==2.1.2 torchvision==0.16.2\"\nelse\n    export TORCH_COMMAND=\"pip install torch==2.3.1 torchvision==0.18.1\"\nfi\n\n####################################################################\n"
  },
  {
    "path": "webui.bat",
    "content": "@echo off\r\n\r\nif exist webui.settings.bat (\r\n    call webui.settings.bat\r\n)\r\n\r\nif not defined PYTHON (set PYTHON=python)\r\nif defined GIT (set \"GIT_PYTHON_GIT_EXECUTABLE=%GIT%\")\r\nif not defined VENV_DIR (set \"VENV_DIR=%~dp0%venv\")\r\n\r\nset SD_WEBUI_RESTART=tmp/restart\r\nset ERROR_REPORTING=FALSE\r\n\r\nmkdir tmp 2>NUL\r\n\r\n%PYTHON% -c \"\" >tmp/stdout.txt 2>tmp/stderr.txt\r\nif %ERRORLEVEL% == 0 goto :check_pip\r\necho Couldn't launch python\r\ngoto :show_stdout_stderr\r\n\r\n:check_pip\r\n%PYTHON% -mpip --help >tmp/stdout.txt 2>tmp/stderr.txt\r\nif %ERRORLEVEL% == 0 goto :start_venv\r\nif \"%PIP_INSTALLER_LOCATION%\" == \"\" goto :show_stdout_stderr\r\n%PYTHON% \"%PIP_INSTALLER_LOCATION%\" >tmp/stdout.txt 2>tmp/stderr.txt\r\nif %ERRORLEVEL% == 0 goto :start_venv\r\necho Couldn't install pip\r\ngoto :show_stdout_stderr\r\n\r\n:start_venv\r\nif [\"%VENV_DIR%\"] == [\"-\"] goto :skip_venv\r\nif [\"%SKIP_VENV%\"] == [\"1\"] goto :skip_venv\r\n\r\ndir \"%VENV_DIR%\\Scripts\\Python.exe\" >tmp/stdout.txt 2>tmp/stderr.txt\r\nif %ERRORLEVEL% == 0 goto :activate_venv\r\n\r\nfor /f \"delims=\" %%i in ('CALL %PYTHON% -c \"import sys; print(sys.executable)\"') do set PYTHON_FULLNAME=\"%%i\"\r\necho Creating venv in directory %VENV_DIR% using python %PYTHON_FULLNAME%\r\n%PYTHON_FULLNAME% -m venv \"%VENV_DIR%\" >tmp/stdout.txt 2>tmp/stderr.txt\r\nif %ERRORLEVEL% == 0 goto :upgrade_pip\r\necho Unable to create venv in directory \"%VENV_DIR%\"\r\ngoto :show_stdout_stderr\r\n\r\n:upgrade_pip\r\n\"%VENV_DIR%\\Scripts\\Python.exe\" -m pip install --upgrade pip\r\nif %ERRORLEVEL% == 0 goto :activate_venv\r\necho Warning: Failed to upgrade PIP version\r\n\r\n:activate_venv\r\nset PYTHON=\"%VENV_DIR%\\Scripts\\Python.exe\"\r\ncall \"%VENV_DIR%\\Scripts\\activate.bat\"\r\necho venv %PYTHON%\r\n\r\n:skip_venv\r\nif [%ACCELERATE%] == [\"True\"] goto :accelerate\r\ngoto :launch\r\n\r\n:accelerate\r\necho Checking for accelerate\r\nset ACCELERATE=\"%VENV_DIR%\\Scripts\\accelerate.exe\"\r\nif EXIST %ACCELERATE% goto :accelerate_launch\r\n\r\n:launch\r\n%PYTHON% launch.py %*\r\nif EXIST tmp/restart goto :skip_venv\r\npause\r\nexit /b\r\n\r\n:accelerate_launch\r\necho Accelerating\r\n%ACCELERATE% launch --num_cpu_threads_per_process=6 launch.py\r\nif EXIST tmp/restart goto :skip_venv\r\npause\r\nexit /b\r\n\r\n:show_stdout_stderr\r\n\r\necho.\r\necho exit code: %errorlevel%\r\n\r\nfor /f %%i in (\"tmp\\stdout.txt\") do set size=%%~zi\r\nif %size% equ 0 goto :show_stderr\r\necho.\r\necho stdout:\r\ntype tmp\\stdout.txt\r\n\r\n:show_stderr\r\nfor /f %%i in (\"tmp\\stderr.txt\") do set size=%%~zi\r\nif %size% equ 0 goto :show_stderr\r\necho.\r\necho stderr:\r\ntype tmp\\stderr.txt\r\n\r\n:endofscript\r\n\r\necho.\r\necho Launch unsuccessful. Exiting.\r\npause\r\n"
  },
  {
    "path": "webui.py",
    "content": "from __future__ import annotations\r\n\r\nimport os\r\nimport time\r\n\r\nfrom modules import timer\r\nfrom modules import initialize_util\r\nfrom modules import initialize\r\n\r\nstartup_timer = timer.startup_timer\r\nstartup_timer.record(\"launcher\")\r\n\r\ninitialize.imports()\r\n\r\ninitialize.check_versions()\r\n\r\n\r\ndef create_api(app):\r\n    from modules.api.api import Api\r\n    from modules.call_queue import queue_lock\r\n\r\n    api = Api(app, queue_lock)\r\n    return api\r\n\r\n\r\ndef api_only():\r\n    from fastapi import FastAPI\r\n    from modules.shared_cmd_options import cmd_opts\r\n\r\n    initialize.initialize()\r\n\r\n    app = FastAPI()\r\n    initialize_util.setup_middleware(app)\r\n    api = create_api(app)\r\n\r\n    from modules import script_callbacks\r\n    script_callbacks.before_ui_callback()\r\n    script_callbacks.app_started_callback(None, app)\r\n\r\n    print(f\"Startup time: {startup_timer.summary()}.\")\r\n    api.launch(\r\n        server_name=initialize_util.gradio_server_name(),\r\n        port=cmd_opts.port if cmd_opts.port else 7861,\r\n        root_path=f\"/{cmd_opts.subpath}\" if cmd_opts.subpath else \"\"\r\n    )\r\n\r\n\r\ndef webui():\r\n    from modules.shared_cmd_options import cmd_opts\r\n\r\n    launch_api = cmd_opts.api\r\n    initialize.initialize()\r\n\r\n    from modules import shared, ui_tempdir, script_callbacks, ui, progress, ui_extra_networks\r\n\r\n    while 1:\r\n        if shared.opts.clean_temp_dir_at_start:\r\n            ui_tempdir.cleanup_tmpdr()\r\n            startup_timer.record(\"cleanup temp dir\")\r\n\r\n        script_callbacks.before_ui_callback()\r\n        startup_timer.record(\"scripts before_ui_callback\")\r\n\r\n        shared.demo = ui.create_ui()\r\n        startup_timer.record(\"create ui\")\r\n\r\n        if not cmd_opts.no_gradio_queue:\r\n            shared.demo.queue(64)\r\n\r\n        gradio_auth_creds = list(initialize_util.get_gradio_auth_creds()) or None\r\n\r\n        auto_launch_browser = False\r\n        if os.getenv('SD_WEBUI_RESTARTING') != '1':\r\n            if shared.opts.auto_launch_browser == \"Remote\" or cmd_opts.autolaunch:\r\n                auto_launch_browser = True\r\n            elif shared.opts.auto_launch_browser == \"Local\":\r\n                auto_launch_browser = not cmd_opts.webui_is_non_local\r\n\r\n        app, local_url, share_url = shared.demo.launch(\r\n            share=cmd_opts.share,\r\n            server_name=initialize_util.gradio_server_name(),\r\n            server_port=cmd_opts.port,\r\n            ssl_keyfile=cmd_opts.tls_keyfile,\r\n            ssl_certfile=cmd_opts.tls_certfile,\r\n            ssl_verify=cmd_opts.disable_tls_verify,\r\n            debug=cmd_opts.gradio_debug,\r\n            auth=gradio_auth_creds,\r\n            inbrowser=auto_launch_browser,\r\n            prevent_thread_lock=True,\r\n            allowed_paths=cmd_opts.gradio_allowed_path,\r\n            app_kwargs={\r\n                \"docs_url\": \"/docs\",\r\n                \"redoc_url\": \"/redoc\",\r\n            },\r\n            root_path=f\"/{cmd_opts.subpath}\" if cmd_opts.subpath else \"\",\r\n        )\r\n\r\n        startup_timer.record(\"gradio launch\")\r\n\r\n        # gradio uses a very open CORS policy via app.user_middleware, which makes it possible for\r\n        # an attacker to trick the user into opening a malicious HTML page, which makes a request to the\r\n        # running web ui and do whatever the attacker wants, including installing an extension and\r\n        # running its code. We disable this here. Suggested by RyotaK.\r\n        app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware']\r\n\r\n        initialize_util.setup_middleware(app)\r\n\r\n        progress.setup_progress_api(app)\r\n        ui.setup_ui_api(app)\r\n\r\n        if launch_api:\r\n            create_api(app)\r\n\r\n        ui_extra_networks.add_pages_to_demo(app)\r\n\r\n        startup_timer.record(\"add APIs\")\r\n\r\n        with startup_timer.subcategory(\"app_started_callback\"):\r\n            script_callbacks.app_started_callback(shared.demo, app)\r\n\r\n        timer.startup_record = startup_timer.dump()\r\n        print(f\"Startup time: {startup_timer.summary()}.\")\r\n\r\n        try:\r\n            while True:\r\n                server_command = shared.state.wait_for_server_command(timeout=5)\r\n                if server_command:\r\n                    if server_command in (\"stop\", \"restart\"):\r\n                        break\r\n                    else:\r\n                        print(f\"Unknown server command: {server_command}\")\r\n        except KeyboardInterrupt:\r\n            print('Caught KeyboardInterrupt, stopping...')\r\n            server_command = \"stop\"\r\n\r\n        if server_command == \"stop\":\r\n            print(\"Stopping server...\")\r\n            # If we catch a keyboard interrupt, we want to stop the server and exit.\r\n            shared.demo.close()\r\n            break\r\n\r\n        # disable auto launch webui in browser for subsequent UI Reload\r\n        os.environ.setdefault('SD_WEBUI_RESTARTING', '1')\r\n\r\n        print('Restarting UI...')\r\n        shared.demo.close()\r\n        time.sleep(0.5)\r\n        startup_timer.reset()\r\n        script_callbacks.app_reload_callback()\r\n        startup_timer.record(\"app reload callback\")\r\n        script_callbacks.script_unloaded_callback()\r\n        startup_timer.record(\"scripts unloaded callback\")\r\n        initialize.initialize_rest(reload_script_modules=True)\r\n\r\n\r\nif __name__ == \"__main__\":\r\n    from modules.shared_cmd_options import cmd_opts\r\n\r\n    if cmd_opts.nowebui:\r\n        api_only()\r\n    else:\r\n        webui()\r\n"
  },
  {
    "path": "webui.sh",
    "content": "#!/usr/bin/env bash\n#################################################\n# Please do not make any changes to this file,  #\n# change the variables in webui-user.sh instead #\n#################################################\n\nSCRIPT_DIR=$( cd -- \"$( dirname -- \"${BASH_SOURCE[0]}\" )\" &> /dev/null && pwd )\n\n\n# If run from macOS, load defaults from webui-macos-env.sh\nif [[ \"$OSTYPE\" == \"darwin\"* ]]; then\n    if [[ -f \"$SCRIPT_DIR\"/webui-macos-env.sh ]]\n        then\n        source \"$SCRIPT_DIR\"/webui-macos-env.sh\n    fi\nfi\n\n# Read variables from webui-user.sh\n# shellcheck source=/dev/null\nif [[ -f \"$SCRIPT_DIR\"/webui-user.sh ]]\nthen\n    source \"$SCRIPT_DIR\"/webui-user.sh\nfi\n\n# If $venv_dir is \"-\", then disable venv support\nuse_venv=1\nif [[ $venv_dir == \"-\" ]]; then\n  use_venv=0\nfi\n\n# Set defaults\n# Install directory without trailing slash\nif [[ -z \"${install_dir}\" ]]\nthen\n    install_dir=\"$SCRIPT_DIR\"\nfi\n\n# Name of the subdirectory (defaults to stable-diffusion-webui)\nif [[ -z \"${clone_dir}\" ]]\nthen\n    clone_dir=\"stable-diffusion-webui\"\nfi\n\n# python3 executable\nif [[ -z \"${python_cmd}\" ]]\nthen\n  python_cmd=\"python3.10\"\nfi\nif [[ ! -x \"$(command -v \"${python_cmd}\")\" ]]\nthen\n  python_cmd=\"python3\"\nfi\n\n# git executable\nif [[ -z \"${GIT}\" ]]\nthen\n    export GIT=\"git\"\nelse\n    export GIT_PYTHON_GIT_EXECUTABLE=\"${GIT}\"\nfi\n\n# python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv)\nif [[ -z \"${venv_dir}\" ]] && [[ $use_venv -eq 1 ]]\nthen\n    venv_dir=\"venv\"\nfi\n\nif [[ -z \"${LAUNCH_SCRIPT}\" ]]\nthen\n    LAUNCH_SCRIPT=\"launch.py\"\nfi\n\n# this script cannot be run as root by default\ncan_run_as_root=0\n\n# read any command line flags to the webui.sh script\nwhile getopts \"f\" flag > /dev/null 2>&1\ndo\n    case ${flag} in\n        f) can_run_as_root=1;;\n        *) break;;\n    esac\ndone\n\n# Disable sentry logging\nexport ERROR_REPORTING=FALSE\n\n# Do not reinstall existing pip packages on Debian/Ubuntu\nexport PIP_IGNORE_INSTALLED=0\n\n# Pretty print\ndelimiter=\"################################################################\"\n\nprintf \"\\n%s\\n\" \"${delimiter}\"\nprintf \"\\e[1m\\e[32mInstall script for stable-diffusion + Web UI\\n\"\nprintf \"\\e[1m\\e[34mTested on Debian 11 (Bullseye), Fedora 34+ and openSUSE Leap 15.4 or newer.\\e[0m\"\nprintf \"\\n%s\\n\" \"${delimiter}\"\n\n# Do not run as root\nif [[ $(id -u) -eq 0 && can_run_as_root -eq 0 ]]\nthen\n    printf \"\\n%s\\n\" \"${delimiter}\"\n    printf \"\\e[1m\\e[31mERROR: This script must not be launched as root, aborting...\\e[0m\"\n    printf \"\\n%s\\n\" \"${delimiter}\"\n    exit 1\nelse\n    printf \"\\n%s\\n\" \"${delimiter}\"\n    printf \"Running on \\e[1m\\e[32m%s\\e[0m user\" \"$(whoami)\"\n    printf \"\\n%s\\n\" \"${delimiter}\"\nfi\n\nif [[ $(getconf LONG_BIT) = 32 ]]\nthen\n    printf \"\\n%s\\n\" \"${delimiter}\"\n    printf \"\\e[1m\\e[31mERROR: Unsupported Running on a 32bit OS\\e[0m\"\n    printf \"\\n%s\\n\" \"${delimiter}\"\n    exit 1\nfi\n\nif [[ -d \"$SCRIPT_DIR/.git\" ]]\nthen\n    printf \"\\n%s\\n\" \"${delimiter}\"\n    printf \"Repo already cloned, using it as install directory\"\n    printf \"\\n%s\\n\" \"${delimiter}\"\n    install_dir=\"${SCRIPT_DIR}/../\"\n    clone_dir=\"${SCRIPT_DIR##*/}\"\nfi\n\n# Check prerequisites\ngpu_info=$(lspci 2>/dev/null | grep -E \"VGA|Display\")\ncase \"$gpu_info\" in\n    *\"Navi 1\"*)\n        export HSA_OVERRIDE_GFX_VERSION=10.3.0\n        if [[ -z \"${TORCH_COMMAND}\" ]]\n        then\n            pyv=\"$(${python_cmd} -c 'import sys; print(f\"{sys.version_info[0]}.{sys.version_info[1]:02d}\")')\"\n            # Using an old nightly compiled against rocm 5.2 for Navi1, see https://github.com/pytorch/pytorch/issues/106728#issuecomment-1749511711\n            if [[ $pyv == \"3.8\" ]]\n            then\n                export TORCH_COMMAND=\"pip install https://download.pytorch.org/whl/nightly/rocm5.2/torch-2.0.0.dev20230209%2Brocm5.2-cp38-cp38-linux_x86_64.whl https://download.pytorch.org/whl/nightly/rocm5.2/torchvision-0.15.0.dev20230209%2Brocm5.2-cp38-cp38-linux_x86_64.whl\"\n            elif [[ $pyv == \"3.9\" ]]\n            then\n                export TORCH_COMMAND=\"pip install https://download.pytorch.org/whl/nightly/rocm5.2/torch-2.0.0.dev20230209%2Brocm5.2-cp39-cp39-linux_x86_64.whl https://download.pytorch.org/whl/nightly/rocm5.2/torchvision-0.15.0.dev20230209%2Brocm5.2-cp39-cp39-linux_x86_64.whl\"\n            elif [[ $pyv == \"3.10\" ]]\n            then\n                export TORCH_COMMAND=\"pip install https://download.pytorch.org/whl/nightly/rocm5.2/torch-2.0.0.dev20230209%2Brocm5.2-cp310-cp310-linux_x86_64.whl https://download.pytorch.org/whl/nightly/rocm5.2/torchvision-0.15.0.dev20230209%2Brocm5.2-cp310-cp310-linux_x86_64.whl\"\n            else\n                printf \"\\e[1m\\e[31mERROR: RX 5000 series GPUs python version must be between 3.8 and 3.10, aborting...\\e[0m\"\n                exit 1\n            fi\n        fi\n    ;;\n    *\"Navi 2\"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0\n    ;;\n    *\"Navi 3\"*) [[ -z \"${TORCH_COMMAND}\" ]] && \\\n         export TORCH_COMMAND=\"pip install torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm5.7\"\n    ;;\n    *\"Renoir\"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0\n        printf \"\\n%s\\n\" \"${delimiter}\"\n        printf \"Experimental support for Renoir: make sure to have at least 4GB of VRAM and 10GB of RAM or enable cpu mode: --use-cpu all --no-half\"\n        printf \"\\n%s\\n\" \"${delimiter}\"\n    ;;\n    *)\n    ;;\nesac\nif ! echo \"$gpu_info\" | grep -q \"NVIDIA\";\nthen\n    if echo \"$gpu_info\" | grep -q \"AMD\" && [[ -z \"${TORCH_COMMAND}\" ]]\n    then\n\t      export TORCH_COMMAND=\"pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.7\"\n    elif npu-smi info 2>/dev/null\n    then\n        export TORCH_COMMAND=\"pip install torch==2.1.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu; pip install torch_npu==2.1.0\"\n    fi\nfi\n\nfor preq in \"${GIT}\" \"${python_cmd}\"\ndo\n    if ! hash \"${preq}\" &>/dev/null\n    then\n        printf \"\\n%s\\n\" \"${delimiter}\"\n        printf \"\\e[1m\\e[31mERROR: %s is not installed, aborting...\\e[0m\" \"${preq}\"\n        printf \"\\n%s\\n\" \"${delimiter}\"\n        exit 1\n    fi\ndone\n\nif [[ $use_venv -eq 1 ]] && ! \"${python_cmd}\" -c \"import venv\" &>/dev/null\nthen\n    printf \"\\n%s\\n\" \"${delimiter}\"\n    printf \"\\e[1m\\e[31mERROR: python3-venv is not installed, aborting...\\e[0m\"\n    printf \"\\n%s\\n\" \"${delimiter}\"\n    exit 1\nfi\n\ncd \"${install_dir}\"/ || { printf \"\\e[1m\\e[31mERROR: Can't cd to %s/, aborting...\\e[0m\" \"${install_dir}\"; exit 1; }\nif [[ -d \"${clone_dir}\" ]]\nthen\n    cd \"${clone_dir}\"/ || { printf \"\\e[1m\\e[31mERROR: Can't cd to %s/%s/, aborting...\\e[0m\" \"${install_dir}\" \"${clone_dir}\"; exit 1; }\nelse\n    printf \"\\n%s\\n\" \"${delimiter}\"\n    printf \"Clone stable-diffusion-webui\"\n    printf \"\\n%s\\n\" \"${delimiter}\"\n    \"${GIT}\" clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git \"${clone_dir}\"\n    cd \"${clone_dir}\"/ || { printf \"\\e[1m\\e[31mERROR: Can't cd to %s/%s/, aborting...\\e[0m\" \"${install_dir}\" \"${clone_dir}\"; exit 1; }\nfi\n\nif [[ $use_venv -eq 1 ]] && [[ -z \"${VIRTUAL_ENV}\" ]];\nthen\n    printf \"\\n%s\\n\" \"${delimiter}\"\n    printf \"Create and activate python venv\"\n    printf \"\\n%s\\n\" \"${delimiter}\"\n    cd \"${install_dir}\"/\"${clone_dir}\"/ || { printf \"\\e[1m\\e[31mERROR: Can't cd to %s/%s/, aborting...\\e[0m\" \"${install_dir}\" \"${clone_dir}\"; exit 1; }\n    if [[ ! -d \"${venv_dir}\" ]]\n    then\n        \"${python_cmd}\" -m venv \"${venv_dir}\"\n        \"${venv_dir}\"/bin/python -m pip install --upgrade pip\n        first_launch=1\n    fi\n    # shellcheck source=/dev/null\n    if [[ -f \"${venv_dir}\"/bin/activate ]]\n    then\n        source \"${venv_dir}\"/bin/activate\n        # ensure use of python from venv\n        python_cmd=\"${venv_dir}\"/bin/python\n    else\n        printf \"\\n%s\\n\" \"${delimiter}\"\n        printf \"\\e[1m\\e[31mERROR: Cannot activate python venv, aborting...\\e[0m\"\n        printf \"\\n%s\\n\" \"${delimiter}\"\n        exit 1\n    fi\nelse\n    printf \"\\n%s\\n\" \"${delimiter}\"\n    printf \"python venv already activate or run without venv: ${VIRTUAL_ENV}\"\n    printf \"\\n%s\\n\" \"${delimiter}\"\nfi\n\n# Try using TCMalloc on Linux\nprepare_tcmalloc() {\n    if [[ \"${OSTYPE}\" == \"linux\"* ]] && [[ -z \"${NO_TCMALLOC}\" ]] && [[ -z \"${LD_PRELOAD}\" ]]; then\n        # check glibc version\n        LIBC_VER=$(echo $(ldd --version | awk 'NR==1 {print $NF}') | grep -oP '\\d+\\.\\d+')\n        echo \"glibc version is $LIBC_VER\"\n        libc_vernum=$(expr $LIBC_VER)\n        # Since 2.34 libpthread is integrated into libc.so\n        libc_v234=2.34\n        # Define Tcmalloc Libs arrays\n        TCMALLOC_LIBS=(\"libtcmalloc(_minimal|)\\.so\\.\\d\" \"libtcmalloc\\.so\\.\\d\")\n        # Traversal array\n        for lib in \"${TCMALLOC_LIBS[@]}\"\n        do\n            # Determine which type of tcmalloc library the library supports\n            TCMALLOC=\"$(PATH=/sbin:/usr/sbin:$PATH ldconfig -p | grep -P $lib | head -n 1)\"\n            TC_INFO=(${TCMALLOC//=>/})\n            if [[ ! -z \"${TC_INFO}\" ]]; then\n                echo \"Check TCMalloc: ${TC_INFO}\"\n                # Determine if the library is linked to libpthread and resolve undefined symbol: pthread_key_create\n                if [ $(echo \"$libc_vernum < $libc_v234\" | bc) -eq 1 ]; then\n                    # glibc < 2.34 pthread_key_create into libpthread.so. check linking libpthread.so...\n                    if ldd ${TC_INFO[2]} | grep -q 'libpthread'; then\n                        echo \"$TC_INFO is linked with libpthread,execute LD_PRELOAD=${TC_INFO[2]}\"\n                        # set fullpath LD_PRELOAD (To be on the safe side)\n                        export LD_PRELOAD=\"${TC_INFO[2]}\"\n                        break\n                    else\n                        echo \"$TC_INFO is not linked with libpthread will trigger undefined symbol: pthread_Key_create error\"\n                    fi\n                else\n                    # Version 2.34 of libc.so (glibc) includes the pthread library IN GLIBC. (USE ubuntu 22.04 and modern linux system and WSL)\n                    # libc.so(glibc) is linked with a library that works in ALMOST ALL Linux userlands. SO NO CHECK!\n                    echo \"$TC_INFO is linked with libc.so,execute LD_PRELOAD=${TC_INFO[2]}\"\n                    # set fullpath LD_PRELOAD (To be on the safe side)\n                    export LD_PRELOAD=\"${TC_INFO[2]}\"\n                    break\n                fi\n            fi\n        done\n        if [[ -z \"${LD_PRELOAD}\" ]]; then\n            printf \"\\e[1m\\e[31mCannot locate TCMalloc. Do you have tcmalloc or google-perftool installed on your system? (improves CPU memory usage)\\e[0m\\n\"\n        fi\n    fi\n}\n\nKEEP_GOING=1\nexport SD_WEBUI_RESTART=tmp/restart\nwhile [[ \"$KEEP_GOING\" -eq \"1\" ]]; do\n    if [[ ! -z \"${ACCELERATE}\" ]] && [ ${ACCELERATE}=\"True\" ] && [ -x \"$(command -v accelerate)\" ]; then\n        printf \"\\n%s\\n\" \"${delimiter}\"\n        printf \"Accelerating launch.py...\"\n        printf \"\\n%s\\n\" \"${delimiter}\"\n        prepare_tcmalloc\n        accelerate launch --num_cpu_threads_per_process=6 \"${LAUNCH_SCRIPT}\" \"$@\"\n    else\n        printf \"\\n%s\\n\" \"${delimiter}\"\n        printf \"Launching launch.py...\"\n        printf \"\\n%s\\n\" \"${delimiter}\"\n        prepare_tcmalloc\n        \"${python_cmd}\" -u \"${LAUNCH_SCRIPT}\" \"$@\"\n    fi\n\n    if [[ ! -f tmp/restart ]]; then\n        KEEP_GOING=0\n    fi\ndone\n"
  }
]