[
  {
    "path": ".github/workflows/deploy.yaml",
    "content": "name: Deploy invoke-training docs\n\non:\n  push:\n    branches:\n      - main\n\npermissions:\n  contents: write\n\njobs:\n  deploy:\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v4\n      - name: Configure Git Credentials\n        run: |\n          git config user.name github-actions[bot]\n          git config user.email 41898282+github-actions[bot]@users.noreply.github.com\n      - uses: actions/setup-python@v4\n        with:\n          python-version: \"3.10\"\n          cache: pip\n          cache-dependency-path: pyproject.toml\n      - name: Install dependencies\n        run: |\n          python -m pip install --upgrade pip\n          python -m pip install .[test]\n      - run: echo \"cache_id=$(date --utc '+%V')\" >> $GITHUB_ENV \n      - uses: actions/cache@v3\n        with:\n          key: mkdocs-material-${{ env.cache_id }}\n          path: .cache\n          restore-keys: |\n            mkdocs-material-\n      - run: mkdocs gh-deploy --force\n"
  },
  {
    "path": ".github/workflows/test.yaml",
    "content": "name: Test invoke-training\n\non:\n  push:\n    branches:\n      - main\n  pull_request:\n  workflow_dispatch:\n\njobs:\n  build:\n\n    runs-on: ubuntu-latest\n    strategy:\n      fail-fast: false\n      matrix:\n        python-version: [\"3.12\"]\n\n    steps:\n    - uses: actions/checkout@v4\n    - name: Set up Python ${{ matrix.python-version }}\n      uses: actions/setup-python@v5\n      with:\n        python-version: ${{ matrix.python-version }}\n        cache: pip\n        cache-dependency-path: pyproject.toml\n    - name: Install dependencies\n      run: |\n        python -m pip install --upgrade pip\n        python -m pip install .[test]\n    - name: Ruff lint\n      run: |\n        ruff check --output-format=github .\n    - name: Ruff format\n      run: |\n        ruff format --check .\n    - name: Test with pytest\n      run: |\n        pytest tests --junitxml=junit/test-results-${{ matrix.python-version }}.xml -m \"not cuda and not loads_model\"\n    - name: Upload pytest test results\n      uses: actions/upload-artifact@v4\n      with:\n        name: pytest-results-${{ matrix.python-version }}\n        path: junit/test-results-${{ matrix.python-version }}.xml\n      # Use always() to always run this step to publish test results when there are test failures.\n      if: ${{ always() }}\n"
  },
  {
    "path": ".gitignore",
    "content": "/output/\n/test_configs/\n/data/\n\n# pyenv\n.python-version\n\n# VSCode\n.vscode/\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\ncover/\njunit/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\n.pybuilder/\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n#   For a library or package, you might want to ignore these files since the code is\n#   intended to run in multiple environments; otherwise, check them in:\n# .python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# poetry\n#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.\n#   This is especially recommended for binary packages to ensure reproducibility, and is more\n#   commonly ignored for libraries.\n#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control\n#poetry.lock\n\n# pdm\n#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.\n#pdm.lock\n#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it\n#   in version control.\n#   https://pdm.fming.dev/#use-with-ide\n.pdm.toml\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n\n# Cython debug symbols\ncython_debug/\n\n# PyCharm\n#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can\n#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore\n#  and can be added to the global gitignore or merged into this file.  For a more nuclear\n#  option (not recommended) you can uncomment the following to ignore the entire idea folder.\n#.idea/\n.aider*\n"
  },
  {
    "path": ".pre-commit-config.yaml",
    "content": "# See https://pre-commit.com/ for usage and config.\nrepos:\n- repo: https://github.com/astral-sh/ruff-pre-commit\n  # Ruff version.\n  rev: v0.1.7\n  hooks:\n    # Run the linter.\n    - id: ruff\n    # Run the formatter.\n    - id: ruff-format\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "# invoke-training\n\nA library for training custom Stable Diffusion models (fine-tuning, LoRA training, textual inversion, etc.) that can be used in [InvokeAI](https://github.com/invoke-ai/InvokeAI).\n\n> [!WARNING] > `invoke-training` is still under active development, and breaking changes are likely. Full backwards compatibility will not be guaranteed until v1.0.0.\n> In the meantime, I recommend pinning to a specific commit hash.\n\n## Documentation\n\n<https://invoke-ai.github.io/invoke-training/>\n\n## Training Modes\n\n- Stable Diffusion\n  - LoRA\n  - DreamBooth LoRA\n  - Textual Inversion\n- Stable Diffusion XL\n  - Full finetuning\n  - LoRA\n  - DreamBooth LoRA\n  - Textual Inversion\n  - LoRA and Textual Inversion\n\nMore training modes coming soon!\n\n## Installation\n\nSee the [Installation](https://invoke-ai.github.io/invoke-training/get-started/installation/) section of the documentation.\n\n## Quick Start\n\n`invoke-training` pipelines can be configured and launched from either the CLI or the GUI.\n\n### CLI\n\nRun training via the CLI with type-checked YAML configuration files for maximum control:\n\n```bash\ninvoke-train --cfg-file src/invoke_training/sample_configs/sdxl_textual_inversion_gnome_1x24gb.yaml\n```\n\n### GUI\n\nRun training via the GUI for a simpler starting point.\n\n```bash\ninvoke-train-ui\n\n# Or, you can optionally override the default host and port:\ninvoke-train-ui --host 0.0.0.0 --port 1234\n```\n\n## Features\n\nTraining progress can be monitored with [Tensorboard](https://www.tensorflow.org/tensorboard):\n![Screenshot of the Tensorboard UI showing validation images.](docs/images/tensorboard_val_images_screenshot.png)\n_Validation images in the Tensorboard UI._\n\nAll trained models are compatible with InvokeAI:\n\n![Screenshot of the InvokeAI UI with an example of a Yoda pokemon generated using a Pokemon LoRA model.](docs/images/invokeai_yoda_pokemon_lora.png)\n_Example image generated with the prompt \"A cute yoda pokemon creature.\" and a trained Pokemon LoRA._\n\n## Contributing\n\nContributors are welcome. For developer guidance, see the [Contributing](https://invoke-ai.github.io/invoke-training/contributing/development_environment/) section of the documentation.\n"
  },
  {
    "path": "docs/contributing/development_environment.md",
    "content": "# Development Environment Setup\n\nSee the [developer installation instructions](../get-started/installation.md#developer-installation).\n"
  },
  {
    "path": "docs/contributing/directory_structure.md",
    "content": "# Directory Structure\n\n```bash\ninvoke-training/\n├── README.md\n├── docs/\n├── src/\n│   └── invoke-training/\n│       ├── _shared/ # Utilities shared across multiple pipelines. Hight unit test coverage.\n│       ├── config/ # Config structures shared by multiple pipelines.\n│       ├── pipelines/ # Each pipeline is isolated in it's own directory with a train.py and config.py.\n│       │   ├── stable_diffusion/\n│       │   │   ├── lora/\n│       │   │   │   ├── config.py\n│       │   │   │   └── train.py\n│       │   │   └── textual_inversion/\n│       │   │       └── ...\n│       │   ├── stable_diffusion_xl/\n│       │   └── ...\n│       └── scripts/ # Main entrypoints.\n└── tests/ # Mirrors src/ directory.\n```\n"
  },
  {
    "path": "docs/contributing/documentation.md",
    "content": "# Documentation\n\nThe documentation site is generated using [mkdocs](https://www.mkdocs.org/) and [mkdocstrings-python](https://mkdocstrings.github.io/python/).\n\nTo view your documentation changes locally, run `mkdocs serve`.\n"
  },
  {
    "path": "docs/contributing/tests.md",
    "content": "# Tests\n\nRun all unit tests with:\n\n```bash\npytest tests/\n```\n\nThere are some test 'markers' defined in [pyproject.toml](https://github.com/invoke-ai/invoke-training/blob/main/pyproject.toml) that can be used to skip some tests. For example, the following command skips tests that require a GPU or require downloading model weights:\n\n```bash\npytest tests/ -m \"not cuda and not loads_model\"\n```\n"
  },
  {
    "path": "docs/get-started/installation.md",
    "content": "# Installation\n\n## Requirements\n\n1. Python 3.10, 3.11 and 3.12 are currently supported. Check your Python version by running `python -V`.\n2. An NVIDIA GPU with >= 8 GB VRAM is recommended for model training.\n\n## Basic Installation\n\n0. Open your terminal and navigate to the directory where you want to clone the `invoke-training` repo.\n1. Clone the repo:\n\n   ```bash\n   git clone https://github.com/invoke-ai/invoke-training.git\n   ```\n\n2. Create and activate a python [virtual environment](https://docs.python.org/3/library/venv.html#creating-virtual-environments). This creates an isolated environment for `invoke-training` and its dependencies that won't interfere with other python environments on your system, including any installations of [InvokeAI](https://www.github.com/invoke-ai/invokeai).\n\n   ```bash\n   # Navigate to the invoke-training directory.\n   cd invoke-training\n\n   # Create a new virtual environment named `invoketraining`.\n   python -m venv invoketraining\n\n   # Activate the new virtual environment.\n   # On Windows:\n   .\\invoketraining\\Scripts\\activate\n   # On MacOS / Linux:\n   source invoketraining/bin/activate\n   ```\n\n3. Install `invoke-training` and its dependencies. Run the appropriate install command for your system.\n\n   ```bash\n   # A recent version of pip is required, so first upgrade pip:\n   python -m pip install --upgrade pip\n\n   # Install - Windows or Linux with a Nvidia GPU:\n   pip install \".[test]\" --extra-index-url https://download.pytorch.org/whl/cu126\n\n   # Install - Linux with no GPU:\n   pip install \".[test]\" --extra-index-url https://download.pytorch.org/whl/cpu\n\n   # Install - All other systems:\n   pip install \".[test]\"\n   ```\n\nIn the future, before you run `invoke-training`, you must activate the virtual environment you created during installation, using the same command you used during installation.\n\n## Developer Installation\n\nConsider forking the repo if you plan to contribute code changes.\n\nFollow the above installation instructions, cloning your fork instead of this repo if you made a fork.\n\nNext, we suggest setting up the repo's pre-commit hooks to automatically format and lint your contributions:\n\n1. (_Optional_) Install the pre-commit hooks: `pre-commit install`. This will run static analysis tools (ruff) on `git commit`.\n2. (_Optional_) Setup `ruff` in your IDE of choice.\n"
  },
  {
    "path": "docs/get-started/quick-start.md",
    "content": "# Quick Start\n\n`invoke-training` has both a GUI and a CLI (for advanced users). The instructions for getting started with both options can be found on this page.\n\nThere is also a video introduction to `invoke-training`:\n\n<iframe width=\"560\" height=\"315\" src=\"https://www.youtube.com/embed/OZIz2vvtlM4?si=iR73F0IhlsolyYAl\" title=\"YouTube video player\" frameborder=\"0\" allow=\"accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share\" referrerpolicy=\"strict-origin-when-cross-origin\" allowfullscreen></iframe>\n\n## Quick Start - GUI\n\n### 1. Installation\n\nFollow the [`invoke-training` installation instructions](./installation.md).\n\n### 2. Launch the GUI\n\nActivate the virtual environment you created during installation, using the same command you used during installation.\n\nYou'll need to do this every time you run `invoke-training`.\n\n```bash\n# From the invoke-training directory:\ninvoke-train-ui\n\n# Or, you can optionally override the default host and port:\ninvoke-train-ui --host 0.0.0.0 --port 1234\n```\n\nAccess the GUI in your browser at the URL printed to the console.\n\n### 3. Configure the training job\n\nSelect the desired training pipeline type in the top-level tab.\n\nFor this tutorial, we don't need to change any of the configuration values. The preset configuration should work well.\n\n### 4. Generate the YAML configuration\n\nClick on 'Generate Config' to generate a YAML configuration file. This YAML configuration file could be used to launch the training job from the CLI, if desired.\n\n### 5. Start training\n\nClick on the 'Start Training' and check your terminal for progress logs.\n\n### 6. Monitor training\n\nMonitor the training process with Tensorboard by running `tensorboard --logdir output/` and visiting [localhost:6006](http://localhost:6006) in your browser. Here you can see generated validation images throughout the training process.\n\n![Screenshot of the Tensorboard UI showing validation images.](../images/tensorboard_val_images_screenshot.png)\n_Validation images in the Tensorboard UI._\n\n### 7. Invokeai\n\nSelect a checkpoint based on the quality of the generated images.\n\nIf you haven't already, setup [InvokeAI](https://github.com/invoke-ai/InvokeAI) by following its documentation.\n\nCopy your selected LoRA checkpoint into your `${INVOKEAI_ROOT}/autoimport/lora` directory. For example:\n\n```bash\n# Note: You will have to replace the timestamp in the checkpoint path.\ncp output/1691088769.5694647/checkpoint_epoch-00000002.safetensors ${INVOKEAI_ROOT}/autoimport/lora/pokemon_epoch-00000002.safetensors\n```\n\nYou can now use your trained Pokemon LoRA in the InvokeAI UI! 🎉\n\n![Screenshot of the InvokeAI UI with an example of a Yoda pokemon generated using a Pokemon LoRA model.](../images/invokeai_yoda_pokemon_lora.png)\n_Example image generated with the prompt \"A cute yoda pokemon creature.\" and Pokemon LoRA._\n\n## Quick Start - CLI\n\n### 1. Installation\n\nFollow the [`invoke-training` installation instructions](./installation.md).\n\n### 2. Training\n\nActivate the virtual environment you created during installation, using the same command you used during installation.\n\nYou'll need to do this every time you run `invoke-training`.\n\nSee the [Textual Inversion - SDXL](../guides/stable_diffusion/textual_inversion_sdxl.md) tutorial for instructions on how to train a model via the CLI.\n"
  },
  {
    "path": "docs/guides/dataset_formats.md",
    "content": "# Dataset Formats\n\n`invoke-training` supports the following dataset formats:\n\n- `IMAGE_CAPTION_JSONL_DATASET`: A local image-caption dataset described by a single `.jsonl` file.\n- `IMAGE_CAPTION_DIR_DATASET`: A local directory of images with associated `.txt` caption files.\n- `IMAGE_DIR_DATASET`: A local directory of images (without captions).\n- `HF_HUB_IMAGE_CAPTION_DATASET`: A Hugging Face Hub dataset containing images and captions.\n\nSee the documentation for a particular training pipeline to see which dataset formats it supports.\n\nThe following sections explain each of these formats in more detail.\n\n## `IMAGE_CAPTION_JSONL_DATASET`\n\nConfig documentation: [ImageCaptionJsonlDatasetConfig][invoke_training.config.data.dataset_config.ImageCaptionJsonlDatasetConfig]\n\nA `IMAGE_CAPTION_JSONL_DATASET` consists of a single `.jsonl` file containing image paths and associated captions.\n\nSample directory structure:\n```bash\nmy_custom_dataset/\n├── data.jsonl\n└── train/\n    ├── 0001.png\n    ├── 0002.png\n    ├── 0003.png\n    └── ...\n```\n\nThe contents of `data.jsonl` would be:\n```json\n{\"file_name\": \"train/0001.png\", \"text\": \"This is a caption describing image 0001.\"}\n{\"file_name\": \"train/0002.png\", \"text\": \"This is a caption describing image 0002.\"}\n{\"file_name\": \"train/0003.png\", \"text\": \"This is a caption describing image 0003.\"}\n```\n\nThe image file paths can be either absolute paths, or relative to the `.jsonl` file.\n\nFinally, this dataset can be used with the following pipeline dataset configuration:\n```yaml\ntype: IMAGE_CAPTION_JSONL_DATASET\njsonl_path: /path/to/my_custom_dataset/metadata.jsonl\nimage_column: file_name\ncaption_column: text\n```\n\nA useful characteristic of this dataset format is that a `.jsonl` file can reference an image file anywhere on the local disk. It is common to maintain multiple `.jsonl` datasets that reference some of the same images without needing multiple copies of those images on disk.\n\n## `IMAGE_CAPTION_DIR_DATASET`\n\nConfig documentation: [ImageCaptionDirDataset][invoke_training.config.data.dataset_config.ImageCaptionDirDatasetConfig]\n\nA `IMAGE_CAPTION_DIR_DATASET` consists of a directory of image files and corresponding `.txt` caption files of the same name.\n\nSample directory structure:\n```bash\nmy_custom_dataset/\n├── 0001.png\n├── 0001.txt\n├── 0002.jpg\n├── 0002.txt\n├── 0003.png\n├── 0003.txt\n└── ...\n```\n\nEach `.txt` file should contain a caption on the first line of the file. Here are the sample contents of `0001.txt`:\n```txt title=\"0001.txt\"\nthis is a caption for example 0001\n```\n\nThis dataset can be used with the following pipeline dataset configuration:\n```yaml\ntype: IMAGE_CAPTION_DIR_DATASET\ndataset_dir: /path/to/my_custom_dataset\n```\n\n## `IMAGE_DIR_DATASET`\n\nConfig documentation: [ImageDirDataset][invoke_training.config.data.dataset_config.ImageDirDatasetConfig]\n\nA `IMAGE_DIR_DATASET` consists of a single directory of images (without captions).\n\nSample directory structure:\n```bash\nmy_custom_dataset/\n├── 0001.png\n├── 0002.jpg\n├── 0003.png\n└── ...\n```\n\nThis dataset can be used with the following pipeline dataset configuration:\n```yaml\ntype: IMAGE_DIR_DATASET\ndataset_dir: /path/to/my_custom_dataset\n```\n\n## `HF_HUB_IMAGE_CAPTION_DATASET`\n\nConfig documentation: [HFHubImageCaptionDatasetConfig][invoke_training.config.data.dataset_config.HFHubImageCaptionDatasetConfig]\n\nThe `HF_HUB_IMAGE_CAPTION_DATASET` dataset format can be used to access publicly datasets on the [Hugging Face Hub](https://huggingface.co/datasets). You can filter for the `Text-to-Image` task to find relevant datasets that contain both an image column and a caption column. [lambdalabs/pokemon-blip-captions](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions) is a popular choice if you're not sure where to start.\n"
  },
  {
    "path": "docs/guides/model_merge.md",
    "content": "# Model Merging\n\n`invoke-training` provides utility scripts for several common model merging workflows. This page contains a summary of the available tools.\n\n## `extract_lora_from_model_diff.py`\n\nExtract a LoRA model that represents the difference between two base models.\n\nNote that the extracted LoRA model is a lossy representation of the difference between the models, so some degradation in quality is expected.\n\nFor usage docs, run:\n```bash\npython src/invoke_training/model_merge/scripts/extract_lora_from_model_diff.py -h\n```\n\n## `merge_lora_into_model.py`\n\nMerge a LoRA model into a base model to produce a new base model.\n\nFor usage docs, run:\n```bash\npython src/invoke_training/model_merge/scripts/merge_lora_into_model.py -h\n```\n\n## `merge_models.py`\n\nMerge 2 or more base models to produce a single base model (using either LERP or SLERP). This is a simple merge strategy that merges all model weights in the same way.\n\nFor usage docs, run:\n```bash\npython src/invoke_training/model_merge/scripts/merge_models.py -h\n```\n\n## `merge_task_models_to_base_model.py`\n\nMerge 1 or more task-specific base models into a single starting base model (using either [TIES](https://arxiv.org/abs/2306.01708) or [DARE](https://arxiv.org/abs/2311.03099)). This merge strategy aims to preserve the task-specific behaviors of the task models while making only small changes to the original base model. This approach enables multiple task models to be merged without excessive interference between them.\n\nIf you want to merge a task-specific LoRA into a base model using this strategy, first use `merge_lora_into_model.py` to produce a task-specific base model, then merge that new base model using this strategy.\n\nFor usage docs, run:\n```bash\npython src/invoke_training/model_merge/scripts/merge_task_models_to_base_model.py -h\n```\n"
  },
  {
    "path": "docs/guides/stable_diffusion/dpo_lora_sd.md",
    "content": "# (Experimental) Diffusion DPO - SD\n\n!!! tip \"Experimental\"\n    The Diffusion Direct Preference Optimization training pipeline is still experimental. Support may be dropped at any time.\n\nThis tutorial walks through some initial experiments around using Diffusion Direct Preference Optimization (DPO) ([paper](https://arxiv.org/abs/2311.12908)) to train Stable Diffusion LoRA models.\n\n\n## Experiment 1: `pickapic_v2` LoRA Training\n\nThe Diffusion-DPO paper does full model fine-tuning on the [pickapic_v2](https://huggingface.co/datasets/yuvalkirstain/pickapic_v2) dataset, which consists of roughly 1M AI-generated image pairs with preference annotations. In this experiment, we attempt to fine-tune a Stable Diffusion LoRA model using a small subset of the pickapic_v2 dataset.\n\nRun this experiment with the following command:\n```bash\ninvoke-train -c src/invoke_training/sample_configs/_experimental/sd_dpo_lora_pickapic_1x24gb.yaml\n```\n\nHere is a cherry-picked example of a prompt for which this training process was clearly beneficial.\nPrompt: \"*A galaxy-colored figurine is floating over the sea at sunset, photorealistic*\"\n\n| Before DPO Training | After DPO Training (same seed)|\n| - | - |\n| ![Sample image before DPO training.](../../images/dpo/before_dpo.jpg) | ![Sample image after DPO training.](../../images/dpo/after_dpo.jpg) |\n\n## Experiment 2: LoRA Model Refinement\n\nAs a second experiment, we attempt the following workflow:\n\n1. Train a Stable Diffusion LoRA model on a particular style.\n2. Generate pairs of images of the character with the trained LoRA model.\n3. Annotate the preferred image from each pair.\n4. Apply Diffusion-DPO to the preference-annotated pairs to further fine-tune the LoRA model.\n\nNote: The steps listed below are pretty rough. They are included primarily for reference for someone looking to resume this line of work in the future.\n\n### 1. Train a style LoRA\n\n```bash\ninvoke-train -c src/invoke_training/sample_configs/sd_lora_pokemon_1x8gb.yaml\n```\n\n### 2. Generate images\n\nPrepare ~100 relevant prompts that will be used to generate training data with the freshly-trained LoRA model. Add the prompts to a `.txt` file - one prompt per line.\n\nExample prompts:\n```txt\na cute orange pokemon character with pointy ears\na drawing of a purple fish\na cartoon blob with a smile on its face\na drawing of a snail with big eyes\n...\n```\n\n```bash\n# Convert the LoRA checkpoint of interest to Kohya format.\n# You will have to change the path timestamps in this example command.\n# TODO(ryand): This manual conversion shouldn't be necessary.\npython src/invoke_training/scripts/convert_sd_lora_to_kohya_format.py \\\n  --src-ckpt-dir output/sd_lora_pokemon/1704824279.2765746/checkpoint_epoch-00000003/ \\\n  --dst-ckpt-file output/sd_lora_pokemon/1704824279.2765746/checkpoint_epoch-00000003_kohya.safetensors\n\n# Generate 2 pairs of images for each prompt.\ninvoke-generate-images \\\n  -o output/pokemon_pairs \\\n  -m runwayml/stable-diffusion-v1-5 \\\n  -v fp16 \\\n  -l output/sd_lora_pokemon/1704824279.2765746/checkpoint_epoch-00000003_kohya.safetensors \\\n  --sd-version SD \\\n  --prompt-file path/to/prompts.txt \\\n  --set-size 2 \\\n  --num-sets 2 \\\n  --height 512 \\\n  --width 512\n```\n\n### 3. Annotate the image pair preferences\n\nLaunch the gradio UI for selecting image pair preferences.\n\n```bash\n# Note: rank_images.py accepts a full training pipeline config, but only uses the dataset configuration.\npython src/invoke_training/scripts/_experimental/rank_images.py -c src/invoke_training/sample_configs/_experimental/sd_dpo_lora_refinement_pokemon_1x24gb.yaml\n```\n\nAfter completing the pair annotations, click \"Save Metadata\" and move the resultant metadata file to your image data directory (e.g. `output/pokemon_pairs/metadata.jsonl`).\n\n### 4. Run Diffusion-DPO\n\n```bash\ninvoke-train -c src/invoke_training/sample_configs/_experimental/sd_dpo_lora_refinement_pokemon_1x24gb.yaml\n```"
  },
  {
    "path": "docs/guides/stable_diffusion/gnome_lora_masks_sdxl.md",
    "content": "# LoRA with Masks - SDXL\n\nThis tutorial explains how to prepare masks for an image dataset and then use that dataset to train an SDXL LoRA model.\n\nMasks can be used to weight regions of images in a dataset to control how much they contribute to the training process. In this tutorial we will use masks to train on a small dataset of images of Bruce the Gnome (4 images). With such a small dataset, there is a high risk of overfitting to the background elements from the images. We will use masks to avoid this problem ond focus only on the object of interest.\n\n## 1 - Dataset Preparation\n\nFor this tutorial, we'll use a dataset consisting of 4 images of Bruce the Gnome:\n\n| | |\n| - | - |\n| ![bruce_the_gnome dataset image 1.](../../images/bruce_the_gnome/001.jpg) | ![bruce_the_gnome dataset image 2.](../../images/bruce_the_gnome/002.jpg) |\n| ![bruce_the_gnome dataset image 3.](../../images/bruce_the_gnome/003.jpg) | ![bruce_the_gnome dataset image 4.](../../images/bruce_the_gnome/004.jpg) |\n\nThis sample dataset is included in the invoke-training repo under [sample_data/bruce_the_gnome](https://github.com/invoke-ai/invoke-training/tree/main/sample_data/bruce_the_gnome).\n\n## 2 - Generate Masks\n\nUse the `generate_masks_for_jsonl_dataset.py` script to generate masks for your dataset based on a single prompt. In this case we are using the prompt `\"a stuffed gnome\"`:\n```bash\npython src/invoke_training/scripts/_experimental/masks/generate_masks_for_jsonl_dataset.py \\\n  --in-jsonl sample_data/bruce_the_gnome/data.jsonl \\\n  --out-jsonl sample_data/bruce_the_gnome/data_masks.jsonl \\\n  --prompt \"a stuffed gnome\"\n```\n\nThe mask generation script will produce the following outputs:\n\n- A directory of generated masks: `sample_data/bruce_the_gnome/masks/`\n- A new `.jsonl` file that references the mask images: `sample_data/bruce_the_gnome/data_masks.jsonl`\n\n## 3 - Review the Generated Masks\n\nReview the generated masks to make sure that the target regions were masked. You may need to adjust the prompt and re-generate the masks to achieve the desired result. Alternatively, you can edit the masks manually. The masks are simply single-channel grayscale images (0=background, 255=foreground).\n\nHere are some examples of the masks that we just generated:\n\n| | |\n| - | - |\n| ![bruce_the_gnome dataset image 1.](../../images/bruce_the_gnome/001.jpg) | ![bruce_the_gnome dataset image 1 mask.](../../images/bruce_masks/001_mask.png) |\n| ![bruce_the_gnome dataset image 2.](../../images/bruce_the_gnome/002.jpg) | ![bruce_the_gnome dataset image 2 mask.](../../images/bruce_masks/002_mask.png) |\n\n## 4 - Configuration\n\nBelow is the training configuration that we'll use for this tutorial.\n\nRaw config file: [src/invoke_training/sample_configs/sdxl_lora_masks_gnome_1x24gb.yaml](https://github.com/invoke-ai/invoke-training/blob/main/src/invoke_training/sample_configs/sdxl_lora_masks_gnome_1x24gb.yaml).\n\n\n```yaml title=\"sdxl_lora_masks_gnome_1x24gb.yaml\"\n--8<-- \"src/invoke_training/sample_configs/sdxl_lora_masks_gnome_1x24gb.yaml\"\n```\n\nFull documentation of all of the configuration options is here: [LoRA SDXL Config](../../reference/config/pipelines/sdxl_lora.md)\n\nThere are few things to note about this training config:\n\n- We set `use_masks: True` in order to use the masks that we generated. This configuration is only compatible with datasets that have mask data.\n- The `learning_rate`, `max_train_steps`, `save_every_n_steps`, and `validate_every_n_steps` are all _lower_ than typical for an SDXL LoRA training pipeline. The combination of masking with the small dataset size cause training to progress very quickly. These configuration fields were all adjusted accordingly to avoid overfitting.\n\n## 5 - Start Training\n\nLaunch the training run.\n```bash\n# From inside the invoke-training/ source directory:\ninvoke-train -c src/invoke_training/sample_configs/sdxl_lora_masks_gnome_1x24gb.yaml\n```\n\nTraining takes ~30 mins on an NVIDIA RTX 4090.\n\n## 4 - Monitor\n\nIn a new terminal, launch Tensorboard to monitor the training run:\n```bash\ntensorboard --logdir output/\n```\nAccess Tensorboard at [localhost:6006](http://localhost:6006) in your browser.\n\nSample images will be logged to Tensorboard so that you can see how the model is evolving.\n\nOnce training is complete, select the model checkpoint that produces the best visual results. For this tutorial, we'll use the checkpoint from step 300:\n\n![Screenshot of the Tensorboard UI showing the validation images for step 300.](../../images/bruce_masks/bruce_masks_step_300.jpg)\n*Screenshot of the Tensorboard UI showing the validation images for epoch 300. The validation prompt was: \"A stuffed gnome at the beach with a pina colada in its hand.\".*\n\n\n## 6 - Import into InvokeAI\n\nIf you haven't already, setup [InvokeAI](https://github.com/invoke-ai/InvokeAI) by following its documentation.\n\nImport your trained LoRA model from the 'Models' tab.\n\nCongratulations, you can now use your new Bruce-the-Gnome model! 🎉\n"
  },
  {
    "path": "docs/guides/stable_diffusion/robocats_finetune_sdxl.md",
    "content": "# Finetune - SDXL\n\nThis tutorial explains how to do a full finetune training run on a [Stable Diffusion XL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) base model.\n\n## 0 - Prerequisites\n\nFull model finetuning is more compute-intensive than parameter-efficient finetuning alternatives (e.g. LoRA or Textual Inversion). This tutorial requires a minimum of 24GB of GPU VRAM.\n\n## 1 - Dataset Preparation\n\nFor this tutorial, we will use a dataset consisting of 14 images of robocats. The images were auto-captioned. Here are some sample images from the dataset, including their captions:\n\n| | |\n| - | - |\n| ![A white robot with blue eyes and a yellow nose sits on a rock, gazing at the camera, with a pink tree and a white cat in the background.](../../images/robocats/sipu3h70yb87rju8a8l36ejr.jpg) | ![A white cat with green eyes and a blue collar sits on a moss-covered rock in a forest, gazing directly at the camera.](../../images/robocats/v2h3ld50bi9owhhzo9gf9utg.jpg) |\n| *A white robot with blue eyes and a yellow nose sits on a rock, gazing at the camera, with a pink tree and a white cat in the background.* | *A white cat with green eyes and a blue collar sits on a moss-covered rock in a forest, gazing directly at the camera.* |\n\n## 2 - Configuration\n\nBelow is the training configuration that we'll use for this tutorial.\n\nRaw config file: [src/invoke_training/sample_configs/sdxl_finetune_robocats_1x24gb.yaml](https://github.com/invoke-ai/invoke-training/blob/main/src/invoke_training/sample_configs/sdxl_finetune_robocats_1x24gb.yaml).\n\n\n```yaml title=\"sdxl_finetune_robocats_1x24gb.yaml\"\n--8<-- \"src/invoke_training/sample_configs/sdxl_finetune_robocats_1x24gb.yaml\"\n```\n\nFull documentation of all of the configuration options is here: [Finetune SDXL Config](../../reference/config/pipelines/sdxl_finetune.md)\n\n!!! note \"`save_checkpoint_format`\"\n    Note the `save_checkpoint_format` setting, as it is unique to full finetune training. For this tutorial, we have set `save_checkpoint_format: trained_only_diffusers`. This means that only the UNet model will be saved at each checkpoint, and it will be saved in diffusers format. This setting conserves disk space by not redundantly saving the non-trained weights. Before these UNet checkpoints can be used, they must either be merged into a full model, or extracted into a LoRA. Instructions for this follow later in this tutorial. A full explanation of the `save_checkpoint_format` options can be found here:  [save_checkpoint_format][invoke_training.pipelines.stable_diffusion_xl.finetune.config.SdxlFinetuneConfig.save_checkpoint_format].\n\n\n## 3 - Start Training\n\nLaunch the training run.\n```bash\n# From inside the invoke-training/ source directory:\ninvoke-train -c src/invoke_training/sample_configs/sdxl_finetune_robocats_1x24gb.yaml\n```\n\nTraining takes ~45 mins on an NVIDIA RTX 4090.\n\n## 4 - Monitor\n\nIn a new terminal, launch Tensorboard to monitor the training run:\n```bash\ntensorboard --logdir output/\n```\nAccess Tensorboard at [localhost:6006](http://localhost:6006) in your browser.\n\nSample images will be logged to Tensorboard so that you can see how the model is evolving.\n\nOnce training is complete, select the model checkpoint that produces the best visual results.\n\n## 5 - Prepare the trained model\n\nSince we set `save_checkpoint_format: trained_only_diffusers`, our selected checkpoint only contains the UNet model weights. The checkpoint has the following directory structure:\n\n```bash\noutput/robocats/sdxl_finetune/1715373799.3558652/checkpoints/checkpoint-epoch_00000500-step_00002000/\n└── unet\n    ├── config.json\n    └── diffusion_pytorch_model.safetensors\n```\n\nBefore we can use this trained model, we must do one of the following:\n\n- Prepare a full diffusers checkpoint with the new UNet weights.\n- Extract the difference between the trained UNet and the original UNet into a LoRA model.\n\n### Prepare a full model\n\nIf we want to use our finetuned UNet model, we must first package it into a format supported by applications like InvokeAI.\n\nIn this section we will assume that we have a full SDXL base model in diffusers format. It should have a directory structure like the one shown before. We simply need to replace the `unet/` directory with the one from our selected training checkpoint:\n```bash\nstable-diffusion-xl-base-1.0\n├── model_index.json\n├── scheduler\n│   └── scheduler_config.json\n├── text_encoder\n│   ├── config.json\n│   └── model.fp16.safetensors\n├── text_encoder_2\n│   ├── config.json\n│   └── model.fp16.safetensors\n├── tokenizer\n│   ├── merges.txt\n│   ├── special_tokens_map.json\n│   ├── tokenizer_config.json\n│   └── vocab.json\n├── tokenizer_2\n│   ├── merges.txt\n│   ├── special_tokens_map.json\n│   ├── tokenizer_config.json\n│   └── vocab.json\n├── unet # <-- Replace this directory with the trained checkpoint.\n│   ├── config.json\n│   └── diffusion_pytorch_model.fp16.safetensors\n├── vae\n│   ├── config.json\n│   └── diffusion_pytorch_model.fp16.safetensors\n└── vae_1_0\n    └── diffusion_pytorch_model.fp16.safetensors\n```\n\n!!! note \"diffusers variants (e.g. 'fp16')\"\n    In this example, notice that the `*.safetensors` files contain `.fp16.` in their filenames. Hugging Face refers to this identifier as a \"variant\". It is used to select between multiple model variants in their model hub.\n    In this case, we should add the `.fp16.` variant tag to our finetuned UNet for consistency with the rest of the model. Since we set `save_dtype: float16` in our training config, the `fp16` tag accurately represents the precision of our UNet model file.\n\n### Extract a LoRA model\n\nAn alternative to using the finetuned UNet model directly is to compare it against the original and extract the difference as a LoRA model. The resultant LoRA has a much smaller file size and can be applied to any base model. But, the LoRA model is a *lossy* representation of the difference, so some quality degradation is expected.\n\nTo extract a LoRA model, run the following command:\n```bash\npython src/invoke_training/model_merge/scripts/extract_lora_from_model_diff.py \\\n  --model-type SDXL \\\n  --model-orig path/to/stable-diffusion-xl-base-1.0 \\\n  --model-tuned output/robocats/sdxl_finetune/1715373799.3558652/checkpoints/checkpoint-epoch_00000500-step_00002000 \\\n  --save-to robocats_lora_step_2000.safetensors \\\n  --lora-rank 32\n```\n\n## 6 - Import into InvokeAI\n\nIf you haven't already, setup [InvokeAI](https://github.com/invoke-ai/InvokeAI) by following its documentation.\n\nImport your finetuned diffusers model or your extracted LoRA from the 'Models' tab.\n\nCongratulations, you can now use your new robocat model! 🎉\n\n## 7 - Comparison: Finetune vs. LoRA Extraction\n\nAs noted earlier, the LoRA extraction process is lossy for a number of reasons.\n\nBelow, we compare images generated with the same seed and prompt for 3 different model configurations.\n\nPrompt: *In robocat style, a robotic lion in the jungle.*\n\n| SDXL Base 1.0 | w/ Finetuned UNet | w/ Extracted LoRA |\n| - | - | - |\n| ![Image generated with SDXL Base 1.0. Prompt: In robocat style, a robotic lion in the jungle.](../../images/robocats/lion_base.jpg) | ![Image generated with finetuned UNet. Prompt: In robocat style, a robotic lion in the jungle.](../../images/robocats/lion_finetuned.jpg) | ![Image generated with extracted LoRA. Prompt: In robocat style, a robotic lion in the jungle.](../../images/robocats/lion_extracted_lora.jpg)\n"
  },
  {
    "path": "docs/guides/stable_diffusion/textual_inversion_sdxl.md",
    "content": "# Textual Inversion - SDXL\n\nThis tutorial walks through a [Textual Inversion](https://arxiv.org/abs/2208.01618) training run with a [Stable Diffusion XL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) base model.\n\n## 1 - Dataset\n\nFor this tutorial, we'll use a dataset consisting of 4 images of Bruce the Gnome:\n\n| | |\n| - | - |\n| ![bruce_the_gnome dataset image 1.](../../images/bruce_the_gnome/001.jpg) | ![bruce_the_gnome dataset image 2.](../../images/bruce_the_gnome/002.jpg) |\n| ![bruce_the_gnome dataset image 3.](../../images/bruce_the_gnome/003.jpg) | ![bruce_the_gnome dataset image 4.](../../images/bruce_the_gnome/004.jpg) |\n\nThis sample dataset is included in the invoke-training repo under [sample_data/bruce_the_gnome](https://github.com/invoke-ai/invoke-training/tree/main/sample_data/bruce_the_gnome).\n\nHere are a few tips for preparing a Textual Inversion dataset:\n\n- Aim for 4 to 50 images of your concept (object / style). The optimal number depends on many factors, and can be much higher than this for some use cases.\n- Vary all of the image features that you *don't* want your TI embedding to contain (e.g. background, pose, lighting, etc.).\n\n## 2 - Configuration\n\nBelow is the training configuration that we'll use for this tutorial.\n\nRaw config file: [src/invoke_training/sample_configs/sdxl_textual_inversion_gnome_1x24gb.yaml](https://github.com/invoke-ai/invoke-training/blob/main/src/invoke_training/sample_configs/sdxl_textual_inversion_gnome_1x24gb.yaml).\n\nFull config reference docs: [Textual Inversion SDXL Config](../../reference/config/pipelines/sdxl_textual_inversion.md)\n\n```yaml title=\"sdxl_textual_inversion_gnome_1x24gb.yaml\"\n--8<-- \"src/invoke_training/sample_configs/sdxl_textual_inversion_gnome_1x24gb.yaml\"\n```\n\n## 3 - Start Training\n\n[Install invoke-training](../../get-started/installation.md), if you haven't already.\n\nLaunch the Textual Inversion training pipeline:\n```bash\n# From inside the invoke-training/ source directory:\ninvoke-train -c src/invoke_training/sample_configs/sdxl_textual_inversion_gnome_1x24gb.yaml\n```\n\nTraining takes ~40 mins on an NVIDIA RTX 4090.\n\n## 4 - Monitor\n\nIn a new terminal, launch Tensorboard to monitor the training run:\n```bash\ntensorboard --logdir output/\n```\nAccess Tensorboard at [localhost:6006](http://localhost:6006) in your browser.\n\nSample images will be logged to Tensorboard so that you can see how the Textual Inversion embedding is evolving.\n\nOnce training is complete, select the epoch that produces the best visual results.\n\nFor this tutorial, we'll choose epoch 500:\n![Screenshot of the Tensorboard UI showing the validation images for epoch 500.](../../images/tensorboard_bruce_the_gnome_epoch_500.png)\n*Screenshot of the Tensorboard UI showing the validation images for epoch 500.*\n\n## 5 - Transfer to InvokeAI\n\nIf you haven't already, setup [InvokeAI](https://github.com/invoke-ai/InvokeAI) by following its documentation.\n\nCopy the selected TI embedding into your `${INVOKEAI_ROOT}/autoimport/embedding/` directory. For example:\n```bash\ncp output/sdxl_ti_bruce_the_gnome/1702587511.2273068/checkpoint_epoch-00000500.safetensors ${INVOKEAI_ROOT}/autoimport/embedding/bruce_the_gnome.safetensors\n```\n\nNote that we renamed the file to `bruce_the_gnome.safetensors`. You can choose any file name, but this will become the token used to reference your embedding. So, in our case, we can refer to our new embedding by including `<bruce_the_gnome>` in our prompts.\n\nLaunch Invoke AI and you can now use your new `bruce_the_gnome` TI embedding! 🎉\n\n![Screenshot of the InvokeAI UI with an example of an image generated with the bruce_the_gnome TI embedding.](../../images/invokeai_bruce_the_gnome_ti.png)\n*Example image generated with the prompt \"`a photo of <bruce_the_gnome> at the park`\".*\n"
  },
  {
    "path": "docs/index.md",
    "content": "# invoke-training\n\nA library for training custom Stable Diffusion models (fine-tuning, LoRA training, textual inversion, etc.) that can be used in [InvokeAI](https://github.com/invoke-ai/InvokeAI).\n\n## Documentation\n\nThe documentation is organized as follows:\n\n- [Get Started](get-started/installation.md): Install `invoke-training` and run your first training pipeline.\n- [Guides](guides/dataset_formats.md): Full tutorials for running popular training pipelines.\n- [Config Reference](reference/config/index.md): Reference documentation for all supported training configuration options.\n- [Contributing](contributing/development_environment.md): Information for `invoke-training` developers.\n"
  },
  {
    "path": "docs/reference/config/index.md",
    "content": "# Config Reference\n\nThis section contains reference documentation for the `invoke-training` configuration schema (i.e. documentation for all of the supported training options).\n\nThis documentation uses python typing semantics to define the configuration schema. Typically the configuration for a training run is specified in a YAML file and then parse against this schema.\n"
  },
  {
    "path": "docs/reference/config/pipelines/sd_lora.md",
    "content": "# `SdLoraConfig`\n\n<!-- To control the member order, we first list out the members whose order we care about, then we list the rest. -->\n::: invoke_training.pipelines.stable_diffusion.lora.config.SdLoraConfig\n    options:\n      members:\n      - type\n\n<!-- Note that we always hide \"model_config\", as it should not be set by the user. -->\n::: invoke_training.pipelines.stable_diffusion.lora.config.SdLoraConfig\n    options:\n      filters:\n      - \"!^model_config\"\n      - \"!^type\""
  },
  {
    "path": "docs/reference/config/pipelines/sd_textual_inversion.md",
    "content": "# `SdTextualInversionConfig`\n\n<!-- To control the member order, we first list out the members whose order we care about, then we list the rest. -->\n::: invoke_training.pipelines.stable_diffusion.textual_inversion.config.SdTextualInversionConfig\n    options:\n      members:\n      - type\n\n<!-- Note that we always hide \"model_config\", as it should not be set by the user. -->\n::: invoke_training.pipelines.stable_diffusion.textual_inversion.config.SdTextualInversionConfig\n    options:\n      filters:\n      - \"!^model_config\"\n      - \"!^type\"\n"
  },
  {
    "path": "docs/reference/config/pipelines/sdxl_finetune.md",
    "content": "# `SdxlFinetuneConfig`\n\n<!-- To control the member order, we first list out the members whose order we care about, then we list the rest. -->\n::: invoke_training.pipelines.stable_diffusion_xl.finetune.config.SdxlFinetuneConfig\n    options:\n      members:\n      - type\n\n<!-- Note that we always hide \"model_config\", as it should not be set by the user. -->\n::: invoke_training.pipelines.stable_diffusion_xl.finetune.config.SdxlFinetuneConfig\n    options:\n      filters:\n      - \"!^model_config\"\n      - \"!^type\"\n"
  },
  {
    "path": "docs/reference/config/pipelines/sdxl_lora.md",
    "content": "# `SdxlLoraConfig`\n\n<!-- To control the member order, we first list out the members whose order we care about, then we list the rest. -->\n::: invoke_training.pipelines.stable_diffusion_xl.lora.config.SdxlLoraConfig\n    options:\n      members:\n      - type\n\n<!-- Note that we always hide \"model_config\", as it should not be set by the user. -->\n::: invoke_training.pipelines.stable_diffusion_xl.lora.config.SdxlLoraConfig\n    options:\n      filters:\n      - \"!^model_config\"\n      - \"!^type\""
  },
  {
    "path": "docs/reference/config/pipelines/sdxl_lora_and_textual_inversion.md",
    "content": "# `SdxlLoraAndTextualInversionConfig`\n\n<!-- To control the member order, we first list out the members whose order we care about, then we list the rest. -->\n::: invoke_training.pipelines.stable_diffusion_xl.lora_and_textual_inversion.config.SdxlLoraAndTextualInversionConfig\n    options:\n      members:\n      - type\n\n<!-- Note that we always hide \"model_config\", as it should not be set by the user. -->\n::: invoke_training.pipelines.stable_diffusion_xl.lora_and_textual_inversion.config.SdxlLoraAndTextualInversionConfig\n    options:\n      filters:\n      - \"!^model_config\"\n      - \"!^type\""
  },
  {
    "path": "docs/reference/config/pipelines/sdxl_textual_inversion.md",
    "content": "# `SdxlTextualInversionConfig`\n\nBelow is a sample yaml config file for Textual Inversion SDXL training ([raw file](https://github.com/invoke-ai/invoke-training/blob/main/src/invoke_training/sample_configs/sdxl_textual_inversion_gnome_1x24gb.yaml)). All of the configuration fields are explained in detail on this page.\n\n```yaml title=\"sdxl_textual_inversion_gnome_1x24gb.yaml\"\n--8<-- \"src/invoke_training/sample_configs/sdxl_textual_inversion_gnome_1x24gb.yaml\"\n```\n\n<!-- To control the member order, we first list out the members whose order we care about, then we list the rest. -->\n::: invoke_training.pipelines.stable_diffusion_xl.textual_inversion.config.SdxlTextualInversionConfig\n    options:\n      members:\n      - type\n\n<!-- Note that we always hide \"model_config\", as it should not be set by the user. -->\n::: invoke_training.pipelines.stable_diffusion_xl.textual_inversion.config.SdxlTextualInversionConfig\n    options:\n      filters:\n      - \"!^model_config\"\n      - \"!^type\"\n"
  },
  {
    "path": "docs/reference/config/shared/data/data_loader_config.md",
    "content": "::: invoke_training.config.data.data_loader_config\n    options:\n      filters:\n      - \"!^model_config\"\n"
  },
  {
    "path": "docs/reference/config/shared/data/dataset_config.md",
    "content": "::: invoke_training.config.data.dataset_config\n    options:\n      filters:\n      - \"!^model_config\""
  },
  {
    "path": "docs/reference/config/shared/optimizer_config.md",
    "content": "::: invoke_training.config.optimizer.optimizer_config\n    options:\n      filters:\n      - \"!^model_config\"\n"
  },
  {
    "path": "docs/templates/python/material/labels.html",
    "content": "<!--\n    This file is intentionally empty. It overrides the default contents of\n    https://github.com/mkdocstrings/python/blob/master/src/mkdocstrings_handlers/python/templates/material/labels.html\n    to hide labels (class-attribute, instance-attribute, etc.)\n-->\n"
  },
  {
    "path": "mkdocs.yml",
    "content": "site_name: invoke-training\nsite_url: https://invoke-ai.github.io/invoke-training/\n\nrepo_name: invoke-ai/invoke-training\nrepo_url: https://github.com/invoke-ai/invoke-training\n\ntheme:\n  name: material\n  features:\n    - navigation.tabs\n    - navigation.indexes\n    - navigation.sections\n    - content.code.copy\n\nmarkdown_extensions:\n  - admonition\n  - sane_lists\n  - pymdownx.highlight:\n      anchor_linenums: true\n      line_spans: __span\n      pygments_lang_class: true\n  - pymdownx.inlinehilite\n  - pymdownx.snippets\n  - pymdownx.superfences\n\nnav:\n  - Welcome: index.md\n  - Get Started:\n      - get-started/installation.md\n      - get-started/quick-start.md\n  - Guides:\n      - Dataset Formats: guides/dataset_formats.md\n      - Model Merging: guides/model_merge.md\n      - Stable Diffusion Training:\n          - guides/stable_diffusion/robocats_finetune_sdxl.md\n          - guides/stable_diffusion/gnome_lora_masks_sdxl.md\n          - guides/stable_diffusion/textual_inversion_sdxl.md\n          - guides/stable_diffusion/dpo_lora_sd.md\n  - YAML Config Reference:\n      - reference/config/index.md\n      - pipelines:\n          - SD LoRA Config: reference/config/pipelines/sd_lora.md\n          - SD Textual Inversion Config: reference/config/pipelines/sd_textual_inversion.md\n          - SDXL LoRA Config: reference/config/pipelines/sdxl_lora.md\n          - SDXL Textual Inversion Config: reference/config/pipelines/sdxl_textual_inversion.md\n          - SDXL LoRA and Textual Inversion Config: reference/config/pipelines/sdxl_lora_and_textual_inversion.md\n          - SDXL Finetune Config: reference/config/pipelines/sdxl_finetune.md\n      - shared:\n          - data_loader_config: reference/config/shared/data/data_loader_config.md\n          - dataset_config: reference/config/shared/data/dataset_config.md\n          - optimizer_config: reference/config/shared/optimizer_config.md\n  - Contributing:\n      - contributing/development_environment.md\n      - contributing/directory_structure.md\n      - contributing/tests.md\n      - contributing/documentation.md\n\nplugins:\n  - search\n  - mkdocstrings:\n      default_handler: python\n      custom_templates: docs/templates\n      handlers:\n        python:\n          options:\n            show_root_heading: false\n            show_root_toc_entry: false\n            show_bases: false\n            show_source: false\n            show_if_no_docstring: true\n            inherited_members: true\n            annotations_path: brief\n            separate_signature: true\n            show_signature_annotations: true\n            members_order: source\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[build-system]\nrequires = [\"setuptools>=65.5\", \"pip>=22.3\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"invoke-training\"\nversion = \"0.0.1\"\nauthors = [{ name = \"The Invoke AI Team\", email = \"ryan@invoke.ai\" }]\ndescription = \"A library for Stable Diffusion model training.\"\nreadme = \"README.md\"\nrequires-python = \">=3.10\"\nlicense = { text = \"Apache-2.0\" }\nclassifiers = [\n    \"Programming Language :: Python :: 3\",\n    \"Operating System :: OS Independent\",\n]\ndependencies = [\n    \"accelerate\",\n    \"datasets~=2.14.3\",\n    \"diffusers[torch]\",\n    \"einops\",\n    \"fastapi\",\n    \"gradio\",\n    \"invokeai>=5.10.0a1\",\n    \"numpy<2.0.0\",\n    \"omegaconf\",\n    \"peft~=0.11.1\",\n    \"pillow\",\n    \"prodigyopt\",\n    \"pydantic\",\n    \"pyyaml\",\n    \"safetensors\",\n    \"tensorboard\",\n    \"torch\",\n    \"torchvision\",\n    \"tqdm\",\n    \"transformers\",\n    \"uvicorn[standard]\",\n]\n\n[project.optional-dependencies]\n\"xformers\" = [\"xformers>=0.0.28.post1; sys_platform!='darwin'\"]\n\"bitsandbytes\" = [\"bitsandbytes>=0.43.1; sys_platform!='darwin'\"]\n\n\"test\" = [\n    \"mkdocs\",\n    \"mkdocs-material\",\n    \"mkdocstrings[python]\",\n    \"pre-commit~=3.3.3\",\n    \"pytest~=7.4.0\",\n    \"ruff~=0.11.2\",\n    \"ruff-lsp\",\n]\n\n[project.scripts]\n\"invoke-train\" = \"invoke_training.scripts.invoke_train:main\"\n\"invoke-train-ui\" = \"invoke_training.scripts.invoke_train_ui:main\"\n\"invoke-generate-images\" = \"invoke_training.scripts.invoke_generate_images:main\"\n\"invoke-visualize-data-loading\" = \"invoke_training.scripts.invoke_visualize_data_loading:main\"\n\n[project.urls]\n\"Homepage\" = \"https://github.com/invoke-ai/invoke-training\"\n\"Discord\" = \"https://discord.gg/ZmtBAhwWhy\"\n\n[tool.setuptools.package-data]\n\"invoke_training.assets\" = [\"*.png\"]\n\"invoke_training.sample_configs\" = [\"**/*.yaml\"]\n\"invoke_training.ui\" = [\"*.html\"]\n\n[tool.ruff]\nsrc = [\"src\"]\nlint.select = [\"E\", \"F\", \"W\", \"C9\", \"N8\", \"I\"]\ntarget-version = \"py39\"\nline-length = 120\n\n[tool.pytest.ini_options]\naddopts = \"--strict-markers\"\nmarkers = [\n    \"cuda: marks tests that require a CUDA GPU\",\n    \"loads_model: marks tests that require a model (or data) from the HF hub\",\n]\n"
  },
  {
    "path": "sample_data/bruce_the_gnome/data.jsonl",
    "content": "{\"image\": \"001.png\", \"text\": \"A stuffed gnome sits on a wooden floor, facing right with a gray couch in the background.\"}\n{\"image\": \"002.png\", \"text\": \"A stuffed gnome stands on a black tiled floor, with a silver refrigerator and white wall in the background.\"}\n{\"image\": \"004.png\", \"text\": \"A stuffed gnome sits on a white marble floor, photorealistic.\"}\n{\"image\": \"003.png\", \"text\": \"A stuffed gnome sits on a gray tiled floor, facing the camera.\"}\n"
  },
  {
    "path": "src/invoke_training/__init__.py",
    "content": ""
  },
  {
    "path": "src/invoke_training/_shared/__init__.py",
    "content": ""
  },
  {
    "path": "src/invoke_training/_shared/accelerator/__init__.py",
    "content": ""
  },
  {
    "path": "src/invoke_training/_shared/accelerator/accelerator_utils.py",
    "content": "import logging\nimport os\nfrom typing import Literal\n\nimport datasets\nimport diffusers\nimport torch\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import MultiProcessAdapter, get_logger\nfrom accelerate.utils import ProjectConfiguration\n\n\ndef initialize_accelerator(\n    out_dir: str, gradient_accumulation_steps: int, mixed_precision: str, log_with: str\n) -> Accelerator:\n    \"\"\"Configure Hugging Face accelerate and return an Accelerator.\n\n    Args:\n        out_dir (str): The output directory where results will be written.\n        gradient_accumulation_steps (int): Forwarded to accelerat.Accelerator(...).\n        mixed_precision (str): Forwarded to accelerate.Accelerator(...).\n        log_with (str): Forwarded to accelerat.Accelerator(...)\n\n    Returns:\n        Accelerator\n    \"\"\"\n    accelerator_project_config = ProjectConfiguration(\n        project_dir=out_dir,\n        logging_dir=os.path.join(out_dir, \"logs\"),\n    )\n    return Accelerator(\n        project_config=accelerator_project_config,\n        gradient_accumulation_steps=gradient_accumulation_steps,\n        mixed_precision=mixed_precision,\n        log_with=log_with,\n    )\n\n\ndef initialize_logging(logger_name: str, accelerator: Accelerator) -> MultiProcessAdapter:\n    \"\"\"Configure logging.\n\n    Returns an accelerate logger with multi-process logging support. Logging is configured to be more verbose on the\n    main process. Non-main processes only log at error level for Hugging Face libraries (datasets, transformers,\n    diffusers).\n\n    Args:\n        accelerator (Accelerator): The Accelerator to configure.\n\n    Returns:\n        MultiProcessAdapter: _description_\n    \"\"\"\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    if accelerator.is_local_main_process:\n        datasets.utils.logging.set_verbosity_warning()\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        # Only log errors from non-main processes.\n        datasets.utils.logging.set_verbosity_error()\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    return get_logger(logger_name)\n\n\ndef get_mixed_precision_dtype(accelerator: Accelerator):\n    \"\"\"Extract torch.dtype from Accelerator config.\n\n    Args:\n        accelerator (Accelerator): The Hugging Face Accelerator.\n\n    Raises:\n        NotImplementedError: If the accelerator's mixed_precision configuration is not recognized.\n\n    Returns:\n        torch.dtype: The weight type inferred from the accelerator mixed_precision configuration.\n    \"\"\"\n    weight_dtype: torch.dtype = torch.float32\n    if accelerator.mixed_precision is None or accelerator.mixed_precision == \"no\":\n        weight_dtype = torch.float32\n    elif accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n    else:\n        raise NotImplementedError(f\"mixed_precision mode '{accelerator.mixed_precision}' is not yet supported.\")\n    return weight_dtype\n\n\ndef get_dtype_from_str(dtype_str: Literal[\"float16\", \"bfloat16\", \"float32\"]) -> torch.dtype:\n    if dtype_str == \"float16\":\n        return torch.float16\n    elif dtype_str == \"bfloat16\":\n        return torch.bfloat16\n    elif dtype_str == \"float32\":\n        return torch.float32\n    else:\n        raise ValueError(f\"Unsupported dtype: {dtype_str}\")\n"
  },
  {
    "path": "src/invoke_training/_shared/checkpoints/__init__.py",
    "content": ""
  },
  {
    "path": "src/invoke_training/_shared/checkpoints/checkpoint_tracker.py",
    "content": "import os\nimport shutil\nimport typing\n\n\nclass CheckpointTracker:\n    \"\"\"A utility class for managing checkpoint paths.\n\n    Manages checkpoint paths of the following forms:\n    - Checkpoint directories: `{base_dir}/{prefix}-epoch_{num_epochs}-step_{num_steps}`\n    - Checkpoint files: `{base_dir}/{prefix}-epoch_{num_epochs}-step_{num_steps}{extension}`\n    \"\"\"\n\n    def __init__(\n        self,\n        base_dir: str,\n        prefix: str,\n        extension: typing.Optional[str] = None,\n        max_checkpoints: typing.Optional[int] = None,\n        index_padding: int = 8,\n    ):\n        \"\"\"Initialize a CheckpointTracker.\n\n        Args:\n            base_dir (str): The base checkpoint directory.\n            prefix (str): A prefix applied to every checkpoint.\n            extension (str, optional): If set, this is the file extension that will be applied to all checkpoints\n                (usually one of \".pt\", \".ckpt\", or \".safetensors\"). If None, then it will be assumed that we are\n                managing checkpoint directories rather than files.\n            max_checkpoints (typing.Optional[int], optional): The maximum number of checkpoints that should exist in\n                base_dir.\n            index_padding (int, optional): The length of the zero-padded epoch/step counts in the generated checkpoint\n                names. E.g. index_padding=8 would produce checkpoint paths like\n                \"base_dir/prefix-epoch_00000001-step_00000001.ckpt\".\n\n        Raises:\n            ValueError: If extension is provided, but it doesn not start with a '.'.\n        \"\"\"\n        if extension is not None and not extension.startswith(\".\"):\n            raise ValueError(f\"extension='{extension}' must start with a '.'.\")\n\n        self._base_dir = base_dir\n        self._prefix = prefix\n        self._extension = extension\n        self._max_checkpoints = max_checkpoints\n        self._index_padding = index_padding\n\n    def prune(self, buffer_num: int = 1) -> int:\n        \"\"\"Delete checkpoint files and directories so that there are at most `max_checkpoints - buffer_num` checkpoints\n        remaining. The checkpoints with the lowest step counts will be deleted.\n\n        Args:\n            buffer_num (int, optional): The number below `max_checkpoints` to 'free-up'.\n\n        Returns:\n            int: The number of checkpoints deleted.\n        \"\"\"\n        if self._max_checkpoints is None:\n            return 0\n\n        checkpoints = os.listdir(self._base_dir)\n        checkpoints = [p for p in checkpoints if p.startswith(self._prefix)]\n        checkpoints = sorted(\n            checkpoints,\n            key=lambda x: int(os.path.splitext(x)[0].split(\"-step_\")[-1]),\n        )\n\n        num_to_remove = len(checkpoints) - (self._max_checkpoints - buffer_num)\n        if num_to_remove > 0:\n            checkpoints_to_remove = checkpoints[:num_to_remove]\n\n            for checkpoint_to_remove in checkpoints_to_remove:\n                checkpoint_to_remove = os.path.join(self._base_dir, checkpoint_to_remove)\n                if os.path.isfile(checkpoint_to_remove):\n                    # Delete checkpoint file.\n                    os.remove(checkpoint_to_remove)\n                else:\n                    # Delete checkpoint directory.\n                    shutil.rmtree(checkpoint_to_remove)\n\n        return max(0, num_to_remove)\n\n    def get_path(self, epoch: int, step: int) -> str:\n        \"\"\"Get the checkpoint path for index `idx`.\n\n        Args:\n            epoch (int): The number of completed epochs.\n            step (int): The number of completed training steps.\n\n        Returns:\n            str: The checkpoint path.\n        \"\"\"\n        suffix = self._extension or \"\"\n        return os.path.join(\n            self._base_dir,\n            f\"{self._prefix.strip()}-epoch_{epoch:0>{self._index_padding}}-step_{step:0>{self._index_padding}}{suffix}\",\n        )\n"
  },
  {
    "path": "src/invoke_training/_shared/checkpoints/lora_checkpoint_utils.py",
    "content": "from pathlib import Path\n\nimport peft\nimport torch\n\n\ndef save_multi_model_peft_checkpoint(checkpoint_dir: Path | str, models: dict[str, peft.PeftModel]):\n    \"\"\"Save a dict of PeftModels to a checkpoint directory.\n\n    The `models` dict keys are used as the subdirectories for each individual model.\n\n    `load_multi_model_peft_checkpoint(...)` can be used to load the resultant checkpoint.\n    \"\"\"\n    checkpoint_dir = Path(checkpoint_dir)\n    for model_key, peft_model in models.items():\n        assert isinstance(peft_model, peft.PeftModel)\n\n        # HACK(ryand): PeftModel.save_pretrained(...) expects the config to have a \"_name_or_path\" entry. For now, we\n        # set this to None here. This should be fixed upstream in PEFT.\n        if (\n            hasattr(peft_model, \"config\")\n            and isinstance(peft_model.config, dict)\n            and \"_name_or_path\" not in peft_model.config\n        ):\n            peft_model.config[\"_name_or_path\"] = None\n\n        peft_model.save_pretrained(str(checkpoint_dir / model_key))\n\n\ndef load_multi_model_peft_checkpoint(\n    checkpoint_dir: Path | str,\n    models: dict[str, torch.nn.Module],\n    is_trainable: bool = False,\n    raise_if_subdir_missing: bool = True,\n) -> dict[str, torch.nn.Module]:\n    \"\"\"Load a multi-model PEFT checkpoint that was saved with `save_multi_model_peft_checkpoint(...)`.\"\"\"\n    checkpoint_dir = Path(checkpoint_dir)\n    assert checkpoint_dir.exists()\n\n    out_models = {}\n    for model_key, model in models.items():\n        dir_path: Path = checkpoint_dir / model_key\n        if dir_path.exists():\n            out_models[model_key] = peft.PeftModel.from_pretrained(model, dir_path, is_trainable=is_trainable)\n        else:\n            if raise_if_subdir_missing:\n                raise ValueError(f\"'{dir_path}' does not exist.\")\n            else:\n                # Pass through the model unchanged.\n                out_models[model_key] = model\n\n    return out_models\n\n\n# This implementation is based on\n# https://github.com/huggingface/peft/blob/8665e2b5719faa4e4b91749ddec09442927b53e0/examples/lora_dreambooth/convert_peft_sd_lora_to_kohya_ss.py#L20\ndef _convert_peft_state_dict_to_kohya_state_dict(\n    lora_config: peft.LoraConfig,\n    peft_state_dict: dict[str, torch.Tensor],\n    prefix: str,\n    dtype: torch.dtype,\n) -> dict[str, torch.Tensor]:\n    kohya_ss_state_dict = {}\n    for peft_key, weight in peft_state_dict.items():\n        kohya_key = peft_key.replace(\"base_model.model\", prefix)\n        kohya_key = kohya_key.replace(\"lora_A\", \"lora_down\")\n        kohya_key = kohya_key.replace(\"lora_B\", \"lora_up\")\n        kohya_key = kohya_key.replace(\".\", \"_\", kohya_key.count(\".\") - 2)\n        kohya_ss_state_dict[kohya_key] = weight.to(dtype)\n\n        # Set alpha parameter\n        if \"lora_down\" in kohya_key:\n            alpha_key = f\"{kohya_key.split('.')[0]}.alpha\"\n            kohya_ss_state_dict[alpha_key] = torch.tensor(lora_config.lora_alpha).to(dtype)\n\n    return kohya_ss_state_dict\n\n\ndef _convert_peft_models_to_kohya_state_dict(\n    kohya_prefixes: list[str], models: list[peft.PeftModel]\n) -> dict[str, torch.Tensor]:\n    kohya_state_dict = {}\n    default_adapter_name = \"default\"\n\n    for kohya_prefix, peft_model in zip(kohya_prefixes, models, strict=True):\n        lora_config = peft_model.peft_config[default_adapter_name]\n        assert isinstance(lora_config, peft.LoraConfig)\n\n        peft_state_dict = peft.get_peft_model_state_dict(peft_model, adapter_name=default_adapter_name)\n\n        kohya_state_dict.update(\n            _convert_peft_state_dict_to_kohya_state_dict(\n                lora_config=lora_config,\n                peft_state_dict=peft_state_dict,\n                prefix=kohya_prefix,\n                dtype=torch.float32,\n            )\n        )\n\n    return kohya_state_dict\n"
  },
  {
    "path": "src/invoke_training/_shared/checkpoints/serialization.py",
    "content": "import typing\nfrom pathlib import Path\n\nimport safetensors.torch\nimport torch\n\n\ndef save_state_dict(state_dict: typing.Dict[str, torch.Tensor], out_file: typing.Union[Path, str]):\n    \"\"\"Save a state_dict to a file.\n\n    Both safetensors and torch formats are supported. The format is inferred from the `out_file` extension.\n    Supported extensions:\n    - \".ckpt\" -> torch\n    - \".pt\" -> torch\n    - \".safetensors -> safetensors\n\n    Args:\n        state_dict (typing.Dict[str, torch.Tensor]): The state_dict to save.\n        out_file (Path | str): The output file to save to.\n\n    Raises:\n        ValueError: If the `out_file` has an unsupported file extension.\n    \"\"\"\n    out_file = Path(out_file)\n    if out_file.suffix == \".ckpt\" or out_file.suffix == \".pt\":\n        torch.save(state_dict, out_file)\n    elif out_file.suffix == \".safetensors\":\n        safetensors.torch.save_file(state_dict, out_file)\n    else:\n        raise ValueError(f\"Unsupported file extension: '{out_file.suffix}'.\")\n\n\ndef load_state_dict(in_file: typing.Union[Path, str]) -> typing.Dict[str, torch.Tensor]:\n    \"\"\"Load a state_dict from a file.\n\n    Both safetensors and torch formats are supported. The format is inferred from the `in_file` extension.\n    Supported extensions:\n    - \".ckpt\" -> torch\n    - \".pt\" -> torch\n    - \".safetensors -> safetensors\n\n    Args:\n        in_file (Path | str): The input file to load from.\n\n    Raises:\n        ValueError: If the `in_file` has an unsupported file extension.\n\n    Returns:\n        typing.Dict[str, torch.Tensor]: The loaded state_dict.\n    \"\"\"\n    in_file = Path(in_file)\n    if in_file.suffix == \".ckpt\" or in_file.suffix == \".pt\":\n        return torch.load(in_file)\n    elif in_file.suffix == \".safetensors\":\n        return safetensors.torch.load_file(in_file)\n    else:\n        raise ValueError(f\"Unsupported file extension: '{in_file.suffix}'.\")\n"
  },
  {
    "path": "src/invoke_training/_shared/data/ARCHITECTURE.md",
    "content": "# Dataset Architecture\nDataset handling is split into 3 layers of abstraction: Datasets, Transforms, and DataLoaders. Each is explained in more detail below.\n\n## Datasets\n\nDatasets implement the [torch.utils.data.Dataset](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html#creating-a-custom-dataset-for-your-files) interface.\n\nMost dataset classes act as an abstraction over a specific dataset format.\n \n## Transforms\n\nTransforms are functions applied to data loaded by Datasets. For example, the `SDImageTransform` implements image augmentations for Stable Diffusion training.\n\nTransforms are kept separate from the underlying datasets for several reasons:\n- It is easier to write tests for isolated transforms.\n- Modular transforms can often be re-used for multiple base datasets.\n- Modular transforms make it easy to customize datasets for different situations. For example, you may want to wrap a dataset with one set of transforms initially to populate a cache, and then with a different set of transforms to read from the cache.\n\nTransforms are applied to a dataset via the `TransformDataset` class.\n\n## DataLoaders\n\nThe dataset classes (with composed transforms) are wrapped in a `torch.utils.data.DataLoader` that handles batch collation, multi-processing, etc.\n"
  },
  {
    "path": "src/invoke_training/_shared/data/__init__.py",
    "content": ""
  },
  {
    "path": "src/invoke_training/_shared/data/data_loaders/__init__.py",
    "content": ""
  },
  {
    "path": "src/invoke_training/_shared/data/data_loaders/dreambooth_sd_dataloader.py",
    "content": "import typing\n\nfrom torch.utils.data import ConcatDataset, DataLoader\nfrom torch.utils.data.sampler import RandomSampler, SequentialSampler\n\nfrom invoke_training._shared.data.data_loaders.image_caption_sd_dataloader import (\n    build_aspect_ratio_bucket_manager,\n    sd_image_caption_collate_fn,\n)\nfrom invoke_training._shared.data.datasets.image_dir_dataset import ImageDirDataset\nfrom invoke_training._shared.data.datasets.transform_dataset import TransformDataset\nfrom invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import AspectRatioBucketBatchSampler\nfrom invoke_training._shared.data.samplers.batch_offset_sampler import BatchOffsetSampler\nfrom invoke_training._shared.data.samplers.concat_sampler import ConcatSampler\nfrom invoke_training._shared.data.samplers.interleaved_sampler import InterleavedSampler\nfrom invoke_training._shared.data.samplers.offset_sampler import OffsetSampler\nfrom invoke_training._shared.data.transforms.constant_field_transform import ConstantFieldTransform\nfrom invoke_training._shared.data.transforms.drop_field_transform import DropFieldTransform\nfrom invoke_training._shared.data.transforms.load_cache_transform import LoadCacheTransform\nfrom invoke_training._shared.data.transforms.sd_image_transform import SDImageTransform\nfrom invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache\nfrom invoke_training.config.data.data_loader_config import DreamboothSDDataLoaderConfig\n\n\ndef build_dreambooth_sd_dataloader(\n    config: DreamboothSDDataLoaderConfig,\n    batch_size: int,\n    text_encoder_output_cache_dir: typing.Optional[str] = None,\n    text_encoder_cache_field_to_output_field: typing.Optional[dict[str, str]] = None,\n    vae_output_cache_dir: typing.Optional[str] = None,\n    shuffle: bool = True,\n    sequential_batching: bool = False,\n) -> DataLoader:\n    \"\"\"Construct a DataLoader for a DreamBooth dataset for Stable Diffusion XL.\n\n    Args:\n        config (DreamboothSDDataLoaderConfig):\n        batch_size (int):\n        text_encoder_output_cache_dir (str, optional): The directory where text encoder outputs are cached and should be\n            loaded from.\n        vae_output_cache_dir (str, optional): The directory where VAE outputs are cached and should be loaded from. If\n            set, then the image augmentation transforms will be skipped, and the image will not be copied to VRAM.\n        shuffle (bool, optional): Whether to shuffle the dataset order.\n        sequential_batching (bool, optional): If True, the internal dataset will be processed sequentially rather than\n            interleaving class and instance examples. This is intended to be used when processing the entire dataset for\n            caching purposes. Defaults to False.\n\n    Returns:\n        DataLoader\n    \"\"\"\n    # Prepare instance dataset.\n    base_instance_dataset = ImageDirDataset(\n        config.instance_dataset.dataset_dir,\n        id_prefix=\"instance_\",\n        keep_in_memory=config.instance_dataset.keep_in_memory,\n    )\n    instance_dataset = TransformDataset(\n        base_instance_dataset,\n        [\n            ConstantFieldTransform(\"caption\", config.instance_caption),\n            ConstantFieldTransform(\"loss_weight\", 1.0),\n        ],\n    )\n    datasets = [instance_dataset]\n\n    # Prepare class dataset.\n    base_class_dataset = None\n    class_dataset = None\n    if config.class_dataset is not None:\n        base_class_dataset = ImageDirDataset(\n            config.class_dataset.dataset_dir, id_prefix=\"class_\", keep_in_memory=config.class_dataset.keep_in_memory\n        )\n        class_dataset = TransformDataset(\n            base_class_dataset,\n            [\n                ConstantFieldTransform(\"caption\", config.class_caption),\n                ConstantFieldTransform(\"loss_weight\", config.class_data_loss_weight),\n            ],\n        )\n        datasets.append(class_dataset)\n\n    # Merge instance dataset and class dataset.\n    merged_dataset = ConcatDataset(datasets)\n\n    # Initialize either the fixed target resolution or aspect ratio buckets.\n    target_resolution = None\n    aspect_ratio_bucket_manager = None\n    instance_sampler = None\n    class_sampler = None\n    if config.aspect_ratio_buckets is None:\n        target_resolution = config.resolution\n        # TODO(ryand): Provide a seeded generator.\n        instance_sampler = RandomSampler(instance_dataset) if shuffle else SequentialSampler(instance_dataset)\n        if base_class_dataset is not None:\n            class_sampler = RandomSampler(class_dataset) if shuffle else SequentialSampler(class_dataset)\n            class_sampler = OffsetSampler(class_sampler, offset=len(base_instance_dataset))\n    else:\n        aspect_ratio_bucket_manager = build_aspect_ratio_bucket_manager(config=config.aspect_ratio_buckets)\n        # TODO(ryand): Drill-down the seed parameter rather than hard-coding to 0 here.\n        instance_sampler = AspectRatioBucketBatchSampler.from_image_sizes(\n            bucket_manager=aspect_ratio_bucket_manager,\n            image_sizes=base_instance_dataset.get_image_dimensions(),\n            batch_size=batch_size,\n            shuffle=shuffle,\n            seed=0,\n        )\n        if base_class_dataset is not None:\n            class_sampler = AspectRatioBucketBatchSampler.from_image_sizes(\n                bucket_manager=aspect_ratio_bucket_manager,\n                image_sizes=base_class_dataset.get_image_dimensions(),\n                batch_size=batch_size,\n                shuffle=shuffle,\n                seed=0,\n            )\n            class_sampler = BatchOffsetSampler(class_sampler, offset=len(base_instance_dataset))\n\n    # Add transforms to the merged dataset.\n    all_transforms = []\n    if vae_output_cache_dir is None:\n        all_transforms.append(\n            SDImageTransform(\n                image_field_names=[\"image\"],\n                fields_to_normalize_to_range_minus_one_to_one=[\"image\"],\n                resolution=target_resolution,\n                aspect_ratio_bucket_manager=aspect_ratio_bucket_manager,\n                center_crop=config.center_crop,\n                random_flip=config.random_flip,\n            )\n        )\n    else:\n        vae_cache = TensorDiskCache(vae_output_cache_dir)\n        all_transforms.append(\n            LoadCacheTransform(\n                cache=vae_cache,\n                cache_key_field=\"id\",\n                cache_field_to_output_field={\n                    \"vae_output\": \"vae_output\",\n                    \"original_size_hw\": \"original_size_hw\",\n                    \"crop_top_left_yx\": \"crop_top_left_yx\",\n                },\n            )\n        )\n        # We drop the image to avoid having to either convert from PIL, or handle PIL batch collation.\n        all_transforms.append(DropFieldTransform(\"image\"))\n\n    if text_encoder_output_cache_dir is not None:\n        assert text_encoder_cache_field_to_output_field is not None\n        text_encoder_cache = TensorDiskCache(text_encoder_output_cache_dir)\n        all_transforms.append(\n            LoadCacheTransform(\n                cache=text_encoder_cache,\n                cache_key_field=\"id\",\n                cache_field_to_output_field=text_encoder_cache_field_to_output_field,\n            )\n        )\n\n    merged_dataset = TransformDataset(merged_dataset, all_transforms)\n\n    # Choose between sequential vs. interleaved merging of the instance and class samplers.\n    # Sequential sampling is typically used to populate a cache, because it guarantees that all examples will be\n    # included in an epoch.\n    samplers = [instance_sampler]\n    if class_sampler is not None:\n        samplers.append(class_sampler)\n    if sequential_batching:\n        sampler = ConcatSampler(samplers)\n    else:\n        sampler = InterleavedSampler(samplers)\n\n    if config.aspect_ratio_buckets is None:\n        return DataLoader(\n            merged_dataset,\n            sampler=sampler,\n            collate_fn=sd_image_caption_collate_fn,\n            batch_size=batch_size,\n            num_workers=config.dataloader_num_workers,\n        )\n    else:\n        # If config.aspect_ratio_buckets is not None, then we are using a batch sampler.\n        return DataLoader(\n            merged_dataset,\n            batch_sampler=sampler,\n            collate_fn=sd_image_caption_collate_fn,\n            num_workers=config.dataloader_num_workers,\n        )\n"
  },
  {
    "path": "src/invoke_training/_shared/data/data_loaders/image_caption_flux_dataloader.py",
    "content": "import typing\n\nfrom torch.utils.data import DataLoader\n\nfrom invoke_training._shared.data.data_loaders.image_caption_sd_dataloader import (\n    build_aspect_ratio_bucket_manager,\n)\nfrom invoke_training._shared.data.data_loaders.image_caption_sd_dataloader import (\n    sd_image_caption_collate_fn as flux_image_caption_collate_fn,\n)\nfrom invoke_training._shared.data.datasets.build_dataset import (\n    build_hf_hub_image_caption_dataset,\n    build_image_caption_dir_dataset,\n    build_image_caption_jsonl_dataset,\n)\nfrom invoke_training._shared.data.datasets.transform_dataset import TransformDataset\nfrom invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import (\n    AspectRatioBucketBatchSampler,\n)\nfrom invoke_training._shared.data.transforms.caption_prefix_transform import CaptionPrefixTransform\nfrom invoke_training._shared.data.transforms.drop_field_transform import DropFieldTransform\nfrom invoke_training._shared.data.transforms.flux_image_transform import FluxImageTransform\nfrom invoke_training._shared.data.transforms.load_cache_transform import LoadCacheTransform\nfrom invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache\nfrom invoke_training.config.data.data_loader_config import ImageCaptionFluxDataLoaderConfig\nfrom invoke_training.config.data.dataset_config import (\n    HFHubImageCaptionDatasetConfig,\n    ImageCaptionDirDatasetConfig,\n    ImageCaptionJsonlDatasetConfig,\n)\n\n\ndef build_image_caption_flux_dataloader(  # noqa: C901\n    config: ImageCaptionFluxDataLoaderConfig,\n    batch_size: int,\n    use_masks: bool = False,\n    text_encoder_output_cache_dir: typing.Optional[str] = None,\n    text_encoder_cache_field_to_output_field: typing.Optional[dict[str, str]] = None,\n    vae_output_cache_dir: typing.Optional[str] = None,\n    shuffle: bool = True,\n) -> DataLoader:\n    \"\"\"Construct a DataLoader for an image-caption dataset for Flux.1-dev.\n\n    Args:\n        config (ImageCaptionFluxDataLoaderConfig): The dataset config.\n        batch_size (int): The DataLoader batch size.\n        text_encoder_output_cache_dir (str, optional): The directory where text encoder outputs are cached and should be\n            loaded from. If set, then the TokenizeTransform will not be applied.\n        vae_output_cache_dir (str, optional): The directory where VAE outputs are cached and should be loaded from. If\n            set, then the image augmentation transforms will be skipped, and the image will not be copied to VRAM.\n        shuffle (bool, optional): Whether to shuffle the dataset order.\n    Returns:\n        DataLoader\n    \"\"\"\n    if isinstance(config.dataset, HFHubImageCaptionDatasetConfig):\n        base_dataset = build_hf_hub_image_caption_dataset(config.dataset)\n    elif isinstance(config.dataset, ImageCaptionJsonlDatasetConfig):\n        base_dataset = build_image_caption_jsonl_dataset(config.dataset)\n    elif isinstance(config.dataset, ImageCaptionDirDatasetConfig):\n        base_dataset = build_image_caption_dir_dataset(config.dataset)\n    else:\n        raise ValueError(f\"Unexpected dataset config type: '{type(config.dataset)}'.\")\n\n    # Initialize either the fixed target resolution or aspect ratio buckets.\n    if config.aspect_ratio_buckets is None:\n        aspect_ratio_bucket_manager = None\n        batch_sampler = None\n    else:\n        aspect_ratio_bucket_manager = build_aspect_ratio_bucket_manager(config=config.aspect_ratio_buckets)\n        # TODO(ryand): Drill-down the seed parameter rather than hard-coding to 0 here.\n        batch_sampler = AspectRatioBucketBatchSampler.from_image_sizes(\n            bucket_manager=aspect_ratio_bucket_manager,\n            image_sizes=base_dataset.get_image_dimensions(),\n            batch_size=batch_size,\n            shuffle=shuffle,\n            seed=0,\n        )\n\n    all_transforms = []\n\n    if config.caption_prefix is not None:\n        all_transforms.append(CaptionPrefixTransform(caption_field_name=\"caption\", prefix=config.caption_prefix + \" \"))\n\n    if vae_output_cache_dir is None:\n        image_field_names = [\"image\"]\n        if use_masks:\n            image_field_names.append(\"mask\")\n        else:\n            all_transforms.append(DropFieldTransform(\"mask\"))\n\n        all_transforms.append(\n            FluxImageTransform(\n                image_field_names=image_field_names,\n                fields_to_normalize_to_range_minus_one_to_one=[\"image\"],\n                resolution=config.resolution,\n                aspect_ratio_bucket_manager=aspect_ratio_bucket_manager,\n                center_crop=config.center_crop,\n                random_flip=config.random_flip,\n            )\n        )\n    else:\n        # We drop the image to avoid having to either convert from PIL, or handle PIL batch collation.\n        all_transforms.append(DropFieldTransform(\"image\"))\n        all_transforms.append(DropFieldTransform(\"mask\"))\n\n        vae_cache = TensorDiskCache(vae_output_cache_dir)\n\n        cache_field_to_output_field = {\n            \"vae_output\": \"vae_output\",\n            \"original_size_hw\": \"original_size_hw\",\n            \"crop_top_left_yx\": \"crop_top_left_yx\",\n        }\n        if use_masks:\n            cache_field_to_output_field[\"mask\"] = \"mask\"\n        all_transforms.append(\n            LoadCacheTransform(\n                cache=vae_cache,\n                cache_key_field=\"id\",\n                cache_field_to_output_field=cache_field_to_output_field,\n            )\n        )\n\n    if text_encoder_output_cache_dir is not None:\n        assert text_encoder_cache_field_to_output_field is not None\n        text_encoder_cache = TensorDiskCache(text_encoder_output_cache_dir)\n        all_transforms.append(\n            LoadCacheTransform(\n                cache=text_encoder_cache,\n                cache_key_field=\"id\",\n                cache_field_to_output_field=text_encoder_cache_field_to_output_field,\n            )\n        )\n    dataset = TransformDataset(base_dataset, all_transforms)\n\n    if batch_sampler is None:\n        return DataLoader(\n            dataset,\n            shuffle=shuffle,\n            collate_fn=flux_image_caption_collate_fn,\n            batch_size=batch_size,\n            num_workers=config.dataloader_num_workers,\n        )\n    else:\n        return DataLoader(\n            dataset,\n            batch_sampler=batch_sampler,\n            collate_fn=flux_image_caption_collate_fn,\n            num_workers=config.dataloader_num_workers,\n        )\n"
  },
  {
    "path": "src/invoke_training/_shared/data/data_loaders/image_caption_sd_dataloader.py",
    "content": "import typing\n\nimport torch\nfrom torch.utils.data import DataLoader\n\nfrom invoke_training._shared.data.datasets.build_dataset import (\n    build_hf_hub_image_caption_dataset,\n    build_image_caption_dir_dataset,\n    build_image_caption_jsonl_dataset,\n)\nfrom invoke_training._shared.data.datasets.transform_dataset import TransformDataset\nfrom invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import AspectRatioBucketBatchSampler\nfrom invoke_training._shared.data.transforms.caption_prefix_transform import CaptionPrefixTransform\nfrom invoke_training._shared.data.transforms.drop_field_transform import DropFieldTransform\nfrom invoke_training._shared.data.transforms.load_cache_transform import LoadCacheTransform\nfrom invoke_training._shared.data.transforms.sd_image_transform import SDImageTransform\nfrom invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache\nfrom invoke_training._shared.data.utils.aspect_ratio_bucket_manager import AspectRatioBucketManager\nfrom invoke_training.config.data.data_loader_config import AspectRatioBucketConfig, ImageCaptionSDDataLoaderConfig\nfrom invoke_training.config.data.dataset_config import (\n    HFHubImageCaptionDatasetConfig,\n    ImageCaptionDirDatasetConfig,\n    ImageCaptionJsonlDatasetConfig,\n)\n\n\ndef sd_image_caption_collate_fn(examples):\n    \"\"\"A batch collation function for the image-caption SDXL data loader.\"\"\"\n    out_examples = {\n        \"id\": [example[\"id\"] for example in examples],\n    }\n\n    if \"image\" in examples[0]:\n        out_examples[\"image\"] = torch.stack([example[\"image\"] for example in examples])\n\n    if \"original_size_hw\" in examples[0]:\n        out_examples[\"original_size_hw\"] = [example[\"original_size_hw\"] for example in examples]\n\n    if \"crop_top_left_yx\" in examples[0]:\n        out_examples[\"crop_top_left_yx\"] = [example[\"crop_top_left_yx\"] for example in examples]\n\n    if \"caption\" in examples[0]:\n        out_examples[\"caption\"] = [example[\"caption\"] for example in examples]\n\n    if \"loss_weight\" in examples[0]:\n        out_examples[\"loss_weight\"] = torch.tensor([example[\"loss_weight\"] for example in examples])\n\n    if \"prompt_embeds\" in examples[0]:\n        out_examples[\"prompt_embeds\"] = torch.stack([example[\"prompt_embeds\"] for example in examples])\n        out_examples[\"pooled_prompt_embeds\"] = torch.stack([example[\"pooled_prompt_embeds\"] for example in examples])\n\n    if \"text_encoder_output\" in examples[0]:\n        out_examples[\"text_encoder_output\"] = torch.stack([example[\"text_encoder_output\"] for example in examples])\n\n    if \"vae_output\" in examples[0]:\n        out_examples[\"vae_output\"] = torch.stack([example[\"vae_output\"] for example in examples])\n\n    if \"mask\" in examples[0]:\n        out_examples[\"mask\"] = torch.stack([example[\"mask\"] for example in examples])\n\n    return out_examples\n\n\ndef build_aspect_ratio_bucket_manager(config: AspectRatioBucketConfig):\n    return AspectRatioBucketManager.from_constraints(\n        target_resolution=config.target_resolution,\n        start_dim=config.start_dim,\n        end_dim=config.end_dim,\n        divisible_by=config.divisible_by,\n    )\n\n\ndef build_image_caption_sd_dataloader(  # noqa: C901\n    config: ImageCaptionSDDataLoaderConfig,\n    batch_size: int,\n    use_masks: bool = False,\n    text_encoder_output_cache_dir: typing.Optional[str] = None,\n    text_encoder_cache_field_to_output_field: typing.Optional[dict[str, str]] = None,\n    vae_output_cache_dir: typing.Optional[str] = None,\n    shuffle: bool = True,\n) -> DataLoader:\n    \"\"\"Construct a DataLoader for an image-caption dataset for Stable Diffusion XL.\n\n    Args:\n        config (ImageCaptionSDDataLoaderConfig): The dataset config.\n        batch_size (int): The DataLoader batch size.\n        text_encoder_output_cache_dir (str, optional): The directory where text encoder outputs are cached and should be\n            loaded from. If set, then the TokenizeTransform will not be applied.\n        vae_output_cache_dir (str, optional): The directory where VAE outputs are cached and should be loaded from. If\n            set, then the image augmentation transforms will be skipped, and the image will not be copied to VRAM.\n        shuffle (bool, optional): Whether to shuffle the dataset order.\n    Returns:\n        DataLoader\n    \"\"\"\n    if isinstance(config.dataset, HFHubImageCaptionDatasetConfig):\n        base_dataset = build_hf_hub_image_caption_dataset(config.dataset)\n    elif isinstance(config.dataset, ImageCaptionJsonlDatasetConfig):\n        base_dataset = build_image_caption_jsonl_dataset(config.dataset)\n    elif isinstance(config.dataset, ImageCaptionDirDatasetConfig):\n        base_dataset = build_image_caption_dir_dataset(config.dataset)\n    else:\n        raise ValueError(f\"Unexpected dataset config type: '{type(config.dataset)}'.\")\n\n    # Initialize either the fixed target resolution or aspect ratio buckets.\n    if config.aspect_ratio_buckets is None:\n        target_resolution = config.resolution\n        aspect_ratio_bucket_manager = None\n        batch_sampler = None\n    else:\n        target_resolution = None\n        aspect_ratio_bucket_manager = build_aspect_ratio_bucket_manager(config=config.aspect_ratio_buckets)\n        # TODO(ryand): Drill-down the seed parameter rather than hard-coding to 0 here.\n        batch_sampler = AspectRatioBucketBatchSampler.from_image_sizes(\n            bucket_manager=aspect_ratio_bucket_manager,\n            image_sizes=base_dataset.get_image_dimensions(),\n            batch_size=batch_size,\n            shuffle=shuffle,\n            seed=0,\n        )\n\n    all_transforms = []\n\n    if config.caption_prefix is not None:\n        all_transforms.append(CaptionPrefixTransform(caption_field_name=\"caption\", prefix=config.caption_prefix + \" \"))\n\n    if vae_output_cache_dir is None:\n        image_field_names = [\"image\"]\n        if use_masks:\n            image_field_names.append(\"mask\")\n        else:\n            all_transforms.append(DropFieldTransform(\"mask\"))\n\n        all_transforms.append(\n            SDImageTransform(\n                image_field_names=image_field_names,\n                fields_to_normalize_to_range_minus_one_to_one=[\"image\"],\n                resolution=target_resolution,\n                aspect_ratio_bucket_manager=aspect_ratio_bucket_manager,\n                center_crop=config.center_crop,\n                random_flip=config.random_flip,\n            )\n        )\n    else:\n        # We drop the image to avoid having to either convert from PIL, or handle PIL batch collation.\n        all_transforms.append(DropFieldTransform(\"image\"))\n        all_transforms.append(DropFieldTransform(\"mask\"))\n\n        vae_cache = TensorDiskCache(vae_output_cache_dir)\n\n        cache_field_to_output_field = {\n            \"vae_output\": \"vae_output\",\n            \"original_size_hw\": \"original_size_hw\",\n            \"crop_top_left_yx\": \"crop_top_left_yx\",\n        }\n        if use_masks:\n            cache_field_to_output_field[\"mask\"] = \"mask\"\n        all_transforms.append(\n            LoadCacheTransform(\n                cache=vae_cache,\n                cache_key_field=\"id\",\n                cache_field_to_output_field=cache_field_to_output_field,\n            )\n        )\n\n    if text_encoder_output_cache_dir is not None:\n        assert text_encoder_cache_field_to_output_field is not None\n        text_encoder_cache = TensorDiskCache(text_encoder_output_cache_dir)\n        all_transforms.append(\n            LoadCacheTransform(\n                cache=text_encoder_cache,\n                cache_key_field=\"id\",\n                cache_field_to_output_field=text_encoder_cache_field_to_output_field,\n            )\n        )\n\n    dataset = TransformDataset(base_dataset, all_transforms)\n\n    if batch_sampler is None:\n        return DataLoader(\n            dataset,\n            shuffle=shuffle,\n            collate_fn=sd_image_caption_collate_fn,\n            batch_size=batch_size,\n            num_workers=config.dataloader_num_workers,\n        )\n    else:\n        return DataLoader(\n            dataset,\n            batch_sampler=batch_sampler,\n            collate_fn=sd_image_caption_collate_fn,\n            num_workers=config.dataloader_num_workers,\n        )\n"
  },
  {
    "path": "src/invoke_training/_shared/data/data_loaders/image_pair_preference_sd_dataloader.py",
    "content": "import typing\n\nimport torch\nfrom torch.utils.data import DataLoader\n\nfrom invoke_training._shared.data.datasets.build_dataset import build_hf_image_pair_preference_dataset\nfrom invoke_training._shared.data.datasets.image_pair_preference_dataset import ImagePairPreferenceDataset\nfrom invoke_training._shared.data.datasets.transform_dataset import TransformDataset\nfrom invoke_training._shared.data.transforms.load_cache_transform import LoadCacheTransform\nfrom invoke_training._shared.data.transforms.sd_image_transform import SDImageTransform\nfrom invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache\nfrom invoke_training.pipelines._experimental.sd_dpo_lora.config import ImagePairPreferenceSDDataLoaderConfig\n\n\ndef sd_image_pair_preference_collate_fn(examples):\n    \"\"\"A batch collation function.\"\"\"\n\n    stack_keys = {\"image_0\", \"image_1\", \"prompt_embeds\", \"pooled_prompt_embeds\", \"text_encoder_output\", \"vae_output\"}\n    list_keys = {\n        \"id\",\n        \"original_size_hw_0\",\n        \"original_size_hw_1\",\n        \"crop_top_left_yx_0\",\n        \"crop_top_left_yx_1\",\n        \"prefer_0\",\n        \"prefer_1\",\n        \"caption\",\n    }\n\n    unhandled_keys = set(examples[0].keys()) - (stack_keys | list_keys)\n    if len(unhandled_keys) > 0:\n        raise ValueError(f\"The following keys are not handled by the collate function: {unhandled_keys}.\")\n\n    out_examples = {}\n\n    # torch.stack(...)\n    for k in stack_keys:\n        if k in examples[0]:\n            out_examples[k] = torch.stack([example[k] for example in examples])\n\n    # Basic list.\n    for k in list_keys:\n        if k in examples[0]:\n            out_examples[k] = [example[k] for example in examples]\n\n    return out_examples\n\n\ndef build_image_pair_preference_sd_dataloader(\n    config: ImagePairPreferenceSDDataLoaderConfig,\n    batch_size: int,\n    text_encoder_output_cache_dir: typing.Optional[str] = None,\n    text_encoder_cache_field_to_output_field: typing.Optional[dict[str, str]] = None,\n    vae_output_cache_dir: typing.Optional[str] = None,\n    shuffle: bool = True,\n) -> DataLoader:\n    \"\"\"Construct a DataLoader for an image-caption dataset for Stable Diffusion XL.\n\n    Args:\n        config (ImageCaptionSDDataLoaderConfig): The dataset config.\n        batch_size (int): The DataLoader batch size.\n        text_encoder_output_cache_dir (str, optional): The directory where text encoder outputs are cached and should be\n            loaded from. If set, then the TokenizeTransform will not be applied.\n        vae_output_cache_dir (str, optional): The directory where VAE outputs are cached and should be loaded from. If\n            set, then the image augmentation transforms will be skipped, and the image will not be copied to VRAM.\n        shuffle (bool, optional): Whether to shuffle the dataset order.\n    Returns:\n        DataLoader\n    \"\"\"\n    if config.dataset.type == \"HF_HUB_IMAGE_PAIR_PREFERENCE_DATASET\":\n        base_dataset = build_hf_image_pair_preference_dataset(config=config.dataset)\n    elif config.dataset.type == \"IMAGE_PAIR_PREFERENCE_DATASET\":\n        base_dataset = ImagePairPreferenceDataset(dataset_dir=config.dataset.dataset_dir)\n    else:\n        raise ValueError(f\"Unexpected dataset config type: '{type(config.dataset)}'.\")\n\n    target_resolution = config.resolution\n\n    all_transforms = []\n    if vae_output_cache_dir is None:\n        # TODO(ryand): Should I process both images in a single SDImageTransform so that they undergo the same\n        # transformations?\n        all_transforms.append(\n            SDImageTransform(\n                image_field_names=[\"image_0\"],\n                fields_to_normalize_to_range_minus_one_to_one=[\"image_0\"],\n                resolution=target_resolution,\n                aspect_ratio_bucket_manager=None,\n                center_crop=config.center_crop,\n                random_flip=config.random_flip,\n                orig_size_field_name=\"original_size_hw_0\",\n                crop_field_name=\"crop_top_left_yx_0\",\n            )\n        )\n        all_transforms.append(\n            SDImageTransform(\n                image_field_names=[\"image_1\"],\n                fields_to_normalize_to_range_minus_one_to_one=[\"image_1\"],\n                resolution=target_resolution,\n                aspect_ratio_bucket_manager=None,\n                center_crop=config.center_crop,\n                random_flip=config.random_flip,\n                orig_size_field_name=\"original_size_hw_1\",\n                crop_field_name=\"crop_top_left_yx_1\",\n            )\n        )\n    else:\n        raise NotImplementedError(\"VAE caching is not yet implemented.\")\n        # vae_cache = TensorDiskCache(vae_output_cache_dir)\n        # all_transforms.append(\n        #     LoadCacheTransform(\n        #         cache=vae_cache,\n        #         cache_key_field=\"id\",\n        #         cache_field_to_output_field={\n        #             \"vae_output\": \"vae_output\",\n        #             \"original_size_hw\": \"original_size_hw\",\n        #             \"crop_top_left_yx\": \"crop_top_left_yx\",\n        #         },\n        #     )\n        # )\n        # # We drop the image to avoid having to either convert from PIL, or handle PIL batch collation.\n        # all_transforms.append(DropFieldTransform(\"image\"))\n\n    if text_encoder_output_cache_dir is not None:\n        assert text_encoder_cache_field_to_output_field is not None\n        text_encoder_cache = TensorDiskCache(text_encoder_output_cache_dir)\n        all_transforms.append(\n            LoadCacheTransform(\n                cache=text_encoder_cache,\n                cache_key_field=\"id\",\n                cache_field_to_output_field=text_encoder_cache_field_to_output_field,\n            )\n        )\n\n    dataset = TransformDataset(base_dataset, all_transforms)\n\n    return DataLoader(\n        dataset,\n        shuffle=shuffle,\n        collate_fn=sd_image_pair_preference_collate_fn,\n        batch_size=batch_size,\n        num_workers=config.dataloader_num_workers,\n    )\n"
  },
  {
    "path": "src/invoke_training/_shared/data/data_loaders/textual_inversion_sd_dataloader.py",
    "content": "from typing import Literal, Optional\n\nfrom torch.utils.data import DataLoader\n\nfrom invoke_training._shared.data.data_loaders.image_caption_sd_dataloader import (\n    build_aspect_ratio_bucket_manager,\n    sd_image_caption_collate_fn,\n)\nfrom invoke_training._shared.data.datasets.build_dataset import (\n    build_hf_hub_image_caption_dataset,\n    build_image_caption_dir_dataset,\n    build_image_caption_jsonl_dataset,\n)\nfrom invoke_training._shared.data.datasets.image_dir_dataset import ImageDirDataset\nfrom invoke_training._shared.data.datasets.transform_dataset import TransformDataset\nfrom invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import AspectRatioBucketBatchSampler\nfrom invoke_training._shared.data.transforms.concat_fields_transform import ConcatFieldsTransform\nfrom invoke_training._shared.data.transforms.drop_field_transform import DropFieldTransform\nfrom invoke_training._shared.data.transforms.load_cache_transform import LoadCacheTransform\nfrom invoke_training._shared.data.transforms.sd_image_transform import SDImageTransform\nfrom invoke_training._shared.data.transforms.shuffle_caption_transform import ShuffleCaptionTransform\nfrom invoke_training._shared.data.transforms.template_caption_transform import TemplateCaptionTransform\nfrom invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache\nfrom invoke_training.config.data.data_loader_config import TextualInversionSDDataLoaderConfig\nfrom invoke_training.config.data.dataset_config import (\n    HFHubImageCaptionDatasetConfig,\n    ImageCaptionDirDatasetConfig,\n    ImageCaptionJsonlDatasetConfig,\n    ImageDirDatasetConfig,\n)\n\n\ndef get_preset_ti_caption_templates(preset: Literal[\"object\", \"style\"]) -> list[str]:\n    if preset == \"object\":\n        return [\n            \"a photo of a {}\",\n            \"a rendering of a {}\",\n            \"a cropped photo of the {}\",\n            \"the photo of a {}\",\n            \"a photo of a clean {}\",\n            \"a photo of a dirty {}\",\n            \"a dark photo of the {}\",\n            \"a photo of my {}\",\n            \"a photo of the cool {}\",\n            \"a close-up photo of a {}\",\n            \"a bright photo of the {}\",\n            \"a cropped photo of a {}\",\n            \"a photo of the {}\",\n            \"a good photo of the {}\",\n            \"a photo of one {}\",\n            \"a close-up photo of the {}\",\n            \"a rendition of the {}\",\n            \"a photo of the clean {}\",\n            \"a rendition of a {}\",\n            \"a photo of a nice {}\",\n            \"a good photo of a {}\",\n            \"a photo of the nice {}\",\n            \"a photo of the small {}\",\n            \"a photo of the weird {}\",\n            \"a photo of the large {}\",\n            \"a photo of a cool {}\",\n            \"a photo of a small {}\",\n        ]\n    elif preset == \"style\":\n        return [\n            \"a painting in the style of {}\",\n            \"a rendering in the style of {}\",\n            \"a cropped painting in the style of {}\",\n            \"the painting in the style of {}\",\n            \"a clean painting in the style of {}\",\n            \"a dirty painting in the style of {}\",\n            \"a dark painting in the style of {}\",\n            \"a picture in the style of {}\",\n            \"a cool painting in the style of {}\",\n            \"a close-up painting in the style of {}\",\n            \"a bright painting in the style of {}\",\n            \"a good painting in the style of {}\",\n            \"a close-up painting in the style of {}\",\n            \"a rendition in the style of {}\",\n            \"a nice painting in the style of {}\",\n            \"a small painting in the style of {}\",\n            \"a weird painting in the style of {}\",\n            \"a large painting in the style of {}\",\n            \"a photo in the style of {}\",\n            \"an image in the style of {}\",\n            \"a drawing in the style of {}\",\n            \"a sketch in the style of {}\",\n            \"a digital work in the style of {}\",\n            \"a digital rendering in the style of {}\",\n            \"a photograph in the style of {}\",\n            \"photography in the style of {}\",\n        ]\n    else:\n        raise ValueError(f\"Unrecognized learnable property type: '{preset}'.\")\n\n\ndef build_textual_inversion_sd_dataloader(  # noqa: C901\n    config: TextualInversionSDDataLoaderConfig,\n    placeholder_token: str,\n    batch_size: int,\n    use_masks: bool = False,\n    vae_output_cache_dir: Optional[str] = None,\n    shuffle: bool = True,\n) -> DataLoader:\n    \"\"\"Construct a DataLoader for a Textual Inversion dataset for Stable Diffusion.\n\n    Args:\n        config (TextualInversionSDDataLoaderConfig): The dataset config.\n        placeholder_token (str): The placeholder token being trained.\n        batch_size (int): The DataLoader batch size.\n        vae_output_cache_dir (str, optional): The directory where VAE outputs are cached and should be loaded from. If\n            set, then the image augmentation transforms will be skipped, and the image will not be copied to VRAM.\n        shuffle (bool, optional): Whether to shuffle the dataset order.\n    Returns:\n        DataLoader\n    \"\"\"\n    if isinstance(config.dataset, HFHubImageCaptionDatasetConfig):\n        base_dataset = build_hf_hub_image_caption_dataset(config.dataset)\n    elif isinstance(config.dataset, ImageCaptionJsonlDatasetConfig):\n        base_dataset = build_image_caption_jsonl_dataset(config.dataset)\n    elif isinstance(config.dataset, ImageCaptionDirDatasetConfig):\n        base_dataset = build_image_caption_dir_dataset(config.dataset)\n    elif isinstance(config.dataset, ImageDirDatasetConfig):\n        base_dataset = ImageDirDataset(\n            image_dir=config.dataset.dataset_dir, keep_in_memory=config.dataset.keep_in_memory\n        )\n    else:\n        raise ValueError(f\"Unexpected dataset config type: '{type(config.dataset)}'.\")\n\n    # Initialize either the fixed target resolution or aspect ratio buckets.\n    if config.aspect_ratio_buckets is None:\n        target_resolution = config.resolution\n        aspect_ratio_bucket_manager = None\n        batch_sampler = None\n    else:\n        target_resolution = None\n        aspect_ratio_bucket_manager = build_aspect_ratio_bucket_manager(config=config.aspect_ratio_buckets)\n        # TODO(ryand): Drill-down the seed parameter rather than hard-coding to 0 here.\n        batch_sampler = AspectRatioBucketBatchSampler.from_image_sizes(\n            bucket_manager=aspect_ratio_bucket_manager,\n            image_sizes=base_dataset.get_image_dimensions(),\n            batch_size=batch_size,\n            shuffle=shuffle,\n            seed=0,\n        )\n\n    if sum([config.caption_templates is not None, config.caption_preset is not None]) != 1:\n        raise ValueError(\"Either caption_templates or caption_preset must be set.\")\n\n    if config.caption_templates is not None:\n        # Overwrites the caption field. Typically used with a ImageDirDataset that does not have captions.\n        caption_tf = TemplateCaptionTransform(\n            field_name=\"caption_prefix\" if config.keep_original_captions else \"caption\",\n            placeholder_str=placeholder_token,\n            caption_templates=config.caption_templates,\n        )\n    elif config.caption_preset is not None:\n        # Overwrites the caption field. Typically used with a ImageDirDataset that does not have captions.\n        caption_tf = TemplateCaptionTransform(\n            field_name=\"caption_prefix\" if config.keep_original_captions else \"caption\",\n            placeholder_str=placeholder_token,\n            caption_templates=get_preset_ti_caption_templates(config.caption_preset),\n        )\n    else:\n        raise ValueError(\"Either caption_templates or caption_preset must be set.\")\n\n    all_transforms = [caption_tf]\n\n    if config.keep_original_captions:\n        # This will only work with a HFHubImageCaptionDataset or HFDirImageCaptionDataset that already has captions.\n        all_transforms.append(\n            ConcatFieldsTransform(\n                src_field_names=[\"caption_prefix\", \"caption\"], dst_field_name=\"caption\", separator=\" \"\n            )\n        )\n\n    if config.shuffle_caption_delimiter is not None:\n        all_transforms.append(ShuffleCaptionTransform(field_name=\"caption\", delimiter=config.shuffle_caption_delimiter))\n\n    if vae_output_cache_dir is None:\n        image_field_names = [\"image\"]\n        if use_masks:\n            image_field_names.append(\"mask\")\n        else:\n            all_transforms.append(DropFieldTransform(\"mask\"))\n\n        all_transforms.append(\n            SDImageTransform(\n                image_field_names=image_field_names,\n                fields_to_normalize_to_range_minus_one_to_one=[\"image\"],\n                resolution=target_resolution,\n                aspect_ratio_bucket_manager=aspect_ratio_bucket_manager,\n                center_crop=config.center_crop,\n                random_flip=config.random_flip,\n            )\n        )\n    else:\n        # We drop the image to avoid having to either convert from PIL, or handle PIL batch collation.\n        all_transforms.append(DropFieldTransform(\"image\"))\n        all_transforms.append(DropFieldTransform(\"mask\"))\n\n        vae_cache = TensorDiskCache(vae_output_cache_dir)\n\n        cache_field_to_output_field = {\n            \"vae_output\": \"vae_output\",\n            \"original_size_hw\": \"original_size_hw\",\n            \"crop_top_left_yx\": \"crop_top_left_yx\",\n        }\n        if use_masks:\n            cache_field_to_output_field[\"mask\"] = \"mask\"\n\n        all_transforms.append(\n            LoadCacheTransform(\n                cache=vae_cache,\n                cache_key_field=\"id\",\n                cache_field_to_output_field=cache_field_to_output_field,\n            )\n        )\n\n    dataset = TransformDataset(base_dataset, all_transforms)\n\n    if batch_sampler is None:\n        return DataLoader(\n            dataset,\n            shuffle=shuffle,\n            collate_fn=sd_image_caption_collate_fn,\n            batch_size=batch_size,\n            num_workers=config.dataloader_num_workers,\n            persistent_workers=config.dataloader_num_workers > 0,\n        )\n    else:\n        return DataLoader(\n            dataset,\n            batch_sampler=batch_sampler,\n            collate_fn=sd_image_caption_collate_fn,\n            num_workers=config.dataloader_num_workers,\n            persistent_workers=config.dataloader_num_workers > 0,\n        )\n"
  },
  {
    "path": "src/invoke_training/_shared/data/datasets/__init__.py",
    "content": ""
  },
  {
    "path": "src/invoke_training/_shared/data/datasets/build_dataset.py",
    "content": "from datasets import VerificationMode\n\nfrom invoke_training._shared.data.datasets.hf_image_caption_dataset import HFImageCaptionDataset\nfrom invoke_training._shared.data.datasets.hf_image_pair_preference_dataset import HFImagePairPreferenceDataset\nfrom invoke_training._shared.data.datasets.image_caption_dir_dataset import ImageCaptionDirDataset\nfrom invoke_training._shared.data.datasets.image_caption_jsonl_dataset import ImageCaptionJsonlDataset\nfrom invoke_training.config.data.dataset_config import (\n    HFHubImageCaptionDatasetConfig,\n    ImageCaptionDirDatasetConfig,\n    ImageCaptionJsonlDatasetConfig,\n)\nfrom invoke_training.pipelines._experimental.sd_dpo_lora.config import HFHubImagePairPreferenceDatasetConfig\n\n\ndef build_hf_hub_image_caption_dataset(config: HFHubImageCaptionDatasetConfig) -> HFImageCaptionDataset:\n    return HFImageCaptionDataset.from_hub(\n        dataset_name=config.dataset_name,\n        hf_load_dataset_kwargs={\n            \"name\": config.dataset_config_name,\n            \"cache_dir\": config.hf_cache_dir,\n        },\n        image_column=config.image_column,\n        caption_column=config.caption_column,\n    )\n\n\ndef build_image_caption_jsonl_dataset(config: ImageCaptionJsonlDatasetConfig) -> HFImageCaptionDataset:\n    return ImageCaptionJsonlDataset(\n        jsonl_path=config.jsonl_path,\n        image_column=config.image_column,\n        caption_column=config.caption_column,\n        keep_in_memory=config.keep_in_memory,\n    )\n\n\ndef build_image_caption_dir_dataset(config: ImageCaptionDirDatasetConfig) -> ImageCaptionDirDataset:\n    return ImageCaptionDirDataset(\n        dataset_dir=config.dataset_dir,\n        keep_in_memory=config.keep_in_memory,\n    )\n\n\ndef build_hf_image_pair_preference_dataset(\n    config: HFHubImagePairPreferenceDatasetConfig,\n) -> HFImagePairPreferenceDataset:\n    # HACK(ryand): This is currently hard-coded to just download a small slice of the very large\n    # 'yuvalkirstain/pickapic_v2' dataset.\n    return HFImagePairPreferenceDataset.from_hub(\n        \"yuvalkirstain/pickapic_v2\",\n        split=\"train\",\n        hf_load_dataset_kwargs={\n            \"data_files\": {\n                # \"validation_unique\": \"data/validation_unique-00000-of-00001-33ead111845fc9c4.parquet\",\n                \"train\": [\n                    \"data/train-00000-of-00645-b66ac786bf6fb553.parquet\",\n                    \"data/train-00001-of-00645-c7b349dd222d6515.parquet\",\n                    \"data/train-00002-of-00645-e4f54d615a978deb.parquet\",\n                    \"data/train-00003-of-00645-2b9d59bac8b433ff.parquet\",\n                    \"data/train-00004-of-00645-e4964649dc0ea543.parquet\",\n                    \"data/train-00005-of-00645-45e8efc0fe93f6e9.parquet\",\n                ]\n            },\n            # Disable checks so that it doesn't complain that I haven't downloaded the other splits.\n            \"verification_mode\": VerificationMode.NO_CHECKS,\n        },\n    )\n"
  },
  {
    "path": "src/invoke_training/_shared/data/datasets/hf_image_caption_dataset.py",
    "content": "import os\nimport typing\n\nimport datasets\nimport torch.utils.data\nfrom PIL.Image import Image\n\nfrom invoke_training._shared.data.utils.resolution import Resolution\n\n\nclass HFImageCaptionDataset(torch.utils.data.Dataset):\n    \"\"\"An image-caption dataset wrapper for Hugging Face datasets.\n\n    The wrapped HF dataset can be either from the HF hub, or in Imagefolder format\n    (https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder).\n    \"\"\"\n\n    def __init__(self, hf_dataset, image_column: str = \"image\", caption_column: str = \"text\"):\n        column_names = hf_dataset[\"train\"].column_names\n        if image_column not in column_names:\n            raise ValueError(\n                f\"The image_column='{image_column}' is not in the set of dataset column names: '{column_names}'.\"\n            )\n\n        if caption_column not in column_names:\n            raise ValueError(\n                f\"The caption_column='{caption_column}' is not in the set of dataset column names: '{column_names}'.\"\n            )\n\n        self._image_column = image_column\n\n        def preprocess(examples):\n            images = [image.convert(\"RGB\") for image in examples[image_column]]\n            return {\n                \"image\": images,\n                \"caption\": examples[caption_column],\n            }\n\n        self._hf_dataset = hf_dataset[\"train\"].with_transform(preprocess)\n\n    @classmethod\n    def from_dir(\n        cls,\n        dataset_dir: str,\n        hf_load_dataset_kwargs: typing.Optional[dict[str, typing.Any]] = None,\n        image_column: str = \"image\",\n        caption_column: str = \"text\",\n    ):\n        \"\"\"Initialize a HFImageCaptionDataset from a Hugging Face ImageFolder dataset directory\n        (https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder).\n\n        Args:\n            dataset_dir (str): The path to the dataset directory.\n            hf_load_dataset_kwargs (dict[str, typing.Any], optional): kwargs to forward to `datasets.load_dataset(...)`.\n            image_column (str, optional): The name of the image column in the dataset. Defaults to \"image\".\n            caption_column (str, optional): The name of the caption column in the dataset. Defaults to \"text\".\n        \"\"\"\n        hf_load_dataset_kwargs = hf_load_dataset_kwargs or {}\n        data_files = {\"train\": os.path.join(dataset_dir, \"**\")}\n        # See more about loading custom images at\n        # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder\n        hf_dataset = datasets.load_dataset(\"imagefolder\", data_files=data_files, **hf_load_dataset_kwargs)\n\n        return cls(hf_dataset=hf_dataset, image_column=image_column, caption_column=caption_column)\n\n    @classmethod\n    def from_hub(\n        cls,\n        dataset_name: str,\n        hf_load_dataset_kwargs: typing.Optional[dict[str, typing.Any]] = None,\n        image_column: str = \"image\",\n        caption_column: str = \"text\",\n    ):\n        \"\"\"Initialize a HFImageCaptionDataset from a Hugging Face Hub dataset.\n\n        Args:\n            dataset_name (str): The HF Hub dataset name (a.k.a. path).\n            hf_load_dataset_kwargs (dict[str, typing.Any], optional): kwargs to forward to `datasets.load_dataset(...)`.\n            image_column (str, optional): The name of the image column in the dataset. Defaults to \"image\".\n            caption_column (str, optional): The name of the caption column in the dataset. Defaults to \"text\".\n        \"\"\"\n        hf_load_dataset_kwargs = hf_load_dataset_kwargs or {}\n        hf_dataset = datasets.load_dataset(dataset_name, **hf_load_dataset_kwargs)\n\n        return cls(hf_dataset=hf_dataset, image_column=image_column, caption_column=caption_column)\n\n    def get_image_dimensions(self) -> list[Resolution]:\n        \"\"\"Get the dimensions of all images in the dataset.\n\n        TODO(ryand): Re-think this approach. For large datasets (e.g. streaming from S3) it doesn't make sense to\n        calculate this dynamically every time.\n        \"\"\"\n        image_dims: list[Resolution] = []\n        for i in range(len(self._hf_dataset)):\n            example = self._hf_dataset[i]\n            image: Image = example[self._image_column]\n            image_dims.append(Resolution(image.height, image.width))\n\n        return image_dims\n\n    def __len__(self) -> int:\n        \"\"\"Get the dataset length.\n\n        Returns:\n            int: The number of image-caption pairs in the dataset.\n        \"\"\"\n        return len(self._hf_dataset)\n\n    def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]:\n        \"\"\"Load the dataset example at index `idx`.\n\n        Raises:\n            IndexError: If `idx` is out of range.\n\n        Returns:\n            dict: A dataset example with 3 keys: \"image\", \"caption\", and \"id\".\n                The \"image\" key maps to a `PIL` image in RGB format.\n                The \"caption\" key maps to a string.\n                The \"id\" key is the example's index (often used for caching).\n        \"\"\"\n        example = self._hf_dataset[idx]\n        example[\"id\"] = idx\n        return example\n"
  },
  {
    "path": "src/invoke_training/_shared/data/datasets/hf_image_pair_preference_dataset.py",
    "content": "import io\nimport typing\n\nimport datasets\nimport torch.utils.data\nfrom PIL import Image\n\n\nclass HFImagePairPreferenceDataset(torch.utils.data.Dataset):\n    \"\"\"A wrapper for the Hugging Face hub \"yuvalkirstain/pickapic_v2\" dataset\n    (https://huggingface.co/datasets/yuvalkirstain/pickapic_v2).\n\n    Designed to be expanded in the future to other HF image pair preference datasets.\n    \"\"\"\n\n    def __init__(\n        self,\n        hf_dataset,\n        skip_no_preference=True,\n        split: str = \"train\",\n        image_0_column: str = \"jpg_0\",\n        label_0_column: str = \"label_0\",\n        image_1_column: str = \"jpg_1\",\n        label_1_column: str = \"jpg_1\",\n        caption_column: str = \"caption\",\n    ):\n        \"\"\"\n        Args:\n            skip_no_preference (bool, optional): If True, skip image pairs without a preference.\n        \"\"\"\n        column_names = hf_dataset[split].column_names\n\n        for col_name in [image_0_column, label_0_column, image_1_column, label_1_column, caption_column]:\n            if col_name not in column_names:\n                raise ValueError(f\"Column '{col_name}' is not in the set of dataset column names: '{column_names}'.\")\n\n        eps = 0.0001\n\n        if skip_no_preference:\n            # Filter to only include pairs with a clear preference.\n            def filter(example: dict[str, typing.Any]) -> bool:\n                return abs(example[\"label_0\"] - example[\"label_1\"]) > eps\n\n            hf_dataset = hf_dataset.filter(filter)\n\n        def preprocess(examples):\n            image_0_list = [Image.open(io.BytesIO(image)).convert(\"RGB\") for image in examples[image_0_column]]\n            image_1_list = [Image.open(io.BytesIO(image)).convert(\"RGB\") for image in examples[image_1_column]]\n\n            image_0_is_better = []\n            image_1_is_better = []\n            for label_0, label_1 in zip(examples[\"label_0\"], examples[\"label_1\"]):\n                if (label_0 - label_1) > eps:\n                    # Label 0 is better.\n                    image_0_is_better.append(True)\n                    image_1_is_better.append(False)\n                elif (label_1 - label_0) > eps:\n                    # Label 1 is better.\n                    image_0_is_better.append(False)\n                    image_1_is_better.append(True)\n                else:\n                    # Tie.\n                    image_0_is_better.append(False)\n                    image_1_is_better.append(False)\n\n            return {\n                \"image_0\": image_0_list,\n                \"image_1\": image_1_list,\n                \"prefer_0\": image_0_is_better,\n                \"prefer_1\": image_1_is_better,\n                \"caption\": examples[caption_column],\n            }\n\n        self._hf_dataset = hf_dataset[split].with_transform(preprocess)\n\n    @classmethod\n    def from_hub(\n        cls,\n        dataset_name: str,\n        skip_no_preference: bool = True,\n        split: str = \"train\",\n        hf_load_dataset_kwargs: typing.Optional[dict[str, typing.Any]] = None,\n    ):\n        \"\"\"Initialize a HFImageCaptionDataset from a Hugging Face Hub dataset.\n\n        Args:\n            dataset_name (str): The HF Hub dataset name (a.k.a. path).\n            hf_load_dataset_kwargs (dict[str, typing.Any], optional): kwargs to forward to `datasets.load_dataset(...)`.\n        \"\"\"\n        if dataset_name != \"yuvalkirstain/pickapic_v2\":\n            raise NotImplementedError(\n                \"The HFImagePairPreferenceDataset class likely won't work with datasets other than \"\n                \"'yuvalkirstain/pickapic_v2'.\"\n            )\n\n        hf_load_dataset_kwargs = hf_load_dataset_kwargs or {}\n        hf_dataset = datasets.load_dataset(dataset_name, **hf_load_dataset_kwargs)\n\n        return cls(hf_dataset=hf_dataset, skip_no_preference=skip_no_preference, split=split)\n\n    def __len__(self) -> int:\n        \"\"\"Get the dataset length.\n\n        Returns:\n            int: The number of image pairs in the dataset.\n        \"\"\"\n        return len(self._hf_dataset)\n\n    def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]:\n        \"\"\"Load the dataset example at index `idx`.\n\n        Raises:\n            IndexError: If `idx` is out of range.\n\n        Returns:\n            dict: A dataset example with the following keys: [\"id\", \"image_1\", \"caption_1\", \"image_2\", \"caption_2\",\n                \"prefer_1\", \"prefer_2\"]\n                The image keys map to a `PIL` image in RGB format.\n                The caption keys map to strings.\n                The \"id\" key is the example's index (often used for caching).\n        \"\"\"\n        example = self._hf_dataset[idx]\n        example[\"id\"] = idx\n        return example\n"
  },
  {
    "path": "src/invoke_training/_shared/data/datasets/image_caption_dir_dataset.py",
    "content": "import os\nimport typing\n\nimport torch.utils.data\nfrom PIL import Image\n\nfrom invoke_training._shared.data.utils.resolution import Resolution\n\n\nclass ImageCaptionDirDataset(torch.utils.data.Dataset):\n    \"\"\"A dataset that loads images and captions from a directory of image files and .txt files.\"\"\"\n\n    def __init__(\n        self,\n        dataset_dir: str,\n        id_prefix: str = \"\",\n        image_extensions: typing.Optional[list[str]] = None,\n        caption_extension: str = \".txt\",\n        keep_in_memory: bool = False,\n    ):\n        \"\"\"Initialize an ImageDirDataset\n\n        Args:\n            image_dir (str): The directory to load images from.\n            id_prefix (str): A prefix added to the 'id' field in every example.\n            image_extensions (list[str], optional): The list of image file extensions to include in the dataset (not\n                case-sensitive). Defaults to [\".jpg\", \".jpeg\", \".png\"].\n            keep_in_memory (bool, optional): If True, keep all images loaded in memory. This improves performance for\n                datasets that are small enough to be kept in memory.\n        \"\"\"\n        super().__init__()\n        self._id_prefix = id_prefix\n        if image_extensions is None:\n            image_extensions = [\".jpg\", \".jpeg\", \".png\"]\n        image_extensions = [ext.lower() for ext in image_extensions]\n\n        # Determine the list of image paths to include in the dataset.\n        self._image_paths: list[str] = []\n        for image_file in os.listdir(dataset_dir):\n            image_path = os.path.join(dataset_dir, image_file)\n            if os.path.isfile(image_path) and os.path.splitext(image_path)[1].lower() in image_extensions:\n                self._image_paths.append(image_path)\n        self._image_paths.sort()\n\n        # Load captions from .txt files for each image.\n        self._captions: list[str] = []\n        missing_captions: list[str] = []\n        for image_path in self._image_paths:\n            caption_path = os.path.splitext(image_path)[0] + caption_extension\n            if os.path.isfile(caption_path):\n                with open(caption_path, \"r\") as f:\n                    self._captions.append(f.read().strip())\n            else:\n                missing_captions.append(caption_path)\n        if len(missing_captions) > 0:\n            raise Exception(f\"The following expected caption files are missing: {missing_captions}\")\n\n        self._images = None\n        if keep_in_memory:\n            self._images = []\n            for image_path in self._image_paths:\n                self._images.append(self._load_image(image_path))\n\n    def _load_image(self, image_path: str) -> Image.Image:\n        # We call `convert(\"RGB\")` to drop the alpha channel from RGBA images, or to repeat channels for greyscale\n        # images.\n        return Image.open(image_path).convert(\"RGB\")\n\n    def get_image_dimensions(self) -> list[Resolution]:\n        \"\"\"Get the dimensions of all images in the dataset.\n\n        TODO(ryand): Re-think this approach. For large datasets (e.g. streaming from S3) it doesn't make sense to\n        calculate this dynamically every time.\n        \"\"\"\n        image_dims: list[Resolution] = []\n        for i in range(len(self._image_paths)):\n            image_path = self._image_paths[i]\n            image = Image.open(image_path)\n            image_dims.append(Resolution(image.height, image.width))\n\n        return image_dims\n\n    def __len__(self) -> int:\n        return len(self._image_paths)\n\n    def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]:\n        image = self._images[idx] if self._images is not None else self._load_image(self._image_paths[idx])\n        return {\"id\": f\"{self._id_prefix}{idx}\", \"image\": image, \"caption\": self._captions[idx]}\n"
  },
  {
    "path": "src/invoke_training/_shared/data/datasets/image_caption_jsonl_dataset.py",
    "content": "import typing\nfrom pathlib import Path\n\nimport torch.utils.data\nfrom PIL import Image\nfrom pydantic import BaseModel\n\nfrom invoke_training._shared.data.utils.resolution import Resolution\nfrom invoke_training._shared.utils.jsonl import load_jsonl, save_jsonl\n\nIMAGE_COLUMN_DEFAULT = \"image\"\nCAPTION_COLUMN_DEFAULT = \"text\"\nMASK_COLUMN_DEFAULT = \"mask\"\n\n\nclass ImageCaptionExample(BaseModel):\n    image_path: str\n    mask_path: str | None = None\n    caption: str\n\n\nclass ImageCaptionJsonlDataset(torch.utils.data.Dataset):\n    \"\"\"A dataset that loads images and captions from a directory of image files and .txt files.\"\"\"\n\n    def __init__(\n        self,\n        jsonl_path: Path | str,\n        image_column: str = IMAGE_COLUMN_DEFAULT,\n        caption_column: str = CAPTION_COLUMN_DEFAULT,\n        keep_in_memory: bool = False,\n    ):\n        super().__init__()\n        self._jsonl_path = Path(jsonl_path)\n        self._image_column = image_column\n        self._caption_column = caption_column\n\n        data = load_jsonl(jsonl_path)\n        examples: list[ImageCaptionExample] = []\n        for d in data:\n            # Clear error messages here are helpful in the Gradio UI.\n            if image_column not in d:\n                raise ValueError(f\"Column '{image_column}' not found in jsonl file '{jsonl_path}'.\")\n            if caption_column not in d:\n                raise ValueError(f\"Column '{caption_column}' not found in jsonl file '{jsonl_path}'.\")\n            examples.append(\n                ImageCaptionExample(\n                    image_path=d[image_column], mask_path=d.get(MASK_COLUMN_DEFAULT, None), caption=d[caption_column]\n                )\n            )\n        self.examples = examples\n\n        self._keep_in_memory = keep_in_memory\n        self._example_cache: dict[int, dict[str, typing.Any]] = {}\n\n    def save_jsonl(self):\n        data = []\n        for example in self.examples:\n            data.append(\n                {\n                    self._image_column: example.image_path,\n                    self._caption_column: example.caption,\n                    MASK_COLUMN_DEFAULT: example.mask_path,\n                }\n            )\n        save_jsonl(data, self._jsonl_path)\n\n    def _get_image_path(self, idx: int) -> str:\n        image_path = self.examples[idx].image_path\n        image_path = Path(image_path)\n\n        # image_path could be either absolute, or relative to the jsonl file.\n        if not image_path.is_absolute():\n            image_path = self._jsonl_path.parent / image_path\n\n        return image_path\n\n    def _get_mask_path(self, idx: int) -> str:\n        mask_path = self.examples[idx].mask_path\n        mask_path = Path(mask_path)\n\n        # mask_path could be either absolute, or relative to the jsonl file.\n        if not mask_path.is_absolute():\n            mask_path = self._jsonl_path.parent / mask_path\n\n        return mask_path\n\n    def _load_image(self, image_path: str) -> Image.Image:\n        # We call `convert(\"RGB\")` to drop the alpha channel from RGBA images, or to repeat channels for greyscale\n        # images.\n        return Image.open(image_path).convert(\"RGB\")\n\n    def _load_mask(self, mask_path: str) -> Image.Image:\n        return Image.open(mask_path).convert(\"L\")\n\n    def _load_example(self, idx: int) -> dict[str, typing.Any]:\n        example = {\n            \"id\": str(idx),\n            \"image\": self._load_image(self._get_image_path(idx)),\n            \"caption\": self.examples[idx].caption,\n        }\n        if self.examples[idx].mask_path:\n            example[\"mask\"] = self._load_mask(self._get_mask_path(idx))\n        return example\n\n    def get_image_dimensions(self) -> list[Resolution]:\n        \"\"\"Get the dimensions of all images in the dataset.\n\n        TODO(ryand): Re-think this approach. For large datasets (e.g. streaming from S3) it doesn't make sense to\n        calculate this dynamically every time.\n        \"\"\"\n        image_dims: list[Resolution] = []\n        for i in range(len(self.examples)):\n            image = Image.open(self._get_image_path(i))\n            image_dims.append(Resolution(image.height, image.width))\n\n        return image_dims\n\n    def __len__(self) -> int:\n        return len(self.examples)\n\n    def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]:\n        if self._keep_in_memory:\n            if idx not in self._example_cache:\n                self._example_cache[idx] = self._load_example(idx)\n            # Return a shallow copy of the example to prevent the caller from modifying the cached example.\n            # Shallow rather than deep, because we don't want to copy the image data.\n            return self._example_cache[idx].copy()\n        return self._load_example(idx)\n"
  },
  {
    "path": "src/invoke_training/_shared/data/datasets/image_dir_dataset.py",
    "content": "import os\nimport typing\n\nimport torch.utils.data\nfrom PIL import Image\n\nfrom invoke_training._shared.data.utils.resolution import Resolution\n\n\nclass ImageDirDataset(torch.utils.data.Dataset):\n    \"\"\"A dataset that loads image files from a directory.\"\"\"\n\n    def __init__(\n        self,\n        image_dir: str,\n        id_prefix: str = \"\",\n        image_extensions: typing.Optional[list[str]] = None,\n        keep_in_memory: bool = False,\n    ):\n        \"\"\"Initialize an ImageDirDataset\n\n        Args:\n            image_dir (str): The directory to load images from.\n            id_prefix (str): A prefix added to the 'id' field in every example.\n            image_extensions (list[str], optional): The list of image file extensions to include in the dataset (not\n                case-sensitive). Defaults to [\".jpg\", \".jpeg\", \".png\"].\n            keep_in_memory (bool, optional): If True, keep all images loaded in memory. This improves performance for\n            datasets that are small enough to be kept in memory.\n        \"\"\"\n        super().__init__()\n        self._id_prefix = id_prefix\n        if image_extensions is None:\n            image_extensions = [\".jpg\", \".jpeg\", \".png\"]\n        image_extensions = [ext.lower() for ext in image_extensions]\n\n        self._image_paths = []\n\n        for image_file in os.listdir(image_dir):\n            image_path = os.path.join(image_dir, image_file)\n            if os.path.isfile(image_path) and os.path.splitext(image_path)[1].lower() in image_extensions:\n                self._image_paths.append(image_path)\n\n        self._images = None\n        if keep_in_memory:\n            self._images = []\n            for image_path in self._image_paths:\n                self._images.append(self._load_image(image_path))\n\n    def _load_image(self, image_path: str) -> Image.Image:\n        # We call `convert(\"RGB\")` to drop the alpha channel from RGBA images, or to repeat channels for greyscale\n        # images.\n        return Image.open(image_path).convert(\"RGB\")\n\n    def get_image_dimensions(self) -> list[Resolution]:\n        \"\"\"Get the dimensions of all images in the dataset.\n\n        TODO(ryand): Re-think this approach. For large datasets (e.g. streaming from S3) it doesn't make sense to\n        calculate this dynamically every time.\n        \"\"\"\n        image_dims: list[Resolution] = []\n        for i in range(len(self._image_paths)):\n            image_path = self._image_paths[i]\n            image = Image.open(image_path)\n            image_dims.append(Resolution(image.height, image.width))\n\n        return image_dims\n\n    def __len__(self) -> int:\n        return len(self._image_paths)\n\n    def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]:\n        image = self._images[idx] if self._images is not None else self._load_image(self._image_paths[idx])\n        return {\"id\": f\"{self._id_prefix}{idx}\", \"image\": image}\n"
  },
  {
    "path": "src/invoke_training/_shared/data/datasets/image_pair_preference_dataset.py",
    "content": "import os\nimport typing\nfrom pathlib import Path\n\nimport torch.utils.data\nfrom PIL import Image\n\nfrom invoke_training._shared.utils.jsonl import load_jsonl, save_jsonl\n\n\nclass ImagePairPreferenceDataset(torch.utils.data.Dataset):\n    def __init__(self, dataset_dir: str):\n        super().__init__()\n        self._dataset_dir = dataset_dir\n\n        self._metadata = load_jsonl(Path(dataset_dir) / \"metadata.jsonl\")\n\n    @classmethod\n    def save_metadata(\n        cls, metadata: list[dict[str, typing.Any]], dataset_dir: str | Path, metadata_file: str = \"metadata.jsonl\"\n    ) -> Path:\n        \"\"\"Load the dataset metadata from metadata.jsonl.\"\"\"\n        metadata_path = Path(dataset_dir) / metadata_file\n        save_jsonl(metadata, metadata_path)\n        return metadata_path\n\n    def __len__(self) -> int:\n        return len(self._metadata)\n\n    def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]:\n        # We call `convert(\"RGB\")` to drop the alpha channel from RGBA images, or to repeat channels for greyscale\n        # images.\n        example = self._metadata[idx]\n        image_0_path = os.path.join(self._dataset_dir, example[\"image_0\"])\n        image_1_path = os.path.join(self._dataset_dir, example[\"image_1\"])\n        return {\n            \"id\": str(idx),\n            \"image_0\": Image.open(image_0_path).convert(\"RGB\"),\n            \"image_1\": Image.open(image_1_path).convert(\"RGB\"),\n            \"caption\": example[\"prompt\"],\n            \"prefer_0\": example[\"prefer_0\"],\n            \"prefer_1\": example[\"prefer_1\"],\n        }\n"
  },
  {
    "path": "src/invoke_training/_shared/data/datasets/transform_dataset.py",
    "content": "import typing\n\nimport torch.utils.data\n\n# The data type expected to be produced by the base dataset and handled by transforms.\nDataType = typing.Dict[str, typing.Any]\n\nTransformType = typing.Callable[[DataType], DataType]\n\n\nclass TransformDataset(torch.utils.data.Dataset):\n    \"\"\"A Dataset that wraps a base dataset and applies callable transforms to its outputs.\"\"\"\n\n    def __init__(self, base_dataset: torch.utils.data.Dataset, transforms: list[TransformType]) -> None:\n        super().__init__()\n        self._base_dataset = base_dataset\n        self._transforms = transforms\n\n    def __len__(self) -> int:\n        return len(self._base_dataset)\n\n    def __getitem__(self, idx: int) -> DataType:\n        example = self._base_dataset[idx]\n        for t in self._transforms:\n            example = t(example)\n        return example\n"
  },
  {
    "path": "src/invoke_training/_shared/data/samplers/__init__.py",
    "content": ""
  },
  {
    "path": "src/invoke_training/_shared/data/samplers/aspect_ratio_bucket_batch_sampler.py",
    "content": "import copy\nimport logging\nimport math\nimport random\nfrom typing import Iterator\n\nfrom torch.utils.data import Sampler\n\nfrom invoke_training._shared.data.utils.aspect_ratio_bucket_manager import AspectRatioBucketManager\nfrom invoke_training._shared.data.utils.resolution import Resolution\n\nAspectRatioBuckets = dict[Resolution, list[int]]\n\n\nclass AspectRatioBucketBatchSampler(Sampler[list[int]]):\n    \"\"\"A batch sampler that adheres to aspect ratio buckets.\"\"\"\n\n    def __init__(\n        self,\n        buckets: AspectRatioBuckets,\n        batch_size: int,\n        shuffle: bool = False,\n        seed: int | None = None,\n    ) -> None:\n        \"\"\"Initialize AspectRatioBucketBatchSampler.\n\n        For most use cases, initialize via AspectRatioBucketBatchSampler.from_image_sizes(...).\n        \"\"\"\n        self._buckets = buckets\n        self._batch_size = batch_size\n        self._shuffle = shuffle\n        self._random = random.Random(seed)\n\n    def __str__(self) -> str:\n        buckets = self.get_buckets()\n        bucket_resolutions = sorted(list(buckets.keys()))\n        s = \"\"\n        for bucket_resolution in bucket_resolutions:\n            bucket_images = buckets[bucket_resolution]\n            s += f\"  {bucket_resolution.to_tuple()}: {len(bucket_images)}\\n\"\n        return s\n\n    @classmethod\n    def from_image_sizes(\n        cls,\n        bucket_manager: AspectRatioBucketManager,\n        image_sizes: list[Resolution],\n        batch_size: int,\n        shuffle: bool = False,\n        seed: int | None = None,\n    ):\n        \"\"\"Initialize from an AspectRatioBucketManager and the list of dataset image resolutions.\"\"\"\n        buckets = cls._build_bucket_to_index_map(bucket_manager, image_sizes)\n        return cls(buckets=buckets, batch_size=batch_size, shuffle=shuffle, seed=seed)\n\n    @classmethod\n    def _build_bucket_to_index_map(\n        cls,\n        bucket_manager: AspectRatioBucketManager,\n        image_sizes: list[Resolution],\n    ) -> AspectRatioBuckets:\n        bucket_to_indexes: AspectRatioBuckets = dict()\n\n        for bucket_resolution in bucket_manager.buckets:\n            bucket_to_indexes[bucket_resolution] = []\n\n        for index, image_size in enumerate(image_sizes):\n            aspect_ratio_bucket = bucket_manager.get_aspect_ratio_bucket(image_size)\n            bucket_to_indexes[aspect_ratio_bucket].append(index)\n\n        return bucket_to_indexes\n\n    def get_buckets(self) -> AspectRatioBuckets:\n        return copy.deepcopy(self._buckets)\n\n    def __iter__(self) -> Iterator[list[int]]:\n        batches: list[list[int]] = []\n\n        # TODO(ryand): If self._shuffle == False, should we still shuffle just with a fixed seed every time? If we\n        # don't shuffle at all then all of the batches from a bucket will be grouped together. If there's a correlation\n        # between aspect ratio and image content in a dataset, this could result in unevenly distributed image content\n        # over the dataset.\n\n        for bucket_resolution in sorted(list(self._buckets.keys())):\n            ordered_bucket_images = self._buckets[bucket_resolution].copy()\n            if self._shuffle:\n                # Shuffle the images within a bucket.\n                self._random.shuffle(ordered_bucket_images)\n\n            # Prepare batches for a single bucket.\n            batch_start = 0\n            while batch_start < len(ordered_bucket_images):\n                batch_end = min(batch_start + self._batch_size, len(ordered_bucket_images))\n                batches.append(ordered_bucket_images[batch_start:batch_end])\n                batch_start += self._batch_size\n\n        if self._shuffle:\n            # We've already shuffled the images within each bucket, now we shuffle the batches.\n            self._random.shuffle(batches)\n\n        yield from batches\n\n    def __len__(self) -> int:\n        num_batches = 0\n        for bucket_images in self._buckets.values():\n            num_batches += math.ceil(len(bucket_images) / self._batch_size)\n        return num_batches\n\n\ndef log_aspect_ratio_buckets(logger: logging.Logger, batch_sampler: AspectRatioBucketBatchSampler):\n    \"\"\"Utility function for logging the aspect ratio buckets.\"\"\"\n    if not isinstance(batch_sampler, AspectRatioBucketBatchSampler):\n        return\n\n    log = \"Aspect Ratio Buckets:\\n\"\n    log += str(batch_sampler)\n    logger.info(log)\n"
  },
  {
    "path": "src/invoke_training/_shared/data/samplers/batch_offset_sampler.py",
    "content": "import typing\n\nfrom torch.utils.data import Sampler\n\n\nclass BatchOffsetSampler(Sampler[int]):\n    \"\"\"A sampler that wraps a batch sampler and applies an offset to all returned batch elements.\"\"\"\n\n    def __init__(self, sampler: Sampler[int], offset: int):\n        self._sampler = sampler\n        self._offset = offset\n\n    def __iter__(self) -> typing.Iterator[int]:\n        for batch in self._sampler:\n            offset_batch = [x + self._offset for x in batch]\n            yield offset_batch\n\n    def __len__(self) -> int:\n        return len(self._sampler)\n"
  },
  {
    "path": "src/invoke_training/_shared/data/samplers/concat_sampler.py",
    "content": "import itertools\nimport typing\n\nfrom torch.utils.data import Sampler\n\nT_co = typing.TypeVar(\"T_co\", covariant=True)\n\n\nclass ConcatSampler(Sampler[T_co]):\n    \"\"\"A meta-Sampler that concatenates multiple samplers.\n\n    Example:\n        sampler 1:           ABCD\n        sampler 2:           EFG\n        sampler 3:           HIJKLM\n        ConcatSampler:       ABCDEFGHIJKLM\n    \"\"\"\n\n    def __init__(self, samplers: list[Sampler[T_co] | typing.Iterable[T_co]]) -> None:\n        self._samplers = samplers\n\n    def __iter__(self) -> typing.Iterator[T_co]:\n        return itertools.chain(*self._samplers)\n\n    def __len__(self) -> int:\n        return sum([len(s) for s in self._samplers])\n"
  },
  {
    "path": "src/invoke_training/_shared/data/samplers/interleaved_sampler.py",
    "content": "import typing\n\nfrom torch.utils.data import Sampler\n\nT_co = typing.TypeVar(\"T_co\", covariant=True)\n\n\nclass InterleavedSampler(Sampler[T_co]):\n    \"\"\"A meta-Sampler that interleaves multiple samplers.\n\n    The length of this sampler is based on the length of the shortest input sampler. All samplers will contribute the\n    same number of samples to the interleaved output.\n\n    Example:\n        sampler 1:           ABCD\n        sampler 2:           EFG\n        sampler 3:           HIJKLM\n        interleaved sampler: AEHBFICGJ\n    \"\"\"\n\n    def __init__(self, samplers: list[Sampler[T_co] | typing.Iterable[T_co]]) -> None:\n        self._samplers = samplers\n        self._min_sampler_len = min([len(s) for s in self._samplers])\n\n    def __iter__(self) -> typing.Iterator[T_co]:\n        sampler_iters = [iter(s) for s in self._samplers]\n        while True:\n            samples = []\n            for sampler_iter in sampler_iters:\n                try:\n                    samples.append(next(sampler_iter))\n                except StopIteration:\n                    # The end of the shortest sampler has been reached.\n                    return\n\n            yield from samples\n\n    def __len__(self) -> int:\n        return self._min_sampler_len * len(self._samplers)\n"
  },
  {
    "path": "src/invoke_training/_shared/data/samplers/offset_sampler.py",
    "content": "import typing\n\nfrom torch.utils.data import Sampler\n\n\nclass OffsetSampler(Sampler[int]):\n    \"\"\"A sampler that wraps another sampler and applies an offset to all returned values.\"\"\"\n\n    def __init__(self, sampler: Sampler[int], offset: int):\n        self._sampler = sampler\n        self._offset = offset\n\n    def __iter__(self) -> typing.Iterator[int]:\n        for idx in self._sampler:\n            yield idx + self._offset\n\n    def __len__(self) -> int:\n        return len(self._sampler)\n"
  },
  {
    "path": "src/invoke_training/_shared/data/transforms/__init__.py",
    "content": ""
  },
  {
    "path": "src/invoke_training/_shared/data/transforms/caption_prefix_transform.py",
    "content": "import typing\n\n\nclass CaptionPrefixTransform:\n    \"\"\"A transform that adds a prefix to all example captions.\"\"\"\n\n    def __init__(self, caption_field_name: str, prefix: str):\n        self._caption_field_name = caption_field_name\n        self._prefix = prefix\n\n    def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:\n        data[self._caption_field_name] = self._prefix + data[self._caption_field_name]\n        return data\n"
  },
  {
    "path": "src/invoke_training/_shared/data/transforms/concat_fields_transform.py",
    "content": "import typing\n\n\nclass ConcatFieldsTransform:\n    \"\"\"A transform that concatenate multiple string fields.\"\"\"\n\n    def __init__(self, src_field_names: list[str], dst_field_name: str, separator: str = \" \"):\n        self._src_field_names = src_field_names\n        self._dst_field_name = dst_field_name\n        self._separator = separator\n\n    def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:\n        result = self._separator.join([data[field_name] for field_name in self._src_field_names])\n        data[self._dst_field_name] = result\n        return data\n"
  },
  {
    "path": "src/invoke_training/_shared/data/transforms/constant_field_transform.py",
    "content": "import typing\n\n\nclass ConstantFieldTransform:\n    \"\"\"A simple transform that adds a constant field to every example.\"\"\"\n\n    def __init__(self, field_name: str, field_value: typing.Any):\n        self._field_name = field_name\n        self._field_value = field_value\n\n    def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:\n        data[self._field_name] = self._field_value\n        return data\n"
  },
  {
    "path": "src/invoke_training/_shared/data/transforms/drop_field_transform.py",
    "content": "import typing\n\n\nclass DropFieldTransform:\n    \"\"\"A simple transform that drops a field from an example.\"\"\"\n\n    def __init__(self, field_to_drop: str):\n        self._field_to_drop = field_to_drop\n\n    def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:\n        if self._field_to_drop in data:\n            del data[self._field_to_drop]\n        return data\n"
  },
  {
    "path": "src/invoke_training/_shared/data/transforms/flux_image_transform.py",
    "content": "import typing\n\nfrom torchvision import transforms\nfrom torchvision.transforms.functional import crop\n\nfrom invoke_training._shared.data.utils.aspect_ratio_bucket_manager import AspectRatioBucketManager, Resolution\nfrom invoke_training._shared.data.utils.resize import resize_to_cover\n\n\nclass FluxImageTransform:\n    \"\"\"A transform that prepares and augments images for Flux.1-dev training.\"\"\"\n\n    def __init__(\n        self,\n        image_field_names: list[str],\n        fields_to_normalize_to_range_minus_one_to_one: list[str],\n        resolution: int | None = 512,\n        aspect_ratio_bucket_manager: AspectRatioBucketManager | None = None,\n        random_flip: bool = True,\n        center_crop: bool = True,\n    ):\n        \"\"\"Initialize FluxImageTransform.\n\n        Args:\n            image_field_names (list[str]): The field names of the images to be transformed.\n            resolution (int): The image resolution that will be produced. One of `resolution` and\n                `aspect_ratio_bucket_manager` should be non-None.\n            aspect_ratio_bucket_manager (AspectRatioBucketManager): The AspectRatioBucketManager used to determine the\n                target resolution for each image. One of `resolution` and `aspect_ratio_bucket_manager` should be\n                non-None.\n            center_crop (bool, optional): If True, crop to the center of the image to achieve the target resolution. If\n                False, crop at a random location.\n            random_flip (bool, optional): Whether to apply a random horizontal flip to the images.\n        \"\"\"\n        self.image_field_names = image_field_names\n        self.fields_to_normalize_to_range_minus_one_to_one = fields_to_normalize_to_range_minus_one_to_one\n        self.resolution = resolution\n        self.aspect_ratio_bucket_manager = aspect_ratio_bucket_manager\n        self.random_flip = random_flip\n        self.center_crop = center_crop\n\n    def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:  # noqa: C901\n        image_fields: dict = {}\n        for field_name in self.image_field_names:\n            image_fields[field_name] = data[field_name]\n\n        # Get the first image to determine original size and resolution\n        first_image = next(iter(image_fields.values()))\n        original_size_hw = (first_image.height, first_image.width)\n\n        for field_name, image in image_fields.items():\n            # Determine the target image resolution.\n            if self.resolution is not None:\n                resolution = self.resolution\n                resolution_obj = Resolution(resolution, resolution)\n            else:\n                resolution_obj = self.aspect_ratio_bucket_manager.get_aspect_ratio_bucket(\n                    Resolution.parse(original_size_hw)\n                )\n\n            image = resize_to_cover(image, resolution_obj)\n\n            # Apply cropping and record top left crop position\n            if self.center_crop:\n                top_left_y = max(0, (image.height - resolution_obj.height) // 2)\n                top_left_x = max(0, (image.width - resolution_obj.width) // 2)\n                image = transforms.CenterCrop(resolution_obj.to_tuple())(image)\n            else:\n                crop_transform = transforms.RandomCrop(resolution_obj.to_tuple())\n                top_left_y, top_left_x, h, w = crop_transform.get_params(image, resolution_obj.to_tuple())\n                image = crop(image, top_left_y, top_left_x, resolution_obj.height, resolution_obj.width)\n\n            # Apply random flip and update top left crop position accordingly\n            if self.random_flip:\n                # TODO: Use a seed for repeatable results\n                import random\n\n                if random.random() < 0.5:\n                    top_left_x = original_size_hw[1] - image.width - top_left_x\n                    image = transforms.RandomHorizontalFlip(p=1.0)(image)\n\n            image = transforms.ToTensor()(image)\n\n            if field_name in self.fields_to_normalize_to_range_minus_one_to_one:\n                image_fields[field_name] = transforms.Normalize([0.5], [0.5])(image)\n            else:\n                image_fields[field_name] = image\n\n        # Store the processed images and metadata\n        for field_name, image in image_fields.items():\n            data[field_name] = image\n\n        # Add metadata fields expected by VAE caching\n        data[\"original_size_hw\"] = original_size_hw\n        data[\"crop_top_left_yx\"] = (top_left_y, top_left_x)\n\n        return data\n"
  },
  {
    "path": "src/invoke_training/_shared/data/transforms/load_cache_transform.py",
    "content": "import typing\n\nfrom invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache\n\n\nclass LoadCacheTransform:\n    \"\"\"A transform that loads data from a TensorDiskCache.\"\"\"\n\n    def __init__(\n        self, cache: TensorDiskCache, cache_key_field: str, cache_field_to_output_field: typing.Dict[str, str]\n    ):\n        \"\"\"Initialize LoadCacheTransform.\n\n        Args:\n            cache (TensorDiskCache): The cache to load from.\n            cache_key_field (str): The name of the field to use as the cache key.\n            cache_field_to_output_field (typing.Dict[str, str]): A map of field names in the cached data to the field\n                names where they should be inserted in the example data.\n        \"\"\"\n        self._cache = cache\n        self._cache_key_field = cache_key_field\n        self._cache_field_to_output_field = cache_field_to_output_field\n\n    def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:\n        key = data[self._cache_key_field]\n\n        cache_data = self._cache.load(key)\n\n        for src, dst in self._cache_field_to_output_field.items():\n            data[dst] = cache_data[src]\n\n        return data\n"
  },
  {
    "path": "src/invoke_training/_shared/data/transforms/sd_image_transform.py",
    "content": "import random\nimport typing\n\nfrom torchvision import transforms\nfrom torchvision.transforms.functional import crop\n\nfrom invoke_training._shared.data.utils.aspect_ratio_bucket_manager import AspectRatioBucketManager, Resolution\nfrom invoke_training._shared.data.utils.resize import resize_to_cover\n\n\nclass SDImageTransform:\n    \"\"\"A transform that prepares and augments images for Stable Diffusion training.\"\"\"\n\n    def __init__(\n        self,\n        image_field_names: list[str],\n        fields_to_normalize_to_range_minus_one_to_one: list[str],\n        resolution: int | tuple[int, int] | Resolution | None,\n        aspect_ratio_bucket_manager: AspectRatioBucketManager | None = None,\n        center_crop: bool = True,\n        random_flip: bool = False,\n        orig_size_field_name: str = \"original_size_hw\",\n        crop_field_name: str = \"crop_top_left_yx\",\n    ):\n        \"\"\"Initialize SDImageTransform.\n\n        Args:\n            image_field_names (list[str]): The field names of the images to be transformed.\n            resolution (Resolution): The image resolution that will be produced. One of `resolution` and\n                `aspect_ratio_bucket_manager` should be non-None.\n            aspect_ratio_bucket_manager (AspectRatioBucketManager): The AspectRatioBucketManager used to determine the\n                target resolution for each image. One of `resolution` and `aspect_ratio_bucket_manager` should be\n                non-None.\n            center_crop (bool, optional): If True, crop to the center of the image to achieve the target resolution. If\n                False, crop at a random location.\n            random_flip (bool, optional): Whether to apply a random horizontal flip to the images.\n        \"\"\"\n        self._image_field_names = image_field_names\n        self._fields_to_normalize_to_range_minus_one_to_one = fields_to_normalize_to_range_minus_one_to_one\n        if resolution is not None and aspect_ratio_bucket_manager is not None:\n            raise ValueError(\"Only one of `resolution` or `aspect_ratio_bucket_manager` should be set.\")\n\n        if resolution is None and aspect_ratio_bucket_manager is None:\n            raise ValueError(\"One of `resolution` or `aspect_ratio_bucket_manager` must be set.\")\n\n        self._resolution = Resolution.parse(resolution) if resolution is not None else None\n        self._aspect_ratio_bucket_manager = aspect_ratio_bucket_manager\n        self._center_crop_enabled = center_crop\n        self._random_flip_enabled = random_flip\n        self._flip_transform = transforms.RandomHorizontalFlip(p=1.0)\n        self._to_tensor_transform = transforms.ToTensor()\n        # Convert pixel values from range [0, 1.0] to range [-1.0, 1.0].\n        # Normalize applies the following transform: out = (in - 0.5) / 0.5\n        self._normalize_image_transform = transforms.Normalize([0.5], [0.5])\n\n        self._orig_size_field_name = orig_size_field_name\n        self._crop_field_name = crop_field_name\n\n    def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:  # noqa: C901\n        # This SDXL image pre-processing logic is adapted from:\n        # https://github.com/huggingface/diffusers/blob/7b07f9812a58bfa96c06ed8ffe9e6b584286e2fd/examples/text_to_image/train_text_to_image_lora_sdxl.py#L850-L873\n\n        image_fields: dict = {}\n        for field_name in self._image_field_names:\n            image_fields[field_name] = data[field_name]\n        sizes = [image.size for image in image_fields.values()]\n        # All images should have the same size.\n        assert all(size == sizes[0] for size in sizes)\n\n        # Helper function to access the first image, which is sometimes used to infer the shape of all images.\n        def get_first_image():\n            return next(iter(image_fields.values()))\n\n        original_size_hw = (get_first_image().height, get_first_image().width)\n\n        # Determine the target image resolution.\n        if self._resolution is not None:\n            resolution = self._resolution\n        else:\n            resolution = self._aspect_ratio_bucket_manager.get_aspect_ratio_bucket(Resolution.parse(original_size_hw))\n\n        # Resize to cover the target resolution while preserving aspect ratio.\n        for field_name, image in image_fields.items():\n            image_fields[field_name] = resize_to_cover(image, resolution)\n\n        # Apply cropping, and record top left crop position.\n        if self._center_crop_enabled:\n            top_left_y = max(0, (get_first_image().height - resolution.height) // 2)\n            top_left_x = max(0, (get_first_image().width - resolution.width) // 2)\n        else:\n            crop_transform = transforms.RandomCrop(resolution.to_tuple())\n            top_left_y, top_left_x, h, w = crop_transform.get_params(get_first_image(), resolution.to_tuple())\n        for field_name, image in image_fields.items():\n            image_fields[field_name] = crop(image, top_left_y, top_left_x, resolution.height, resolution.width)\n\n        # Apply random flip and update top left crop position accordingly.\n        # TODO(ryand): Use a seed for repeatable results.\n        if self._random_flip_enabled and random.random() < 0.5:\n            top_left_x = original_size_hw[1] - get_first_image().width - top_left_x\n            for field_name, image in image_fields.items():\n                image_fields[field_name] = self._flip_transform(image)\n\n        crop_top_left_yx = (top_left_y, top_left_x)\n\n        # Convert to Tensors.\n        for field_name, image in image_fields.items():\n            image_fields[field_name] = self._to_tensor_transform(image)\n\n        # Normalize to range [-1.0, 1.0].\n        # HACK(ryand): We should find a better way to determine the normalization range of each image field.\n        for field_name, image in image_fields.items():\n            if field_name in self._fields_to_normalize_to_range_minus_one_to_one:\n                image_fields[field_name] = self._normalize_image_transform(image)\n\n        data[self._orig_size_field_name] = original_size_hw\n        data[self._crop_field_name] = crop_top_left_yx\n        for field_name, image in image_fields.items():\n            data[field_name] = image\n\n        return data\n"
  },
  {
    "path": "src/invoke_training/_shared/data/transforms/shuffle_caption_transform.py",
    "content": "import typing\n\nimport numpy as np\n\n\nclass ShuffleCaptionTransform:\n    \"\"\"A transform that applies shuffle transformations to character-delimited captions.\n\n    Example:\n    - Original: \"unreal engine, render of sci-fi helmet, dramatic lighting\"\n    - Shuffled: \"render of sci-fi helmet, unreal engine, dramatic lighting\"\n    \"\"\"\n\n    def __init__(self, field_name: str, delimiter: str = \",\", seed: int = 0):\n        self._field_name = field_name\n        self._delimiter = delimiter\n        self._rng = np.random.default_rng(seed)\n\n    def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:\n        caption: str = data[self._field_name]\n        caption_chunks = caption.split(self._delimiter)\n        caption_chunks = [s.strip() for s in caption_chunks]\n\n        self._rng.shuffle(caption_chunks)\n\n        join_str = self._delimiter + \" \"\n        data[self._field_name] = join_str.join(caption_chunks)\n        return data\n"
  },
  {
    "path": "src/invoke_training/_shared/data/transforms/template_caption_transform.py",
    "content": "import typing\n\nimport numpy as np\n\n\nclass TemplateCaptionTransform:\n    \"\"\"A simple transform that constructs a caption for each example by combining a caption template with the\n    placeholder string.\n    \"\"\"\n\n    def __init__(self, field_name: str, placeholder_str: str, caption_templates: list[str], seed: int = 0):\n        self._field_name = field_name\n        self._placeholder_str = placeholder_str\n        self._caption_templates = caption_templates\n        self._rng = np.random.default_rng(seed)\n\n    def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:\n        caption = self._rng.choice(self._caption_templates).format(self._placeholder_str)\n        # Assert that the template was well-formed such that the placeholder string is in the output caption.\n        assert self._placeholder_str in caption\n\n        data[self._field_name] = caption\n        return data\n"
  },
  {
    "path": "src/invoke_training/_shared/data/transforms/tensor_disk_cache.py",
    "content": "import os\nimport typing\n\nimport torch\n\n\nclass TensorDiskCache:\n    \"\"\"A data cache that caches `torch.Tensor`s on disk.\"\"\"\n\n    def __init__(self, cache_dir: str):\n        super().__init__()\n        self._cache_dir = cache_dir\n\n        os.makedirs(self._cache_dir, exist_ok=True)\n\n    def _get_path(self, key: int):\n        \"\"\"Get the cache file path for `key`.\n        Args:\n            key (int): The cache key.\n        Returns:\n            str: The cache file path.\n        \"\"\"\n        return os.path.join(self._cache_dir, f\"{key}.pt\")\n\n    def save(self, key: int, data: typing.Dict[str, torch.Tensor]):\n        \"\"\"Save data in the cache.\n        Raises:\n            AssertionError: If an entry already exists in the cache for this `key`.\n        Args:\n            key (int): The cache key.\n            data (typing.Dict[str, torch.Tensor]): The data to save.\n        \"\"\"\n        # torch.save() supports a range of different data types, but it is cleaner if we force everyone to use a dict.\n        # This allows for more reusable cache loading code.\n        assert isinstance(data, dict)\n\n        save_path = self._get_path(key)\n        assert not os.path.exists(save_path)\n        torch.save(data, save_path)\n\n    def load(self, key: int) -> typing.Dict[str, torch.Tensor]:\n        \"\"\"Load data from the cache.\n        Args:\n            key (int): The cache key to load.\n        Returns:\n            typing.Dict[str, torch.Tensor]: Data loaded from the cache.\n        \"\"\"\n        return torch.load(self._get_path(key))\n"
  },
  {
    "path": "src/invoke_training/_shared/data/utils/__init__.py",
    "content": ""
  },
  {
    "path": "src/invoke_training/_shared/data/utils/aspect_ratio_bucket_manager.py",
    "content": "from invoke_training._shared.data.utils.resolution import Resolution\n\n\nclass AspectRatioBucketManager:\n    def __init__(self, buckets: set[Resolution]):\n        self.buckets = buckets\n\n    @classmethod\n    def from_constraints(cls, target_resolution: int, start_dim: int, end_dim: int, divisible_by: int) -> None:\n        buckets = cls.build_aspect_ratio_buckets(\n            target_resolution=target_resolution,\n            start_dim=start_dim,\n            end_dim=end_dim,\n            divisible_by=divisible_by,\n        )\n        return cls(buckets)\n\n    @classmethod\n    def build_aspect_ratio_buckets(\n        cls, target_resolution: int, start_dim: int, end_dim: int, divisible_by: int\n    ) -> set[Resolution]:\n        \"\"\"Prepare a set of aspect ratios.\n\n        Args:\n            target_resolution (Resolution): All resolutions in the returned set will aim to have close to\n                (but <=) `target_resolution * target_resolution` pixels.\n            start_dim (int):\n            end_dim (int):\n            divisible_by (int): All dimensions in the returned set of resolutions will be divisible by `divisible_by`.\n\n        Returns:\n            set[tuple[int, int]]: The aspect ratio bucket resolutions.\n        \"\"\"\n        # Validate target_resolution.\n        assert target_resolution % divisible_by == 0\n\n        # Validate start_dim, end_dim.\n        assert start_dim <= end_dim\n        assert start_dim % divisible_by == 0\n        assert end_dim % divisible_by == 0\n\n        target_size = target_resolution * target_resolution\n\n        buckets = set()\n\n        height = start_dim\n        while height <= end_dim:\n            width = (target_size // height) // divisible_by * divisible_by\n            buckets.add(Resolution(height, width))\n            buckets.add(Resolution(width, height))\n\n            height += divisible_by\n\n        return buckets\n\n    def get_aspect_ratio_bucket(self, resolution: Resolution):\n        \"\"\"Get the bucket with the closest aspect ratio to 'resolution'.\"\"\"\n        # Note: If this is ever found to be a bottleneck, there is a clearly-more-efficient implementation using bisect.\n        return min(self.buckets, key=lambda x: abs(x.aspect_ratio() - resolution.aspect_ratio()))\n"
  },
  {
    "path": "src/invoke_training/_shared/data/utils/resize.py",
    "content": "import math\n\nfrom PIL.Image import Image\nfrom torchvision import transforms\n\nfrom invoke_training._shared.data.utils.resolution import Resolution\n\n\ndef resize_to_cover(image: Image, size_to_cover: Resolution) -> Image:\n    \"\"\"Resize image to the smallest size that covers 'size_to_cover' while preserving its aspect ratio.\n\n    In other words, achieve the following:\n    - resized_height >= size_to_cover.height\n    - resized_width >= size_to_cover.width\n    - resized_height == size_to_cover.height or resized_width == size_to_cover.width\n    - 'image' aspect ratio is preserved.\n    \"\"\"\n\n    scale_to_height = size_to_cover.height / image.height\n    scale_to_width = size_to_cover.width / image.width\n\n    if scale_to_height > scale_to_width:\n        resize_height = size_to_cover.height\n        resize_width = math.ceil(image.width * scale_to_height)\n    else:\n        resize_width = size_to_cover.width\n        resize_height = math.ceil(image.height * scale_to_width)\n\n    resize_transform = transforms.Resize(\n        (resize_height, resize_width), interpolation=transforms.InterpolationMode.BILINEAR\n    )\n\n    return resize_transform(image)\n"
  },
  {
    "path": "src/invoke_training/_shared/data/utils/resolution.py",
    "content": "from typing import Union\n\n\nclass Resolution:\n    def __init__(self, height: int, width: int):\n        self.height = height\n        self.width = width\n\n    @classmethod\n    def parse(cls, resolution: Union[int, tuple[int, int], \"Resolution\"]):\n        \"\"\"Initialize a Resolution object from another type.\"\"\"\n        if isinstance(resolution, int):\n            # Assume square resolution.\n            return cls(resolution, resolution)\n        elif isinstance(resolution, tuple):\n            height, width = resolution\n            return cls(height, width)\n        elif isinstance(resolution, cls):\n            return cls(resolution.height, resolution.width)\n        else:\n            raise ValueError(f\"Unsupported resolution type: '{type(resolution)}'.\")\n\n    def aspect_ratio(self):\n        return self.height / self.width\n\n    def to_tuple(self) -> tuple[int, int]:\n        return (self.height, self.width)\n\n    def __eq__(self, other: \"Resolution\") -> bool:\n        return self.to_tuple() == other.to_tuple()\n\n    def __lt__(self, other: \"Resolution\") -> bool:\n        return self.to_tuple() < other.to_tuple()\n\n    def __hash__(self):\n        return hash(self.to_tuple())\n"
  },
  {
    "path": "src/invoke_training/_shared/flux/encoding_utils.py",
    "content": "import logging\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nfrom transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast\n\n\ndef get_clip_prompt_embeds(\n    prompt: Union[str, List[str]],\n    tokenizer: CLIPTokenizer,\n    text_encoder: CLIPTextModel,\n    device: torch.device,\n    num_images_per_prompt: int = 1,\n    tokenizer_max_length: int = 77,\n    logger: logging.Logger | None = None,\n) -> torch.FloatTensor:\n    \"\"\"Encodes the prompt using CLIP text encoder and returns pooled embeddings.\"\"\"\n    prompt = [prompt] if isinstance(prompt, str) else prompt\n    batch_size = len(prompt)\n\n    # Process text input with the tokenizer\n    text_inputs = tokenizer(\n        prompt,\n        padding=\"max_length\",\n        max_length=tokenizer_max_length,\n        truncation=True,\n        return_overflowing_tokens=False,\n        return_length=False,\n        return_tensors=\"pt\",\n    )\n\n    text_input_ids = text_inputs.input_ids\n    untruncated_ids = tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n    # Check if truncation occurred\n    if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):\n        removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer_max_length - 1 : -1])\n        if logger is not None:\n            logger.warning(f\"Warning: The following part of your input was truncated: {removed_text}\")\n\n    # Get prompt embeddings through the text encoder\n    prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)\n\n    # Use pooled output of CLIPTextModel\n    prompt_embeds = prompt_embeds.pooler_output\n    prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)\n\n    # Duplicate text embeddings for each generation per prompt\n    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)\n    prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)\n\n    return prompt_embeds\n\n\ndef get_t5_prompt_embeds(\n    prompt: Union[str, List[str]],\n    tokenizer: T5TokenizerFast,\n    text_encoder: T5EncoderModel,\n    device: torch.device,\n    num_images_per_prompt: int = 1,\n    tokenizer_max_length: int = 512,\n    logger: logging.Logger | None = None,\n) -> torch.FloatTensor:\n    \"\"\"Encodes the prompt using T5 text encoder.\"\"\"\n    prompt = [prompt] if isinstance(prompt, str) else prompt\n    batch_size = len(prompt)\n\n    # Process text input with the tokenizer\n    text_inputs = tokenizer(\n        prompt,\n        padding=\"max_length\",\n        max_length=tokenizer_max_length,\n        truncation=True,\n        return_length=False,\n        return_overflowing_tokens=False,\n        return_tensors=\"pt\",\n    )\n    text_input_ids = text_inputs.input_ids\n    untruncated_ids = tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n    # Check if truncation occurred\n    if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):\n        removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer_max_length - 1 : -1])\n        if logger is not None:\n            logger.warning(f\"Warning: The following part of your input was truncated: {removed_text}\")\n\n    # Get prompt embeddings through the text encoder\n    prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)[0]\n\n    dtype = text_encoder.dtype\n    prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)\n\n    # Get shape and duplicate for multiple generations\n    _, seq_len, _ = prompt_embeds.shape\n    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n    prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n    return prompt_embeds\n\n\ndef handle_lora_scale(\n    clip_text_encoder: CLIPTextModel,\n    t5_text_encoder: T5EncoderModel,\n    lora_scale: Optional[float] = None,\n    use_peft_backend: bool = False,\n):\n    \"\"\"Handles LoRA scale adjustments for text encoders.\"\"\"\n    if lora_scale is not None and use_peft_backend:\n        from peft.utils import scale_lora_layers\n\n        # Apply LoRA scaling to text encoders if they exist\n        if clip_text_encoder is not None:\n            scale_lora_layers(clip_text_encoder, lora_scale)\n        if t5_text_encoder is not None:\n            scale_lora_layers(t5_text_encoder, lora_scale)\n\n        return True\n    return False\n\n\ndef reset_lora_scale(\n    clip_text_encoder: CLIPTextModel,\n    t5_text_encoder: T5EncoderModel,\n    lora_scale: Optional[float] = None,\n    lora_applied: bool = False,\n    use_peft_backend: bool = False,\n):\n    \"\"\"Resets LoRA scale for text encoders if it was applied.\"\"\"\n    if lora_applied and use_peft_backend:\n        from peft.utils import unscale_lora_layers\n\n        # Reset LoRA scaling\n        if clip_text_encoder is not None:\n            unscale_lora_layers(clip_text_encoder, lora_scale)\n        if t5_text_encoder is not None:\n            unscale_lora_layers(t5_text_encoder, lora_scale)\n\n\n# A lot of this code was adapted from:\n# https://github.com/huggingface/diffusers/blob/ea81a4228d8ff16042c3ccaf61f0e588e60166cd/src/diffusers/pipelines/flux/pipeline_flux.py#L310-L387\ndef encode_prompt(\n    prompt: Union[str, List[str]],\n    prompt_2: Optional[Union[str, List[str]]],\n    clip_tokenizer: CLIPTokenizer,\n    t5_tokenizer: T5TokenizerFast,\n    clip_text_encoder: CLIPTextModel,\n    t5_text_encoder: T5EncoderModel,\n    device: torch.device,\n    num_images_per_prompt: int = 1,\n    prompt_embeds: Optional[torch.FloatTensor] = None,\n    pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n    lora_scale: Optional[float] = None,\n    use_peft_backend: bool = False,\n    clip_tokenizer_max_length: int = 77,\n    t5_tokenizer_max_length: int = 512,\n    logger: logging.Logger | None = None,\n) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:\n    \"\"\"\n    Encodes the prompt using both CLIP and T5 text encoders.\n\n    Returns:\n        Tuple containing:\n            - T5 text embeddings\n            - CLIP pooled embeddings\n            - Text IDs\n    \"\"\"\n    # Apply LoRA scale if needed\n    lora_applied = handle_lora_scale(\n        clip_text_encoder=clip_text_encoder,\n        t5_text_encoder=t5_text_encoder,\n        lora_scale=lora_scale,\n        use_peft_backend=use_peft_backend,\n    )\n\n    # If no pre-generated embeddings, create them\n    if prompt_embeds is None:\n        prompt_2 = prompt_2 or prompt\n        prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2\n\n        # Get CLIP pooled embeddings\n        pooled_prompt_embeds = get_clip_prompt_embeds(\n            prompt=prompt,\n            tokenizer=clip_tokenizer,\n            text_encoder=clip_text_encoder,\n            device=device,\n            num_images_per_prompt=num_images_per_prompt,\n            tokenizer_max_length=clip_tokenizer_max_length,\n        )\n\n        # Get T5 text embeddings\n        prompt_embeds = get_t5_prompt_embeds(\n            prompt=prompt_2,\n            tokenizer=t5_tokenizer,\n            text_encoder=t5_text_encoder,\n            device=device,\n            num_images_per_prompt=num_images_per_prompt,\n            tokenizer_max_length=t5_tokenizer_max_length,\n        )\n\n    # Reset LoRA scale if it was applied\n    reset_lora_scale(\n        clip_text_encoder=clip_text_encoder,\n        t5_text_encoder=t5_text_encoder,\n        lora_scale=lora_scale,\n        lora_applied=lora_applied,\n        use_peft_backend=use_peft_backend,\n    )\n\n    # Create text_ids placeholder for model\n    dtype = clip_text_encoder.dtype if clip_text_encoder is not None else t5_text_encoder.dtype\n    text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)\n\n    return prompt_embeds, pooled_prompt_embeds, text_ids\n"
  },
  {
    "path": "src/invoke_training/_shared/flux/lora_checkpoint_utils.py",
    "content": "# ruff: noqa: N806\nimport os\nfrom pathlib import Path\n\nimport peft\nimport torch\nfrom diffusers import FluxTransformer2DModel\nfrom transformers import CLIPTextModel\n\nfrom invoke_training._shared.checkpoints.lora_checkpoint_utils import (\n    _convert_peft_state_dict_to_kohya_state_dict,\n    load_multi_model_peft_checkpoint,\n    save_multi_model_peft_checkpoint,\n)\nfrom invoke_training._shared.checkpoints.serialization import save_state_dict\n\nFLUX_TRANSFORMER_TARGET_MODULES = [\n    # double blocks\n    \"attn.add_k_proj\",\n    \"attn.add_q_proj\",\n    \"attn.add_v_proj\",\n    \"attn.to_add_out\",\n    \"attn.to_k\",\n    \"attn.to_q\",\n    \"attn.to_v\",\n    \"attn.to_out.0\",\n    \"ff.net.0.proj\",\n    \"ff.net.2.0\",\n    \"ff_context.net.0.proj\",\n    \"ff_context.net.2.0\",\n    # single blocks\n    \"attn.to_k\",\n    \"attn.to_q\",\n    \"attn.to_v\",\n    \"proj_mlp\",\n    \"proj_out\",\n    \"proj_in\",\n]\n\nTEXT_ENCODER_TARGET_MODULES = [\"fc1\", \"fc2\", \"q_proj\", \"k_proj\", \"v_proj\", \"out_proj\"]\n\n# Module lists copied from diffusers training script.\n# These module lists will produce lighter, less expressive, LoRA models than the non-light versions.\nFLUX_TRANSFORMER_TARGET_MODULES_LIGHT = [\"to_k\", \"to_q\", \"to_v\", \"to_out.0\"]\nFLUX_TEXT_ENCODER_TARGET_MODULES_LIGHT = [\"q_proj\", \"k_proj\", \"v_proj\", \"out_proj\"]\n\nFLUX_PEFT_TRANSFORMER_KEY = \"transformer\"\nFLUX_PEFT_TEXT_ENCODER_1_KEY = \"text_encoder_1\"\nFLUX_PEFT_TEXT_ENCODER_2_KEY = \"text_encoder_2\"\n\nFLUX_KOHYA_TRANSFORMER_KEY = \"lora_unet\"\nFLUX_KOHYA_TEXT_ENCODER_1_KEY = \"lora_clip\"\nFLUX_KOHYA_TEXT_ENCODER_2_KEY = \"lora_t5\"\n\nFLUX_PEFT_TO_KOHYA_KEYS = {\n    FLUX_PEFT_TRANSFORMER_KEY: FLUX_KOHYA_TRANSFORMER_KEY,\n    FLUX_PEFT_TEXT_ENCODER_1_KEY: FLUX_KOHYA_TEXT_ENCODER_1_KEY,\n    FLUX_PEFT_TEXT_ENCODER_2_KEY: FLUX_KOHYA_TEXT_ENCODER_2_KEY,\n}\n\n\ndef save_flux_peft_checkpoint(\n    checkpoint_dir: Path | str,\n    transformer: peft.PeftModel | None,\n    text_encoder_1: peft.PeftModel | None,\n    text_encoder_2: peft.PeftModel | None,\n):\n    models = {}\n    if transformer is not None:\n        models[FLUX_PEFT_TRANSFORMER_KEY] = transformer\n    if text_encoder_1 is not None:\n        models[FLUX_PEFT_TEXT_ENCODER_1_KEY] = text_encoder_1\n    if text_encoder_2 is not None:\n        models[FLUX_PEFT_TEXT_ENCODER_2_KEY] = text_encoder_2\n\n    save_multi_model_peft_checkpoint(checkpoint_dir=checkpoint_dir, models=models)\n\n\ndef load_flux_peft_checkpoint(\n    checkpoint_dir: Path | str,\n    transformer: FluxTransformer2DModel,\n    text_encoder_1: CLIPTextModel,\n    text_encoder_2: CLIPTextModel,\n    is_trainable: bool = False,\n):\n    models = load_multi_model_peft_checkpoint(\n        checkpoint_dir=checkpoint_dir,\n        models={\n            FLUX_PEFT_TRANSFORMER_KEY: transformer,\n            FLUX_PEFT_TEXT_ENCODER_1_KEY: text_encoder_1,\n            FLUX_PEFT_TEXT_ENCODER_2_KEY: text_encoder_2,\n        },\n        is_trainable=is_trainable,\n        raise_if_subdir_missing=False,\n    )\n\n    return models[FLUX_PEFT_TRANSFORMER_KEY], models[FLUX_PEFT_TEXT_ENCODER_1_KEY], models[FLUX_PEFT_TEXT_ENCODER_2_KEY]\n\n\ndef save_flux_kohya_checkpoint(\n    checkpoint_path: Path,\n    transformer: peft.PeftModel | None,\n    text_encoder_1: peft.PeftModel | None,\n    text_encoder_2: peft.PeftModel | None,\n):\n    kohya_prefixes = []\n    models = []\n    for kohya_prefix, peft_model in zip(\n        [FLUX_KOHYA_TRANSFORMER_KEY, FLUX_KOHYA_TEXT_ENCODER_1_KEY], [transformer, text_encoder_1]\n    ):\n        if peft_model is not None:\n            kohya_prefixes.append(kohya_prefix)\n            models.append(peft_model)\n\n    kohya_state_dict = _convert_peft_models_to_kohya_state_dict(kohya_prefixes=kohya_prefixes, models=models)\n\n    checkpoint_path.parent.mkdir(parents=True, exist_ok=True)\n    save_state_dict(kohya_state_dict, checkpoint_path)\n\n\ndef convert_flux_peft_checkpoint_to_kohya_state_dict(\n    in_checkpoint_dir: Path,\n    out_checkpoint_file: Path,\n    dtype: torch.dtype = torch.float32,\n) -> dict[str, torch.Tensor]:\n    \"\"\"Convert Flux PEFT models to a Kohya-format LoRA state dict.\"\"\"\n    # Get the immediate subdirectories of the checkpoint directory. We assume that each subdirectory is a PEFT model.\n    peft_model_dirs = os.listdir(in_checkpoint_dir)\n    peft_model_dirs = [in_checkpoint_dir / d for d in peft_model_dirs]  # Convert to Path objects.\n    peft_model_dirs = [d for d in peft_model_dirs if d.is_dir()]  # Filter out non-directories.\n\n    if len(peft_model_dirs) == 0:\n        raise ValueError(f\"No checkpoint files found in directory '{in_checkpoint_dir}'.\")\n\n    kohya_state_dict = {}\n    for peft_model_dir in peft_model_dirs:\n        if peft_model_dir.name in FLUX_PEFT_TO_KOHYA_KEYS:\n            kohya_prefix = FLUX_PEFT_TO_KOHYA_KEYS[peft_model_dir.name]\n        else:\n            raise ValueError(f\"Unrecognized checkpoint directory: '{peft_model_dir}'.\")\n\n        # Note: This logic to load the LoraConfig and weights directly is based on how it is done here:\n        # https://github.com/huggingface/peft/blob/8665e2b5719faa4e4b91749ddec09442927b53e0/src/peft/peft_model.py#L672-L689\n        # This may need to be updated in the future to support other adapter types (LoKr, LoHa, etc.).\n        # Also, I could see this interface breaking in the future.\n        lora_config = peft.LoraConfig.from_pretrained(peft_model_dir)\n        lora_weights = peft.utils.load_peft_weights(peft_model_dir, device=\"cpu\")\n\n        kohya_state_dict.update(\n            _convert_peft_state_dict_to_kohya_state_dict(\n                lora_config=lora_config, peft_state_dict=lora_weights, prefix=kohya_prefix, dtype=dtype\n            )\n        )\n\n    save_state_dict(kohya_state_dict, out_checkpoint_file)\n\n\ndef _convert_peft_models_to_kohya_state_dict(\n    kohya_prefixes: list[str], models: list[peft.PeftModel]\n) -> dict[str, torch.Tensor]:\n    kohya_state_dict = {}\n    default_adapter_name = \"default\"\n\n    for kohya_prefix, peft_model in zip(kohya_prefixes, models, strict=True):\n        lora_config = peft_model.peft_config[default_adapter_name]\n        assert isinstance(lora_config, peft.LoraConfig)\n\n        state_dict = peft.get_peft_model_state_dict(peft_model, adapter_name=default_adapter_name)\n\n        if kohya_prefix == FLUX_KOHYA_TRANSFORMER_KEY:\n            state_dict = convert_diffusers_to_flux_transformer_checkpoint(state_dict)\n\n        kohya_state_dict.update(\n            _convert_peft_state_dict_to_kohya_state_dict(\n                lora_config=lora_config,\n                peft_state_dict=state_dict,\n                prefix=kohya_prefix,\n                dtype=torch.float32,\n            )\n        )\n\n    return kohya_state_dict\n\n\ndef find_matching_key_prefix(state_dict, key_pattern):\n    \"\"\"\n    Find if any key in the state dictionary matches the given pattern.\n\n    Args:\n        state_dict: The state dictionary to search in\n        key_pattern: The pattern to look for in keys\n\n    Returns:\n        The matching prefix if found, False otherwise\n    \"\"\"\n    base_prefix = key_pattern.split(\".lora_A\")[0].split(\".lora_B\")[0].split(\".weight\")[0]\n\n    for key in state_dict.keys():\n        if base_prefix in key:\n            return base_prefix\n    return False\n\n\ndef convert_layer_weights(target_dict, source_dict, source_pattern, target_pattern):\n    \"\"\"\n    Convert weights from source pattern to target pattern if they exist.\n\n    Args:\n        target_dict: Dictionary to store converted weights\n        source_dict: Source dictionary containing weights\n        source_pattern: Original key pattern to search for\n        target_pattern: New key pattern to use\n\n\n    Returns:\n        Tuple of (updated target_dict, updated source_dict)\n    \"\"\"\n    if original_key := find_matching_key_prefix(source_dict, source_pattern):\n        # Find all keys matching the pattern\n        keys_to_convert = [k for k in source_dict.keys() if original_key in k]\n\n        for key in keys_to_convert:\n            # Create replacement key\n            new_key = key.replace(original_key, target_pattern.replace(\".weight\", \"\"))\n            # Transfer and remove from original\n            target_dict[new_key] = source_dict.pop(key)\n\n    return target_dict, source_dict\n\n\ndef convert_double_transformer_block(target_dict, source_dict, prefix=\"\", block_idx=0):\n    \"\"\"\n    Convert weights for a double transformer block.\n\n    Args:\n        target_dict: Dictionary to store converted weights\n        source_dict: Source dictionary containing weights\n        prefix: Prefix for the keys in the state dictionary\n        block_idx: Block index\n\n    Returns:\n        Tuple of (updated target_dict, updated source_dict)\n    \"\"\"\n    block_prefix = f\"transformer_blocks.{block_idx}.\"\n\n    # Convert norms\n    target_dict, source_dict = convert_layer_weights(\n        target_dict,\n        source_dict,\n        f\"{prefix}{block_prefix}norm1.linear.weight\",\n        f\"double_blocks.{block_idx}.img_mod.lin.weight\",\n    )\n\n    target_dict, source_dict = convert_layer_weights(\n        target_dict,\n        source_dict,\n        f\"{prefix}{block_prefix}norm1_context.linear.weight\",\n        f\"double_blocks.{block_idx}.txt_mod.lin.weight\",\n    )\n\n    # Convert attention weights by concatenating Q, K, V\n    try:\n        # Sample attention weights\n        sample_q_A = source_dict.pop(f\"{prefix}{block_prefix}attn.to_q.lora_A.weight\")\n        sample_q_B = source_dict.pop(f\"{prefix}{block_prefix}attn.to_q.lora_B.weight\")\n        sample_k_A = source_dict.pop(f\"{prefix}{block_prefix}attn.to_k.lora_A.weight\")\n        sample_k_B = source_dict.pop(f\"{prefix}{block_prefix}attn.to_k.lora_B.weight\")\n        sample_v_A = source_dict.pop(f\"{prefix}{block_prefix}attn.to_v.lora_A.weight\")\n        sample_v_B = source_dict.pop(f\"{prefix}{block_prefix}attn.to_v.lora_B.weight\")\n\n        # Context attention weights\n        context_q_A = source_dict.pop(f\"{prefix}{block_prefix}attn.add_q_proj.lora_A.weight\")\n        context_q_B = source_dict.pop(f\"{prefix}{block_prefix}attn.add_q_proj.lora_B.weight\")\n        context_k_A = source_dict.pop(f\"{prefix}{block_prefix}attn.add_k_proj.lora_A.weight\")\n        context_k_B = source_dict.pop(f\"{prefix}{block_prefix}attn.add_k_proj.lora_B.weight\")\n        context_v_A = source_dict.pop(f\"{prefix}{block_prefix}attn.add_v_proj.lora_A.weight\")\n        context_v_B = source_dict.pop(f\"{prefix}{block_prefix}attn.add_v_proj.lora_B.weight\")\n\n        # Concatenate Q, K, V for image and text\n        target_dict[f\"double_blocks.{block_idx}.img_attn.qkv.lora_A.weight\"] = torch.cat(\n            [sample_q_A, sample_k_A, sample_v_A], dim=0\n        )\n        target_dict[f\"double_blocks.{block_idx}.img_attn.qkv.lora_B.weight\"] = torch.cat(\n            [sample_q_B, sample_k_B, sample_v_B], dim=0\n        )\n        target_dict[f\"double_blocks.{block_idx}.txt_attn.qkv.lora_A.weight\"] = torch.cat(\n            [context_q_A, context_k_A, context_v_A], dim=0\n        )\n        target_dict[f\"double_blocks.{block_idx}.txt_attn.qkv.lora_B.weight\"] = torch.cat(\n            [context_q_B, context_k_B, context_v_B], dim=0\n        )\n    except KeyError as e:\n        print(f\"Error processing attention weights for block {block_idx}: {e}\")\n        raise\n\n    # Convert QK norms\n    norm_keys = [\n        (f\"{prefix}{block_prefix}attn.norm_q.weight\", f\"double_blocks.{block_idx}.img_attn.norm.query_norm.scale\"),\n        (f\"{prefix}{block_prefix}attn.norm_k.weight\", f\"double_blocks.{block_idx}.img_attn.norm.key_norm.scale\"),\n        (\n            f\"{prefix}{block_prefix}attn.norm_added_q.weight\",\n            f\"double_blocks.{block_idx}.txt_attn.norm.query_norm.scale\",\n        ),\n        (f\"{prefix}{block_prefix}attn.norm_added_k.weight\", f\"double_blocks.{block_idx}.txt_attn.norm.key_norm.scale\"),\n    ]\n\n    for src_key, target_key in norm_keys:\n        target_dict, source_dict = convert_layer_weights(target_dict, source_dict, src_key, target_key)\n\n    # Convert MLP weights\n    mlp_keys = [\n        (f\"{prefix}{block_prefix}ff.net.0.proj.weight\", f\"double_blocks.{block_idx}.img_mlp.0.weight\"),\n        (f\"{prefix}{block_prefix}ff.net.2.weight\", f\"double_blocks.{block_idx}.img_mlp.2.weight\"),\n        (f\"{prefix}{block_prefix}ff_context.net.0.proj.weight\", f\"double_blocks.{block_idx}.txt_mlp.0.weight\"),\n        (f\"{prefix}{block_prefix}ff_context.net.2.weight\", f\"double_blocks.{block_idx}.txt_mlp.2.weight\"),\n    ]\n\n    for src_key, target_key in mlp_keys:\n        target_dict, source_dict = convert_layer_weights(target_dict, source_dict, src_key, target_key)\n\n    # Convert output projections\n    output_keys = [\n        (f\"{prefix}{block_prefix}attn.to_out.0.weight\", f\"double_blocks.{block_idx}.img_attn.proj.weight\"),\n        (f\"{prefix}{block_prefix}attn.to_add_out.weight\", f\"double_blocks.{block_idx}.txt_attn.proj.weight\"),\n    ]\n\n    for src_key, target_key in output_keys:\n        target_dict, source_dict = convert_layer_weights(target_dict, source_dict, src_key, target_key)\n\n    return target_dict, source_dict\n\n\ndef convert_single_transformer_block(target_dict, source_dict, prefix, block_idx):\n    \"\"\"\n    Convert weights for a single transformer block.\n\n    Args:\n        target_dict: Dictionary to store converted weights\n        source_dict: Source dictionary containing weights\n        prefix: Prefix for the keys in the state dictionary\n        block_idx: Block index\n\n    Returns:\n        Tuple of (updated target_dict, updated source_dict)\n    \"\"\"\n    block_prefix = f\"single_transformer_blocks.{block_idx}.\"\n\n    # Convert norm\n    target_dict, source_dict = convert_layer_weights(\n        target_dict,\n        source_dict,\n        f\"{prefix}{block_prefix}norm.linear.weight\",\n        f\"single_blocks.{block_idx}.modulation.lin.weight\",\n    )\n\n    try:\n        # Convert Q, K, V, MLP by concatenating\n        q_A = source_dict.pop(f\"{prefix}{block_prefix}attn.to_q.lora_A.weight\")\n        q_B = source_dict.pop(f\"{prefix}{block_prefix}attn.to_q.lora_B.weight\")\n        k_A = source_dict.pop(f\"{prefix}{block_prefix}attn.to_k.lora_A.weight\")\n        k_B = source_dict.pop(f\"{prefix}{block_prefix}attn.to_k.lora_B.weight\")\n        v_A = source_dict.pop(f\"{prefix}{block_prefix}attn.to_v.lora_A.weight\")\n        v_B = source_dict.pop(f\"{prefix}{block_prefix}attn.to_v.lora_B.weight\")\n        mlp_A = source_dict.pop(f\"{prefix}{block_prefix}proj_mlp.lora_A.weight\")\n        mlp_B = source_dict.pop(f\"{prefix}{block_prefix}proj_mlp.lora_B.weight\")\n\n        target_dict[f\"single_blocks.{block_idx}.linear1.lora_A.weight\"] = torch.cat([q_A, k_A, v_A, mlp_A], dim=0)\n        target_dict[f\"single_blocks.{block_idx}.linear1.lora_B.weight\"] = torch.cat([q_B, k_B, v_B, mlp_B], dim=0)\n    except KeyError as e:\n        print(f\"Error processing attention weights for single block {block_idx}: {e}\")\n        raise\n\n    # Convert output projection\n    target_dict, source_dict = convert_layer_weights(\n        target_dict,\n        source_dict,\n        f\"{prefix}{block_prefix}proj_out.weight\",\n        f\"single_blocks.{block_idx}.linear2.weight\",\n    )\n\n    return target_dict, source_dict\n\n\ndef convert_embedding_layers(target_dict, source_dict, prefix, has_guidance=True):\n    \"\"\"\n    Convert time, text, guidance, and context embedding layers.\n\n    Args:\n        target_dict: Dictionary to store converted weights\n        source_dict: Source dictionary containing weights\n        prefix: Prefix for the keys in the state dictionary\n        has_guidance: Whether the model has guidance embedding\n\n    Returns:\n        Tuple of (updated target_dict, updated source_dict)\n    \"\"\"\n    # Convert time embedding\n    target_dict, source_dict = convert_layer_weights(\n        target_dict,\n        source_dict,\n        f\"{prefix}time_text_embed.timestep_embedder.linear_1.weight\",\n        \"time_in.in_layer.weight\",\n    )\n\n    # Convert text embedding\n    text_embed_keys = [\n        (f\"{prefix}time_text_embed.text_embedder.linear_1.weight\", \"vector_in.in_layer.weight\"),\n        (f\"{prefix}time_text_embed.text_embedder.linear_2.weight\", \"vector_in.out_layer.weight\"),\n    ]\n\n    for src_key, target_key in text_embed_keys:\n        target_dict, source_dict = convert_layer_weights(target_dict, source_dict, src_key, target_key)\n\n    # Convert guidance embedding if needed\n    if has_guidance:\n        guidance_keys = [\n            (f\"{prefix}time_text_embed.guidance_embedder.linear_1.weight\", \"guidance_in.in_layer.weight\"),\n            (f\"{prefix}time_text_embed.guidance_embedder.linear_2.weight\", \"guidance_in.out_layer.weight\"),\n        ]\n\n        for src_key, target_key in guidance_keys:\n            target_dict, source_dict = convert_layer_weights(target_dict, source_dict, src_key, target_key)\n\n    # Convert context and image embedders\n    embed_keys = [\n        (f\"{prefix}context_embedder.weight\", \"txt_in.weight\"),\n        (f\"{prefix}x_embedder.weight\", \"img_in.weight\"),\n    ]\n\n    for src_key, target_key in embed_keys:\n        target_dict, source_dict = convert_layer_weights(target_dict, source_dict, src_key, target_key)\n\n    return target_dict, source_dict\n\n\ndef convert_output_layers(target_dict, source_dict, prefix):\n    \"\"\"\n    Convert final output layers.\n\n    Args:\n        target_dict: Dictionary to store converted weights\n        source_dict: Source dictionary containing weights\n        prefix: Prefix for the keys in the state dictionary\n\n    Returns:\n        Tuple of (updated target_dict, updated source_dict)\n    \"\"\"\n    output_keys = [\n        (f\"{prefix}proj_out.weight\", \"final_layer.linear.weight\"),\n        (f\"{prefix}proj_out.bias\", \"final_layer.linear.bias\"),\n        (f\"{prefix}norm_out.linear.weight\", \"final_layer.adaLN_modulation.1.weight\"),\n    ]\n\n    for src_key, target_key in output_keys:\n        target_dict, source_dict = convert_layer_weights(target_dict, source_dict, src_key, target_key)\n\n    return target_dict, source_dict\n\n\ndef convert_diffusers_to_flux_transformer_checkpoint(\n    diffusers_state_dict,\n    num_layers=19,\n    num_single_layers=38,\n    has_guidance=True,\n    old_prefix=\"base_model.model.\",\n    new_prefix=FLUX_KOHYA_TRANSFORMER_KEY,\n):\n    \"\"\"\n    Convert a diffusers state dictionary to flux transformer checkpoint format.\n\n    Args:\n        diffusers_state_dict: Source diffusers state dictionary\n        num_layers: Number of double transformer layers\n        num_single_layers: Number of single transformer layers\n        has_guidance: Whether the model has guidance embedding\n        prefix: Prefix for keys in the source dictionary\n\n    Returns:\n        A new state dictionary in flux transformer format\n    \"\"\"\n    # Create a new state dictionary\n    flux_state_dict = {}\n\n    # Convert embedding layers\n    flux_state_dict, diffusers_state_dict = convert_embedding_layers(\n        flux_state_dict, diffusers_state_dict, old_prefix, has_guidance\n    )\n\n    # Convert double transformer blocks\n    for i in range(num_layers):\n        flux_state_dict, diffusers_state_dict = convert_double_transformer_block(\n            flux_state_dict, diffusers_state_dict, old_prefix, i\n        )\n\n    # Convert single transformer blocks\n    for i in range(num_single_layers):\n        flux_state_dict, diffusers_state_dict = convert_single_transformer_block(\n            flux_state_dict, diffusers_state_dict, old_prefix, i\n        )\n\n    # Convert output layers\n    flux_state_dict, diffusers_state_dict = convert_output_layers(flux_state_dict, diffusers_state_dict, old_prefix)\n\n    # Check for leftover keys\n    if diffusers_state_dict:\n        print(f\"Unexpected keys: {list(diffusers_state_dict.keys())}\")\n\n    # Replace the old prefix with the new prefix\n    keys = list(flux_state_dict.keys())\n    for key in keys:\n        new_key = f\"{new_prefix}.{key}\"\n        flux_state_dict[new_key] = flux_state_dict.pop(key)\n    return flux_state_dict\n"
  },
  {
    "path": "src/invoke_training/_shared/flux/model_loading_utils.py",
    "content": "import logging\nfrom enum import Enum\n\nimport torch\nfrom diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel\nfrom transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer\n\n\nclass PipelineVersionEnum(Enum):\n    FLUX = \"FLUX\"\n\n\ndef load_pipeline(\n    logger: logging.Logger,\n    model_name_or_path: str = \"black-forest-labs/FLUX.1-dev\",\n    pipeline_version: PipelineVersionEnum = PipelineVersionEnum.FLUX,\n    transformer_path: str | None = None,\n    text_encoder_1_path: str | None = None,\n    text_encoder_2_path: str | None = None,\n    torch_dtype: torch.dtype | None = None,\n) -> FluxPipeline:\n    \"\"\"Load a Flux pipeline with optional custom components from .safetensors files.\n\n    Args:\n        logger: Logger instance\n        model_name_or_path: Base model path or repository\n        pipeline_version: Pipeline version (currently only FLUX supported)\n        transformer_path: Path to custom transformer .safetensors file\n        text_encoder_1_path: Path to custom CLIP text encoder .safetensors file\n        text_encoder_2_path: Path to custom T5 text encoder .safetensors file\n        torch_dtype: Desired dtype for the models\n    Returns:\n        FluxPipeline: Configured pipeline with custom components if specified\n    \"\"\"\n    if pipeline_version != PipelineVersionEnum.FLUX:\n        raise ValueError(f\"Invalid pipeline version: {pipeline_version}\")\n\n    # Prepare kwargs for from_pretrained\n    kwargs = {\"torch_dtype\": torch_dtype}\n\n    # Add components only if custom paths are provided\n    if transformer_path is not None:\n        # load_model_from_file_or_pretrained(FluxTransformer2DModel, transformer_path, torch_dtype=torch_dtype,\n        # use_safetensors=True, subfolder=\"transformer\")\n        kwargs[\"transformer\"] = FluxTransformer2DModel.from_pretrained(\n            transformer_path,\n            torch_dtype=torch_dtype,\n        )\n        logger.info(f\"Loading custom transformer from {transformer_path}\")\n\n    if text_encoder_1_path is not None:\n        logger.info(f\"Loading custom CLIP text encoder from {text_encoder_1_path}\")\n        kwargs[\"text_encoder\"] = CLIPTextModel.from_pretrained(text_encoder_1_path, torch_dtype=torch_dtype)\n\n    if text_encoder_2_path is not None:\n        logger.info(f\"Loading custom T5 text encoder from {text_encoder_2_path}\")\n        kwargs[\"text_encoder_2\"] = T5EncoderModel.from_pretrained(text_encoder_2_path, torch_dtype=torch_dtype)\n\n    # Load the pipeline with any custom components\n    pipeline = FluxPipeline.from_pretrained(model_name_or_path, **kwargs)\n\n    return pipeline\n\n\ndef load_models_flux(\n    logger: logging.Logger,\n    model_name_or_path: str = \"black-forest-labs/FLUX.1-dev\",\n    dtype: torch.dtype | None = None,\n    transformer_path: str | None = None,\n    text_encoder_1_path: str | None = None,\n    text_encoder_2_path: str | None = None,\n) -> tuple[CLIPTokenizer, FlowMatchEulerDiscreteScheduler, CLIPTextModel, AutoencoderKL, FluxTransformer2DModel]:\n    \"\"\"Load all models required for training from disk, transfer them to the\n    target training device and cast their weight dtypes.\n\n    Args:\n        logger: Logger instance\n        model_name_or_path: Base model path or repository\n        dtype: Desired dtype for the models\n        transformer_path: Path to custom transformer .safetensors file\n        text_encoder_1_path: Path to custom CLIP text encoder .safetensors file\n        text_encoder_2_path: Path to custom T5 text encoder .safetensors file\n    \"\"\"\n\n    pipeline: FluxPipeline = load_pipeline(\n        logger=logger,\n        model_name_or_path=model_name_or_path,\n        pipeline_version=PipelineVersionEnum.FLUX,\n        transformer_path=transformer_path,\n        text_encoder_1_path=text_encoder_1_path,\n        text_encoder_2_path=text_encoder_2_path,\n        torch_dtype=dtype,\n    )\n\n    # Tokenizers and text encoders.\n    tokenizer_1: CLIPTokenizer = pipeline.tokenizer\n    text_encoder_1: CLIPTextModel = pipeline.text_encoder\n\n    tokenizer_2: T5Tokenizer = pipeline.tokenizer_2\n    text_encoder_2: T5EncoderModel = pipeline.text_encoder_2\n\n    # Transformer and Scheduler\n    transformer: FluxTransformer2DModel = pipeline.transformer\n    noise_scheduler: FlowMatchEulerDiscreteScheduler = pipeline.scheduler\n\n    # Decoder\n    vae: AutoencoderKL = pipeline.vae\n\n    # Log component status\n    logger.info(\n        f\"Pipeline components loaded: tokenizer_1={tokenizer_1 is not None}, \"\n        f\"text_encoder_1={text_encoder_1 is not None}, \"\n        f\"tokenizer_2={tokenizer_2 is not None}, \"\n        f\"text_encoder_2={text_encoder_2 is not None}, \"\n        f\"transformer={transformer is not None}, \"\n        f\"vae={vae is not None}\"\n    )\n\n    # Check for None components\n    if text_encoder_1 is None:\n        raise ValueError(\n            \"text_encoder_1 failed to load. \"\n            \"Check if you have access to the model repository and are properly authenticated.\"\n        )\n    if text_encoder_2 is None:\n        raise ValueError(\n            \"text_encoder_2 failed to load. \"\n            \"Check if you have access to the model repository and are properly authenticated.\"\n        )\n    if transformer is None:\n        raise ValueError(\n            \"transformer failed to load. \"\n            \"Check if you have access to the model repository and are properly authenticated.\"\n        )\n    if vae is None:\n        raise ValueError(\n            \"vae failed to load. Check if you have access to the model repository and are properly authenticated.\"\n        )\n\n    # Disable gradient calculation for model weights to save memory.\n    text_encoder_1.requires_grad_(False)\n    text_encoder_2.requires_grad_(False)\n    vae.requires_grad_(False)\n    transformer.requires_grad_(False)\n\n    if dtype is not None:\n        text_encoder_1 = text_encoder_1.to(dtype=dtype)\n        text_encoder_2 = text_encoder_2.to(dtype=dtype)\n        vae = vae.to(dtype=dtype)\n        transformer = transformer.to(dtype=dtype)\n\n    # Put models in 'eval' mode.\n    text_encoder_1.eval()\n    text_encoder_2.eval()\n    vae.eval()\n    transformer.eval()\n\n    return tokenizer_1, tokenizer_2, noise_scheduler, text_encoder_1, text_encoder_2, vae, transformer\n"
  },
  {
    "path": "src/invoke_training/_shared/flux/validation.py",
    "content": "import logging\nimport os\n\nimport numpy as np\nimport torch\nimport torch.utils.data\nfrom accelerate import Accelerator\nfrom accelerate.hooks import remove_hook_from_module\nfrom diffusers import (\n    AutoencoderKL,\n    FlowMatchEulerDiscreteScheduler,\n    FluxPipeline,\n    FluxTransformer2DModel,\n)\nfrom peft import PeftModel\nfrom transformers import CLIPTextModel, CLIPTokenizer\n\nfrom invoke_training._shared.data.utils.resolution import Resolution\nfrom invoke_training.pipelines.callbacks import PipelineCallbacks, ValidationImage, ValidationImages\nfrom invoke_training.pipelines.flux.lora.config import FluxLoraConfig\n\nNUM_INFERENCE_STEPS = 20\n\n\ndef generate_validation_images_flux(  # noqa: C901\n    epoch: int,\n    step: int,\n    out_dir: str,\n    accelerator: Accelerator,\n    vae: AutoencoderKL,\n    text_encoder_1: CLIPTextModel,\n    text_encoder_2: CLIPTextModel,\n    tokenizer_1: CLIPTokenizer,\n    tokenizer_2: CLIPTokenizer,\n    noise_scheduler: FlowMatchEulerDiscreteScheduler,\n    transformer: FluxTransformer2DModel | PeftModel,\n    config: FluxLoraConfig,\n    logger: logging.Logger,\n    callbacks: list[PipelineCallbacks] | None = None,\n):\n    \"\"\"Generate validation images for the purpose of tracking image generation behaviour on fixed prompts throughout\n    training.\n    \"\"\"\n    # Record original model devices so that we can restore this state after running the pipeline with CPU model\n    # offloading.\n    transformer_device = transformer.device\n    vae_device = vae.device\n    text_encoder_1_device = text_encoder_1.device\n    text_encoder_2_device = text_encoder_2.device\n\n    # Create pipeline.\n    pipeline = FluxPipeline(\n        vae=vae,\n        text_encoder=text_encoder_1,\n        text_encoder_2=text_encoder_2,\n        tokenizer=tokenizer_1,\n        tokenizer_2=tokenizer_2,\n        transformer=transformer,\n        scheduler=noise_scheduler,\n    )\n    if config.enable_cpu_offload_during_validation:\n        pipeline.enable_model_cpu_offload(accelerator.device.index or 0)\n    else:\n        pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    validation_resolution = Resolution.parse(config.data_loader.resolution)\n\n    validation_images = ValidationImages(images=[], epoch=epoch, step=step)\n\n    validation_step_dir = os.path.join(out_dir, \"validation\", f\"epoch_{epoch:0>8}-step_{step:0>8}\")\n    logger.info(f\"Generating validation images ({validation_step_dir}).\")\n\n    # Run inference.\n    with torch.no_grad():\n        for prompt_idx in range(len(config.validation_prompts)):\n            positive_prompt = config.validation_prompts[prompt_idx]\n            negative_prompt = None\n            logger.info(f\"Validation prompt {prompt_idx}, pos: '{positive_prompt}', neg: '{negative_prompt or ''}'\")\n\n            generator = torch.Generator(device=accelerator.device)\n            if config.seed is not None:\n                generator = generator.manual_seed(config.seed)\n\n            images = []\n            for _ in range(config.num_validation_images_per_prompt):\n                with accelerator.autocast():\n                    images.append(\n                        pipeline(\n                            positive_prompt,\n                            num_inference_steps=NUM_INFERENCE_STEPS,\n                            generator=generator,\n                            height=validation_resolution.height,\n                            width=validation_resolution.width,\n                            negative_prompt=negative_prompt,\n                        ).images[0]\n                    )\n\n            # Save images to disk.\n            validation_prompt_dir = os.path.join(validation_step_dir, f\"prompt_{prompt_idx:0>4}\")\n            os.makedirs(validation_prompt_dir)\n            for image_idx, image in enumerate(images):\n                image_path = os.path.join(validation_prompt_dir, f\"{image_idx:0>4}.jpg\")\n                validation_images.images.append(\n                    ValidationImage(file_path=image_path, prompt=positive_prompt, image_idx=image_idx)\n                )\n                image.save(image_path)\n\n            # Log images to trackers. Currently, only tensorboard is supported.\n            for tracker in accelerator.trackers:\n                if tracker.name == \"tensorboard\":\n                    np_images = np.stack([np.asarray(img) for img in images])\n                    tracker.writer.add_images(\n                        f\"validation (prompt {prompt_idx})\",\n                        np_images,\n                        step,\n                        dataformats=\"NHWC\",\n                    )\n\n    del pipeline\n    torch.cuda.empty_cache()\n\n    for model in [transformer, vae, text_encoder_1, text_encoder_2]:\n        remove_hook_from_module(model)\n\n    # Restore models to original devices.\n    transformer.to(transformer_device)\n    vae.to(vae_device)\n    text_encoder_1.to(text_encoder_1_device)\n    text_encoder_2.to(text_encoder_2_device)\n\n    # Run callbacks.\n    if callbacks is not None:\n        for cb in callbacks:\n            cb.on_save_validation_images(images=validation_images)\n"
  },
  {
    "path": "src/invoke_training/_shared/optimizer/__init__.py",
    "content": ""
  },
  {
    "path": "src/invoke_training/_shared/optimizer/optimizer_utils.py",
    "content": "import torch\nfrom prodigyopt import Prodigy\n\nfrom invoke_training.config.optimizer.optimizer_config import AdamOptimizerConfig, ProdigyOptimizerConfig\n\n\ndef initialize_optimizer(\n    config: AdamOptimizerConfig | ProdigyOptimizerConfig, trainable_params: list\n) -> torch.optim.Optimizer:\n    \"\"\"Initialize an optimizer based on the provided config.\"\"\"\n\n    if config.optimizer_type == \"AdamW\":\n        adam_cls = torch.optim.AdamW\n        if config.use_8bit:\n            try:\n                import bitsandbytes  # noqa: F401\n            except ImportError:\n                raise ImportError(\n                    \"bitsandbytes is not installed. bitsandbytes is required to use the 8-bit Adam optimizer. Install \"\n                    'it by running `pip install \".[bitsandbytes]\"`.'\n                )\n            adam_cls = bitsandbytes.optim.AdamW8bit\n        optimizer = adam_cls(\n            trainable_params,\n            lr=config.learning_rate,\n            betas=(config.beta1, config.beta2),\n            weight_decay=config.weight_decay,\n            eps=config.epsilon,\n        )\n    elif config.optimizer_type == \"Prodigy\":\n        optimizer = Prodigy(\n            trainable_params,\n            lr=config.learning_rate,\n            weight_decay=config.weight_decay,\n            use_bias_correction=config.use_bias_correction,\n            safeguard_warmup=config.safeguard_warmup,\n        )\n    else:\n        raise ValueError(f\"'{config.optimizer_type}' is not a supported optimizer.\")\n\n    return optimizer\n"
  },
  {
    "path": "src/invoke_training/_shared/stable_diffusion/__init__.py",
    "content": ""
  },
  {
    "path": "src/invoke_training/_shared/stable_diffusion/base_model_version.py",
    "content": "from enum import Enum\n\nfrom transformers import PretrainedConfig\n\n\nclass BaseModelVersionEnum(Enum):\n    STABLE_DIFFUSION_V1 = 1\n    STABLE_DIFFUSION_V2 = 2\n    STABLE_DIFFUSION_SDXL_BASE = 3\n    STABLE_DIFFUSION_SDXL_REFINER = 4\n\n\ndef get_base_model_version(\n    diffusers_model_name: str, revision: str = \"main\", local_files_only: bool = True\n) -> BaseModelVersionEnum:\n    \"\"\"Returns the `BaseModelVersionEnum` of a diffusers model.\n\n    Args:\n        diffusers_model_name (str): The diffusers model name (on Hugging Face Hub).\n        revision (str, optional): The model revision (branch or commit hash). Defaults to \"main\".\n\n    Raises:\n        Exception: If the base model version can not be determined.\n\n    Returns:\n        BaseModelVersionEnum: The detected base model version.\n    \"\"\"\n    unet_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path=diffusers_model_name,\n        revision=revision,\n        subfolder=\"unet\",\n        local_files_only=local_files_only,\n    )\n\n    # This logic was copied from\n    # https://github.com/invoke-ai/InvokeAI/blob/e77400ab62d24acbdf2f48a7427705e7b8b97e4a/invokeai/backend/model_management/model_probe.py#L412-L421\n    # This seems fragile. If you see this and know of a better way to detect the base model version, your contribution\n    # would be welcome.\n    if unet_config.cross_attention_dim == 768:\n        return BaseModelVersionEnum.STABLE_DIFFUSION_V1\n    elif unet_config.cross_attention_dim == 1024:\n        return BaseModelVersionEnum.STABLE_DIFFUSION_V2\n    elif unet_config.cross_attention_dim == 1280:\n        return BaseModelVersionEnum.STABLE_DIFFUSION_SDXL_REFINER\n    elif unet_config.cross_attention_dim == 2048:\n        return BaseModelVersionEnum.STABLE_DIFFUSION_SDXL_BASE\n    else:\n        raise Exception(\n            \"Failed to determine base model version. UNet cross_attention_dim has unexpected value: \"\n            f\"'{unet_config.cross_attention_dim}'.\"\n        )\n\n\ndef check_base_model_version(\n    allowed_versions: set[BaseModelVersionEnum],\n    diffusers_model_name: str,\n    revision: str = \"main\",\n    local_files_only: bool = True,\n):\n    \"\"\"Helper function that checks if a diffusers model is compatible with a set of base model versions.\n\n    Args:\n        allowed_versions (set[BaseModelVersionEnum]): The set of allowed base model versions.\n        diffusers_model_name (str): The diffusers model name (on Hugging Face Hub) to check.\n        revision (str, optional): The model revision (branch or commit hash). Defaults to \"main\".\n\n    Raises:\n        ValueError: If the model has an unsupported version.\n    \"\"\"\n    version = get_base_model_version(diffusers_model_name, revision, local_files_only)\n    if version not in allowed_versions:\n        raise ValueError(\n            f\"Model '{diffusers_model_name}' (revision='{revision}') has an unsupported version: '{version.name}'. \"\n            f\"Supported versions: {[v.name for v in allowed_versions]}.\"\n        )\n"
  },
  {
    "path": "src/invoke_training/_shared/stable_diffusion/checkpoint_utils.py",
    "content": "from pathlib import Path\n\nimport torch\nfrom diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionXLPipeline, UNet2DConditionModel\nfrom transformers import CLIPTextModel, CLIPTokenizer\n\n\ndef save_sdxl_diffusers_unet_checkpoint(\n    checkpoint_path: Path | str, unet: UNet2DConditionModel, save_dtype: torch.dtype\n):\n    # Record original device and dtype so that we can restore it afterward.\n    model_list = [unet]\n    original_devices = [model.device for model in model_list]\n    original_dtypes = [model.dtype for model in model_list]\n\n    # Save UNet.\n    unet.to(dtype=save_dtype)\n    unet.save_pretrained(Path(checkpoint_path) / \"unet\")\n\n    # Restore original device/dtype.\n    for model, device, dtype in zip(model_list, original_devices, original_dtypes, strict=True):\n        model.to(device=device, dtype=dtype)\n\n\ndef save_sdxl_diffusers_checkpoint(\n    checkpoint_path: Path | str,\n    vae: AutoencoderKL,\n    text_encoder_1: CLIPTextModel,\n    text_encoder_2: CLIPTextModel,\n    tokenizer_1: CLIPTokenizer,\n    tokenizer_2: CLIPTokenizer,\n    noise_scheduler: DDPMScheduler,\n    unet: UNet2DConditionModel,\n    save_dtype: torch.dtype,\n):\n    # Record original device and dtype so that we can restore it afterward.\n    # TODO(ryand): This method of restoring original device/dtype is a bit naive. It does not handle mixed precisions\n    # within a model, and results in a loss of precision if the save_dtype is lower precision than the model dtype. We\n    # may need to revisit this.\n    model_list = [vae, text_encoder_1, text_encoder_2, unet]\n    original_devices = [model.device for model in model_list]\n    original_dtypes = [model.dtype for model in model_list]\n\n    # Create pipeline.\n    pipeline = StableDiffusionXLPipeline(\n        vae=vae,\n        text_encoder=text_encoder_1,\n        text_encoder_2=text_encoder_2,\n        tokenizer=tokenizer_1,\n        tokenizer_2=tokenizer_2,\n        unet=unet,\n        scheduler=noise_scheduler,\n    )\n    pipeline = pipeline.to(dtype=save_dtype)\n\n    # Save pipeline.\n    pipeline.save_pretrained(checkpoint_path)\n\n    # Restore original device/dtype.\n    for model, device, dtype in zip(model_list, original_devices, original_dtypes, strict=True):\n        model.to(device=device, dtype=dtype)\n"
  },
  {
    "path": "src/invoke_training/_shared/stable_diffusion/lora_checkpoint_utils.py",
    "content": "import os\nfrom pathlib import Path\n\nimport peft\nimport torch\nfrom diffusers import UNet2DConditionModel\nfrom transformers import CLIPTextModel\n\nfrom invoke_training._shared.checkpoints.lora_checkpoint_utils import (\n    _convert_peft_models_to_kohya_state_dict,\n    _convert_peft_state_dict_to_kohya_state_dict,\n    load_multi_model_peft_checkpoint,\n    save_multi_model_peft_checkpoint,\n)\nfrom invoke_training._shared.checkpoints.serialization import save_state_dict\n\n# Copied from https://github.com/huggingface/peft/blob/8665e2b5719faa4e4b91749ddec09442927b53e0/examples/stable_diffusion/train_dreambooth.py#L49C1-L65C87\nUNET_TARGET_MODULES = [\n    \"to_q\",\n    \"to_k\",\n    \"to_v\",\n    \"proj\",\n    \"proj_in\",\n    \"proj_out\",\n    \"conv\",\n    \"conv1\",\n    \"conv2\",\n    \"conv_shortcut\",\n    \"to_out.0\",\n    \"time_emb_proj\",\n    \"ff.net.2\",\n]\nTEXT_ENCODER_TARGET_MODULES = [\"fc1\", \"fc2\", \"q_proj\", \"k_proj\", \"v_proj\", \"out_proj\"]\n\n# Module lists copied from diffusers training script.\n# These module lists will produce lighter, less expressive, LoRA models than the non-light versions.\nUNET_TARGET_MODULES_LIGHT = [\"to_k\", \"to_q\", \"to_v\", \"to_out.0\"]\nTEXT_ENCODER_TARGET_MODULES_LIGHT = [\"q_proj\", \"k_proj\", \"v_proj\", \"out_proj\"]\n\nSD_PEFT_UNET_KEY = \"unet\"\nSD_PEFT_TEXT_ENCODER_KEY = \"text_encoder\"\n\nSDXL_PEFT_UNET_KEY = \"unet\"\nSDXL_PEFT_TEXT_ENCODER_1_KEY = \"text_encoder_1\"\nSDXL_PEFT_TEXT_ENCODER_2_KEY = \"text_encoder_2\"\n\nSD_KOHYA_UNET_KEY = \"lora_unet\"\nSD_KOHYA_TEXT_ENCODER_KEY = \"lora_te\"\n\nSDXL_KOHYA_UNET_KEY = \"lora_unet\"\nSDXL_KOHYA_TEXT_ENCODER_1_KEY = \"lora_te1\"\nSDXL_KOHYA_TEXT_ENCODER_2_KEY = \"lora_te2\"\n\nSD_PEFT_TO_KOHYA_KEYS = {\n    SD_PEFT_UNET_KEY: SD_KOHYA_UNET_KEY,\n    SD_PEFT_TEXT_ENCODER_KEY: SD_KOHYA_TEXT_ENCODER_KEY,\n}\n\nSDXL_PEFT_TO_KOHYA_KEYS = {\n    SDXL_PEFT_UNET_KEY: SDXL_KOHYA_UNET_KEY,\n    SDXL_PEFT_TEXT_ENCODER_1_KEY: SDXL_KOHYA_TEXT_ENCODER_1_KEY,\n    SDXL_PEFT_TEXT_ENCODER_2_KEY: SDXL_KOHYA_TEXT_ENCODER_2_KEY,\n}\n\n\ndef save_sd_peft_checkpoint(\n    checkpoint_dir: Path | str, unet: peft.PeftModel | None, text_encoder: peft.PeftModel | None\n):\n    models = {}\n    if unet is not None:\n        models[SD_PEFT_UNET_KEY] = unet\n    if text_encoder is not None:\n        models[SD_PEFT_TEXT_ENCODER_KEY] = text_encoder\n\n    save_multi_model_peft_checkpoint(checkpoint_dir=checkpoint_dir, models=models)\n\n\ndef load_sd_peft_checkpoint(\n    checkpoint_dir: Path | str, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, is_trainable: bool = False\n):\n    models = load_multi_model_peft_checkpoint(\n        checkpoint_dir=checkpoint_dir,\n        models={SD_PEFT_UNET_KEY: unet, SD_PEFT_TEXT_ENCODER_KEY: text_encoder},\n        is_trainable=is_trainable,\n        raise_if_subdir_missing=False,\n    )\n\n    return models[SD_PEFT_UNET_KEY], models[SD_PEFT_TEXT_ENCODER_KEY]\n\n\ndef save_sdxl_peft_checkpoint(\n    checkpoint_dir: Path | str,\n    unet: peft.PeftModel | None,\n    text_encoder_1: peft.PeftModel | None,\n    text_encoder_2: peft.PeftModel | None,\n):\n    models = {}\n    if unet is not None:\n        models[SDXL_PEFT_UNET_KEY] = unet\n    if text_encoder_1 is not None:\n        models[SDXL_PEFT_TEXT_ENCODER_1_KEY] = text_encoder_1\n    if text_encoder_2 is not None:\n        models[SDXL_PEFT_TEXT_ENCODER_2_KEY] = text_encoder_2\n\n    save_multi_model_peft_checkpoint(checkpoint_dir=checkpoint_dir, models=models)\n\n\ndef load_sdxl_peft_checkpoint(\n    checkpoint_dir: Path | str,\n    unet: UNet2DConditionModel,\n    text_encoder_1: CLIPTextModel,\n    text_encoder_2: CLIPTextModel,\n    is_trainable: bool = False,\n):\n    models = load_multi_model_peft_checkpoint(\n        checkpoint_dir=checkpoint_dir,\n        models={\n            SDXL_PEFT_UNET_KEY: unet,\n            SDXL_PEFT_TEXT_ENCODER_1_KEY: text_encoder_1,\n            SDXL_PEFT_TEXT_ENCODER_2_KEY: text_encoder_2,\n        },\n        is_trainable=is_trainable,\n        raise_if_subdir_missing=False,\n    )\n\n    return models[SDXL_PEFT_UNET_KEY], models[SDXL_PEFT_TEXT_ENCODER_1_KEY], models[SDXL_PEFT_TEXT_ENCODER_2_KEY]\n\n\ndef save_sd_kohya_checkpoint(checkpoint_path: Path, unet: peft.PeftModel | None, text_encoder: peft.PeftModel | None):\n    kohya_prefixes = []\n    models = []\n    for kohya_prefix, peft_model in zip([SD_KOHYA_UNET_KEY, SD_KOHYA_TEXT_ENCODER_KEY], [unet, text_encoder]):\n        if peft_model is not None:\n            kohya_prefixes.append(kohya_prefix)\n            models.append(peft_model)\n\n    kohya_state_dict = _convert_peft_models_to_kohya_state_dict(kohya_prefixes=kohya_prefixes, models=models)\n\n    checkpoint_path.parent.mkdir(parents=True, exist_ok=True)\n    save_state_dict(kohya_state_dict, checkpoint_path)\n\n\ndef save_sdxl_kohya_checkpoint(\n    checkpoint_path: Path,\n    unet: peft.PeftModel | None,\n    text_encoder_1: peft.PeftModel | None,\n    text_encoder_2: peft.PeftModel | None,\n):\n    kohya_prefixes = []\n    models = []\n    for kohya_prefix, peft_model in zip(\n        [SDXL_KOHYA_UNET_KEY, SDXL_KOHYA_TEXT_ENCODER_1_KEY, SDXL_KOHYA_TEXT_ENCODER_2_KEY],\n        [unet, text_encoder_1, text_encoder_2],\n    ):\n        if peft_model is not None:\n            kohya_prefixes.append(kohya_prefix)\n            models.append(peft_model)\n\n    kohya_state_dict = _convert_peft_models_to_kohya_state_dict(kohya_prefixes=kohya_prefixes, models=models)\n\n    checkpoint_path.parent.mkdir(parents=True, exist_ok=True)\n    save_state_dict(kohya_state_dict, checkpoint_path)\n\n\ndef convert_sd_peft_checkpoint_to_kohya_state_dict(\n    in_checkpoint_dir: Path,\n    out_checkpoint_file: Path,\n    dtype: torch.dtype = torch.float32,\n) -> dict[str, torch.Tensor]:\n    \"\"\"Convert SD v1 or SDXL PEFT models to a Kohya-format LoRA state dict.\"\"\"\n    # Get the immediate subdirectories of the checkpoint directory. We assume that each subdirectory is a PEFT model.\n    peft_model_dirs = os.listdir(in_checkpoint_dir)\n    peft_model_dirs = [in_checkpoint_dir / d for d in peft_model_dirs]  # Convert to Path objects.\n    peft_model_dirs = [d for d in peft_model_dirs if d.is_dir()]  # Filter out non-directories.\n\n    if len(peft_model_dirs) == 0:\n        raise ValueError(f\"No checkpoint files found in directory '{in_checkpoint_dir}'.\")\n\n    kohya_state_dict = {}\n    for peft_model_dir in peft_model_dirs:\n        if peft_model_dir.name in SD_PEFT_TO_KOHYA_KEYS:\n            kohya_prefix = SD_PEFT_TO_KOHYA_KEYS[peft_model_dir.name]\n        elif peft_model_dir.name in SDXL_PEFT_TO_KOHYA_KEYS:\n            kohya_prefix = SDXL_PEFT_TO_KOHYA_KEYS[peft_model_dir.name]\n        else:\n            raise ValueError(f\"Unrecognized checkpoint directory: '{peft_model_dir}'.\")\n\n        # Note: This logic to load the LoraConfig and weights directly is based on how it is done here:\n        # https://github.com/huggingface/peft/blob/8665e2b5719faa4e4b91749ddec09442927b53e0/src/peft/peft_model.py#L672-L689\n        # This may need to be updated in the future to support other adapter types (LoKr, LoHa, etc.).\n        # Also, I could see this interface breaking in the future.\n        lora_config = peft.LoraConfig.from_pretrained(peft_model_dir)\n        lora_weights = peft.utils.load_peft_weights(peft_model_dir, device=\"cpu\")\n\n        kohya_state_dict.update(\n            _convert_peft_state_dict_to_kohya_state_dict(\n                lora_config=lora_config, peft_state_dict=lora_weights, prefix=kohya_prefix, dtype=dtype\n            )\n        )\n\n    save_state_dict(kohya_state_dict, out_checkpoint_file)\n"
  },
  {
    "path": "src/invoke_training/_shared/stable_diffusion/min_snr_weighting.py",
    "content": "import torch\nfrom diffusers import DDPMScheduler\n\n\ndef compute_snr(noise_scheduler: DDPMScheduler, timesteps: torch.Tensor):\n    \"\"\"\n    Computes SNR.\n\n    Adapted from:\n    https://github.com/huggingface/diffusers/blob/ea9dc3fa90c70c7cd825ca2346a31153e08b5367/src/diffusers/training_utils.py#L40\n\n    Which was originally based on:\n    https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849\n    \"\"\"\n    alphas_cumprod = noise_scheduler.alphas_cumprod\n    sqrt_alphas_cumprod = alphas_cumprod**0.5\n    sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5\n\n    # Expand the tensors.\n    # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026\n    sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()\n    while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):\n        sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]\n    alpha = sqrt_alphas_cumprod.expand(timesteps.shape)\n\n    sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()\n    while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):\n        sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]\n    sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)\n\n    # Compute SNR.\n    snr = (alpha / sigma) ** 2\n    return snr\n"
  },
  {
    "path": "src/invoke_training/_shared/stable_diffusion/model_loading_utils.py",
    "content": "import logging\nimport os\nimport typing\nfrom enum import Enum\n\nimport torch\nfrom diffusers import (\n    AutoencoderKL,\n    DDPMScheduler,\n    StableDiffusionPipeline,\n    StableDiffusionXLPipeline,\n    UNet2DConditionModel,\n)\nfrom transformers import CLIPTextModel, CLIPTokenizer\n\nfrom invoke_training._shared.checkpoints.serialization import load_state_dict\n\nHF_VARIANT_FALLBACKS = [None, \"fp16\"]\n\n\nclass PipelineVersionEnum(Enum):\n    SD = \"SD\"\n    SDXL = \"SDXL\"\n\n\ndef load_pipeline(\n    logger: logging.Logger,\n    model_name_or_path: str,\n    pipeline_version: PipelineVersionEnum,\n    torch_dtype: torch.dtype = None,\n    variant: str | None = None,\n) -> typing.Union[StableDiffusionPipeline, StableDiffusionXLPipeline]:\n    \"\"\"Load a Stable Diffusion pipeline from disk.\n\n    Args:\n        model_name_or_path (str): The name or path of the pipeline to load. Can be in diffusers format, or a single\n            stable diffusion checkpoint file. (E.g. 'runwayml/stable-diffusion-v1-5',\n            'stabilityai/stable-diffusion-xl-base-1.0', '/path/to/realisticVisionV51_v51VAE.safetensors', etc. )\n        pipeline_version (PipelineVersionEnum): The pipeline version.\n        variant (str | None): The Hugging Face Hub variant. Only applies if `model_name_or_path` is a HF Hub model name.\n\n    Returns:\n        typing.Union[StableDiffusionPipeline, StableDiffusionXLPipeline]: The loaded pipeline.\n    \"\"\"\n    if pipeline_version == PipelineVersionEnum.SD:\n        pipeline_class = StableDiffusionPipeline\n    elif pipeline_version == PipelineVersionEnum.SDXL:\n        pipeline_class = StableDiffusionXLPipeline\n    else:\n        raise ValueError(f\"Unsupported pipeline_version: '{pipeline_version}'.\")\n\n    if os.path.isfile(model_name_or_path):\n        return pipeline_class.from_single_file(\n            model_name_or_path,\n            torch_dtype=torch_dtype,\n            safety_checker=None,\n            feature_extractor=None,\n        )\n\n    return from_pretrained_with_variant_fallback(\n        logger=logger,\n        model_class=pipeline_class,\n        model_name_or_path=model_name_or_path,\n        torch_dtype=torch_dtype,\n        variant=variant,\n        # kwargs\n        safety_checker=None,\n        requires_safety_checker=False,\n    )\n\n\nModelT = typing.TypeVar(\"ModelT\")\n\n\ndef from_pretrained_with_variant_fallback(\n    logger: logging.Logger,\n    model_class: typing.Type[ModelT],\n    model_name_or_path: str,\n    torch_dtype: torch.dtype | None = None,\n    variant: str | None = None,\n    **kwargs,\n) -> ModelT:\n    \"\"\"A wrapper for .from_pretrained() that tries multiple variants if the initial one fails.\"\"\"\n    variants_to_try = [variant] + [v for v in HF_VARIANT_FALLBACKS if v != variant]\n\n    model: ModelT | None = None\n    for variant_to_try in variants_to_try:\n        if variant_to_try != variant:\n            logger.warning(f\"Trying fallback variant '{variant_to_try}'.\")\n        try:\n            model = model_class.from_pretrained(\n                model_name_or_path,\n                torch_dtype=torch_dtype,\n                variant=variant_to_try,\n                **kwargs,\n            )\n        except (OSError, ValueError) as e:\n            error_str = str(e)\n            if \"no file named\" in error_str or \"no such modeling files are available\" in error_str:\n                # Ok; we'll try the variant fallbacks.\n                logger.warning(f\"Failed to load '{model_name_or_path}' with variant '{variant_to_try}'. Error: {e}.\")\n            else:\n                raise\n\n        if model is not None:\n            break\n\n    if model is None:\n        raise RuntimeError(f\"Failed to load model '{model_name_or_path}'.\")\n    return model\n\n\ndef load_models_sd(\n    logger: logging.Logger,\n    model_name_or_path: str,\n    hf_variant: str | None = None,\n    base_embeddings: dict[str, str] = None,\n    dtype: torch.dtype | None = None,\n) -> tuple[CLIPTokenizer, DDPMScheduler, CLIPTextModel, AutoencoderKL, UNet2DConditionModel]:\n    \"\"\"Load all models required for training from disk, transfer them to the\n    target training device and cast their weight dtypes.\n    \"\"\"\n    base_embeddings = base_embeddings or {}\n\n    pipeline: StableDiffusionPipeline = load_pipeline(\n        logger=logger,\n        model_name_or_path=model_name_or_path,\n        pipeline_version=PipelineVersionEnum.SD,\n        variant=hf_variant,\n    )\n\n    for token, embedding_path in base_embeddings.items():\n        pipeline.load_textual_inversion(embedding_path, token=token)\n\n    # Extract sub-models from the pipeline.\n    tokenizer: CLIPTokenizer = pipeline.tokenizer\n    text_encoder: CLIPTextModel = pipeline.text_encoder\n    vae: AutoencoderKL = pipeline.vae\n    unet: UNet2DConditionModel = pipeline.unet\n    noise_scheduler = DDPMScheduler(\n        beta_start=0.00085,\n        beta_end=0.012,\n        beta_schedule=\"scaled_linear\",\n        num_train_timesteps=1000,\n        clip_sample=False,\n        steps_offset=1,\n    )\n\n    # Disable gradient calculation for model weights to save memory.\n    text_encoder.requires_grad_(False)\n    vae.requires_grad_(False)\n    unet.requires_grad_(False)\n\n    if dtype is not None:\n        text_encoder = text_encoder.to(dtype=dtype)\n        vae = vae.to(dtype=dtype)\n        unet = unet.to(dtype=dtype)\n\n    # Put models in 'eval' mode.\n    text_encoder.eval()\n    vae.eval()\n    unet.eval()\n\n    return tokenizer, noise_scheduler, text_encoder, vae, unet\n\n\ndef load_models_sdxl(\n    logger: logging.Logger,\n    model_name_or_path: str,\n    hf_variant: str | None = None,\n    vae_model: str | None = None,\n    base_embeddings: dict[str, str] = None,\n    dtype: torch.dtype | None = None,\n) -> tuple[\n    CLIPTokenizer,\n    CLIPTokenizer,\n    DDPMScheduler,\n    CLIPTextModel,\n    CLIPTextModel,\n    AutoencoderKL,\n    UNet2DConditionModel,\n]:\n    \"\"\"Load all models required for training, transfer them to the target training device and cast their weight\n    dtypes.\n    \"\"\"\n    base_embeddings = base_embeddings or {}\n\n    pipeline: StableDiffusionXLPipeline = load_pipeline(\n        logger=logger,\n        model_name_or_path=model_name_or_path,\n        pipeline_version=PipelineVersionEnum.SDXL,\n        variant=hf_variant,\n    )\n\n    for token, embedding_path in base_embeddings.items():\n        state_dict = load_state_dict(embedding_path)\n        pipeline.load_textual_inversion(\n            state_dict[\"clip_l\"],\n            token=token,\n            text_encoder=pipeline.text_encoder,\n            tokenizer=pipeline.tokenizer,\n        )\n        pipeline.load_textual_inversion(\n            state_dict[\"clip_g\"],\n            token=token,\n            text_encoder=pipeline.text_encoder_2,\n            tokenizer=pipeline.tokenizer_2,\n        )\n\n    # Extract sub-models from the pipeline.\n    tokenizer_1: CLIPTokenizer = pipeline.tokenizer\n    tokenizer_2: CLIPTokenizer = pipeline.tokenizer_2\n    text_encoder_1: CLIPTextModel = pipeline.text_encoder\n    text_encoder_2: CLIPTextModel = pipeline.text_encoder_2\n    vae: AutoencoderKL = pipeline.vae\n    unet: UNet2DConditionModel = pipeline.unet\n    noise_scheduler = DDPMScheduler(\n        beta_start=0.00085,\n        beta_end=0.012,\n        beta_schedule=\"scaled_linear\",\n        num_train_timesteps=1000,\n        clip_sample=False,\n        steps_offset=1,\n    )\n\n    if vae_model is not None:\n        vae: AutoencoderKL = AutoencoderKL.from_pretrained(vae_model)\n\n    # Disable gradient calculation for model weights to save memory.\n    text_encoder_1.requires_grad_(False)\n    text_encoder_2.requires_grad_(False)\n    vae.requires_grad_(False)\n    unet.requires_grad_(False)\n\n    if dtype is not None:\n        text_encoder_1 = text_encoder_1.to(dtype=dtype)\n        text_encoder_2 = text_encoder_2.to(dtype=dtype)\n        vae = vae.to(dtype=dtype)\n        unet = unet.to(dtype=dtype)\n\n    # Put models in 'eval' mode.\n    text_encoder_1.eval()\n    text_encoder_2.eval()\n    vae.eval()\n    unet.eval()\n\n    return tokenizer_1, tokenizer_2, noise_scheduler, text_encoder_1, text_encoder_2, vae, unet\n"
  },
  {
    "path": "src/invoke_training/_shared/stable_diffusion/textual_inversion.py",
    "content": "import logging\n\nimport torch\nfrom accelerate import Accelerator\nfrom transformers import CLIPTextModel, CLIPTokenizer, PreTrainedTokenizer\n\nfrom invoke_training._shared.checkpoints.serialization import load_state_dict\n\n\ndef _expand_placeholder_token(placeholder_token: str, num_vectors: int = 1) -> list[str]:\n    \"\"\"Expand a placeholder token into a list of placeholder tokens based on the number of embedding vectors being\n    trained.\n    \"\"\"\n    placeholder_tokens = [placeholder_token]\n    if num_vectors < 1:\n        raise ValueError(f\"num_vectors must be >1, but is '{num_vectors}'.\")\n    # Add dummy placeholder tokens if num_vectors > 1.\n    for i in range(1, num_vectors):\n        placeholder_tokens.append(f\"{placeholder_token}_{i}\")\n    return placeholder_tokens\n\n\ndef _add_tokens_to_tokenizer(placeholder_tokens: list[str], tokenizer: PreTrainedTokenizer):\n    \"\"\"Add new tokens to a tokenizer.\n\n    Raises:\n        ValueError: Raises if the tokenizer already contains one of the tokens in `placeholder_tokens`.\n    \"\"\"\n    num_added_tokens = tokenizer.add_tokens(placeholder_tokens)\n    if num_added_tokens != len(placeholder_tokens):\n        raise ValueError(\n            f\"The tokenizer already contains one of the tokens in '{placeholder_tokens}'. Please pass a different\"\n            \" 'placeholder_token' that is not already in the tokenizer.\"\n        )\n\n\ndef expand_placeholders_in_caption(caption: str, tokenizer: CLIPTokenizer) -> str:\n    \"\"\"Expand any multi-vector placeholder tokens in the caption.\n\n    For example, \"a dog in the style of my_placeholder\", could get expanded to \"a dog in the style of my_placeholder\n    my_placeholder_1 my_placeholder_2\".\n\n    This implementation is based on\n    https://github.com/huggingface/diffusers/blob/main/src/diffusers/loaders/textual_inversion.py#L144. This logic gets\n    applied automatically when running a full diffusers text-to-image pipeline.\n    \"\"\"\n    tokens = tokenizer.tokenize(caption)\n    unique_tokens = set(tokens)\n    for token in unique_tokens:\n        if token in tokenizer.added_tokens_encoder:\n            replacement = token\n            i = 1\n            while f\"{token}_{i}\" in tokenizer.added_tokens_encoder:\n                replacement += f\" {token}_{i}\"\n                i += 1\n\n            if replacement != token:\n                # If the replacement is different from the original token, then we double check that the replacement\n                # isn't already in the caption. If the replacement is already in the caption, this probably means that\n                # someone didn't realize that placeholder expansion is handled here.\n                assert replacement not in caption\n\n            caption = caption.replace(token, replacement)\n\n    return caption\n\n\ndef initialize_placeholder_tokens_from_initializer_token(\n    tokenizer: CLIPTokenizer,\n    text_encoder: CLIPTextModel,\n    initializer_token: str,\n    placeholder_token: str,\n    num_vectors: int,\n    logger: logging.Logger,\n) -> tuple[list[str], list[int]]:\n    # Convert the initializer_token to a token id.\n    initializer_token_ids = tokenizer.encode(initializer_token, add_special_tokens=False)\n    if len(initializer_token_ids) > 1:\n        logger.warning(\n            f\"The initializer_token '{initializer_token}' gets tokenized to {len(initializer_token_ids)} tokens. \"\n            \"Only the first token will be used. It is recommended to choose a different initializer_token that maps to \"\n            \"a single token.\"\n        )\n\n    initializer_token_id = initializer_token_ids[0]\n\n    # Expand the tokenizer / text_encoder to include the placeholder tokens.\n    placeholder_tokens = _expand_placeholder_token(placeholder_token, num_vectors=num_vectors)\n    _add_tokens_to_tokenizer(placeholder_tokens, tokenizer)\n    # Resize the token embeddings as we have added new special tokens to the tokenizer.\n    text_encoder.resize_token_embeddings(len(tokenizer))\n    placeholder_token_ids = tokenizer.convert_tokens_to_ids(placeholder_tokens)\n    # convert_tokens_to_ids returns a `int | list[int]` type, but since we pass in a list it should always return a\n    # list.\n    assert isinstance(placeholder_token_ids, list)\n\n    # Initialize the newly-added placeholder token(s) with the embeddings of the initializer token.\n    token_embeds = text_encoder.get_input_embeddings().weight.data\n    with torch.no_grad():\n        for placeholder_token_id in placeholder_token_ids:\n            token_embeds[placeholder_token_id] = token_embeds[initializer_token_id].clone()\n\n    return placeholder_tokens, placeholder_token_ids\n\n\ndef initialize_placeholder_tokens_from_initial_phrase(\n    tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, initial_phrase: str, placeholder_token: str\n) -> tuple[list[str], list[int]]:\n    # Convert the initial_phrase to token ids.\n    initial_token_ids = tokenizer.encode(initial_phrase, add_special_tokens=False)\n\n    # Expand the tokenizer / text_encoder to include one placeholder token for each token in the initial_phrase.\n    placeholder_tokens = _expand_placeholder_token(placeholder_token, num_vectors=len(initial_token_ids))\n    _add_tokens_to_tokenizer(placeholder_tokens, tokenizer)\n    # Resize the token embeddings as we have added new special tokens to the tokenizer.\n    text_encoder.resize_token_embeddings(len(tokenizer))\n    placeholder_token_ids = tokenizer.convert_tokens_to_ids(placeholder_tokens)\n    # convert_tokens_to_ids returns a `int | list[int]` type, but since we pass in a list it should always return a\n    # list.\n    assert isinstance(placeholder_token_ids, list)\n\n    # Initialize the newly-added placeholder token(s) with the embeddings of the initial phrase.\n    token_embeds = text_encoder.get_input_embeddings().weight.data\n    with torch.no_grad():\n        for placeholder_token_id, initial_token_id in zip(placeholder_token_ids, initial_token_ids):\n            token_embeds[placeholder_token_id] = token_embeds[initial_token_id].clone()\n\n    return placeholder_tokens, placeholder_token_ids\n\n\ndef initialize_placeholder_tokens_from_initial_embedding(\n    tokenizer: CLIPTokenizer,\n    text_encoder: CLIPTextModel,\n    initial_embedding_file: str,\n    placeholder_token: str,\n    num_vectors: int,\n) -> tuple[list[str], list[int]]:\n    # Expand the tokenizer / text_encoder to include the placeholder tokens.\n    placeholder_tokens = _expand_placeholder_token(placeholder_token, num_vectors=num_vectors)\n    _add_tokens_to_tokenizer(placeholder_tokens, tokenizer)\n    # Resize the token embeddings as we have added new special tokens to the tokenizer.\n    text_encoder.resize_token_embeddings(len(tokenizer))\n\n    state_dict = load_state_dict(initial_embedding_file)\n    if placeholder_token not in state_dict:\n        raise ValueError(\n            f\"The initial embedding at '{initial_embedding_file}' does not contain an embedding for placeholder token \"\n            f\"'{placeholder_token}'.\"\n        )\n\n    embeddings = state_dict[placeholder_token]\n    if embeddings.shape[0] != len(placeholder_tokens):\n        raise ValueError(\n            f\"The number of initial embeddings in '{initial_embedding_file}' ({embeddings.shape[0]}) does not match \"\n            f\"the expected number of placeholder tokens ({len(placeholder_tokens)}).\"\n        )\n\n    placeholder_token_ids = tokenizer.convert_tokens_to_ids(placeholder_tokens)\n    # convert_tokens_to_ids returns a `int | list[int]` type, but since we pass in a list it should always return a\n    # list.\n    assert isinstance(placeholder_token_ids, list)\n\n    # Initialize the newly-added placeholder token(s) with the loaded embeddings.\n    token_embeds = text_encoder.get_input_embeddings().weight.data\n    with torch.no_grad():\n        for i, token_id in enumerate(placeholder_token_ids):\n            token_embeds[token_id] = embeddings[i].clone()\n\n    return placeholder_tokens, placeholder_token_ids\n\n\ndef restore_original_embeddings(\n    tokenizer: CLIPTokenizer,\n    placeholder_token_ids: list[int],\n    accelerator: Accelerator,\n    text_encoder: CLIPTextModel,\n    orig_embeds_params: torch.Tensor,\n):\n    \"\"\"Restore the text_encoder embeddings that we are not actively training to make sure they don't change.\n\n    TODO(ryand): Look into whether this is actually necessary if we set requires_grad correctly.\n    \"\"\"\n    index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool)\n    index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False\n    index_updates = ~index_no_updates\n    with torch.no_grad():\n        unwrapped_text_encoder = accelerator.unwrap_model(text_encoder)\n        unwrapped_text_encoder.get_input_embeddings().weight[index_no_updates] = orig_embeds_params[index_no_updates]\n\n        target_std = unwrapped_text_encoder.get_input_embeddings().weight[index_no_updates].std()\n        new_embeddings = unwrapped_text_encoder.get_input_embeddings().weight[index_updates]\n        target_over_new_std = target_std / new_embeddings.std()\n\n        # Scale the new embeddings towards the target embeddings. Raise to the 0.1 power to avoid large changes.\n        new_embeddings = new_embeddings * (target_over_new_std**0.1)\n        unwrapped_text_encoder.get_input_embeddings().weight[index_updates] = new_embeddings\n"
  },
  {
    "path": "src/invoke_training/_shared/stable_diffusion/tokenize_captions.py",
    "content": "import torch\nfrom transformers import CLIPTokenizer\n\nfrom invoke_training._shared.stable_diffusion.textual_inversion import expand_placeholders_in_caption\n\n\ndef tokenize_captions(tokenizer: CLIPTokenizer, captions: list[str]) -> torch.Tensor:\n    \"\"\"Tokenize a list of caption.\n\n    Args:\n        tokenizer (CLIPTokenizer): The tokenizer.\n        caption (str): The caption.\n\n    Returns:\n        torch.Tensor: The token IDs.\n    \"\"\"\n    caption_token_ids = []\n    for caption in captions:\n        caption = expand_placeholders_in_caption(caption, tokenizer)\n        input = tokenizer(\n            caption,\n            max_length=tokenizer.model_max_length,\n            padding=\"max_length\",\n            truncation=True,\n            return_tensors=\"pt\",\n        )\n        caption_token_ids.append(input.input_ids[0, ...])\n\n    caption_token_ids = torch.stack(caption_token_ids)\n    return caption_token_ids\n"
  },
  {
    "path": "src/invoke_training/_shared/stable_diffusion/validation.py",
    "content": "import logging\nimport os\n\nimport numpy as np\nimport torch\nimport torch.utils.data\nfrom accelerate import Accelerator\nfrom accelerate.hooks import remove_hook_from_module\nfrom diffusers import (\n    AutoencoderKL,\n    DDPMScheduler,\n    StableDiffusionPipeline,\n    StableDiffusionXLPipeline,\n    UNet2DConditionModel,\n)\nfrom transformers import CLIPTextModel, CLIPTokenizer\n\nfrom invoke_training._shared.data.utils.resolution import Resolution\nfrom invoke_training.pipelines.callbacks import PipelineCallbacks, ValidationImage, ValidationImages\nfrom invoke_training.pipelines.stable_diffusion.lora.config import SdLoraConfig\nfrom invoke_training.pipelines.stable_diffusion_xl.lora.config import SdxlLoraConfig\n\n\ndef generate_validation_images_sd(  # noqa: C901\n    epoch: int,\n    step: int,\n    out_dir: str,\n    accelerator: Accelerator,\n    vae: AutoencoderKL,\n    text_encoder: CLIPTextModel,\n    tokenizer: CLIPTokenizer,\n    noise_scheduler: DDPMScheduler,\n    unet: UNet2DConditionModel,\n    config: SdLoraConfig,\n    logger: logging.Logger,\n    callbacks: list[PipelineCallbacks] | None = None,\n):\n    \"\"\"Generate validation images for the purpose of tracking image generation behaviour on fixed prompts throughout\n    training.\n    \"\"\"\n    # Record original model devices so that we can restore this state after running the pipeline with CPU model\n    # offloading.\n    unet_device = unet.device\n    vae_device = vae.device\n    text_encoder_device = text_encoder.device\n\n    # Create pipeline.\n    pipeline = StableDiffusionPipeline(\n        vae=vae,\n        text_encoder=text_encoder,\n        tokenizer=tokenizer,\n        unet=unet,\n        scheduler=noise_scheduler,\n        safety_checker=None,\n        feature_extractor=None,\n        # TODO(ryand): Add safety checker support.\n        requires_safety_checker=False,\n    )\n    if config.enable_cpu_offload_during_validation:\n        pipeline.enable_model_cpu_offload(accelerator.device.index or 0)\n    else:\n        pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    validation_resolution = Resolution.parse(config.data_loader.resolution)\n\n    validation_images = ValidationImages(images=[], epoch=epoch, step=step)\n\n    validation_step_dir = os.path.join(out_dir, \"validation\", f\"epoch_{epoch:0>8}-step_{step:0>8}\")\n    logger.info(f\"Generating validation images ({validation_step_dir}).\")\n\n    # Run inference.\n    with torch.no_grad():\n        for prompt_idx in range(len(config.validation_prompts)):\n            positive_prompt = config.validation_prompts[prompt_idx]\n            negative_prompt = None\n            if config.negative_validation_prompts is not None:\n                negative_prompt = config.negative_validation_prompts[prompt_idx]\n            logger.info(f\"Validation prompt {prompt_idx}, pos: '{positive_prompt}', neg: '{negative_prompt or ''}'\")\n\n            generator = torch.Generator(device=accelerator.device)\n            if config.seed is not None:\n                generator = generator.manual_seed(config.seed)\n\n            images = []\n            for _ in range(config.num_validation_images_per_prompt):\n                with accelerator.autocast():\n                    images.append(\n                        pipeline(\n                            positive_prompt,\n                            num_inference_steps=30,\n                            generator=generator,\n                            height=validation_resolution.height,\n                            width=validation_resolution.width,\n                            negative_prompt=negative_prompt,\n                        ).images[0]\n                    )\n\n            # Save images to disk.\n            validation_prompt_dir = os.path.join(validation_step_dir, f\"prompt_{prompt_idx:0>4}\")\n            os.makedirs(validation_prompt_dir)\n            for image_idx, image in enumerate(images):\n                image_path = os.path.join(validation_prompt_dir, f\"{image_idx:0>4}.jpg\")\n                validation_images.images.append(\n                    ValidationImage(file_path=image_path, prompt=positive_prompt, image_idx=image_idx)\n                )\n                image.save(image_path)\n\n            # Log images to trackers. Currently, only tensorboard is supported.\n            for tracker in accelerator.trackers:\n                if tracker.name == \"tensorboard\":\n                    np_images = np.stack([np.asarray(img) for img in images])\n                    tracker.writer.add_images(\n                        f\"validation (prompt {prompt_idx})\",\n                        np_images,\n                        step,\n                        dataformats=\"NHWC\",\n                    )\n\n    del pipeline\n    torch.cuda.empty_cache()\n\n    # Remove hooks from models.\n    # HACK(ryand): Hooks get added when calling `pipeline.enable_model_cpu_offload(...)`, but `StableDiffusionPipeline`\n    # does not offer a way to clean them up so we have to do this manually.\n    for model in [unet, vae, text_encoder]:\n        remove_hook_from_module(model)\n\n    # Restore models to original devices.\n    unet.to(unet_device)\n    vae.to(vae_device)\n    text_encoder.to(text_encoder_device)\n\n    # Run callbacks.\n    if callbacks is not None:\n        for cb in callbacks:\n            cb.on_save_validation_images(images=validation_images)\n\n\ndef generate_validation_images_sdxl(  # noqa: C901\n    epoch: int,\n    step: int,\n    out_dir: str,\n    accelerator: Accelerator,\n    vae: AutoencoderKL,\n    text_encoder_1: CLIPTextModel,\n    text_encoder_2: CLIPTextModel,\n    tokenizer_1: CLIPTokenizer,\n    tokenizer_2: CLIPTokenizer,\n    noise_scheduler: DDPMScheduler,\n    unet: UNet2DConditionModel,\n    config: SdxlLoraConfig,\n    logger: logging.Logger,\n    callbacks: list[PipelineCallbacks] | None = None,\n):\n    \"\"\"Generate validation images for the purpose of tracking image generation behaviour on fixed prompts throughout\n    training.\n    \"\"\"\n    # Record original model devices so that we can restore this state after running the pipeline with CPU model\n    # offloading.\n    unet_device = unet.device\n    vae_device = vae.device\n    text_encoder_1_device = text_encoder_1.device\n    text_encoder_2_device = text_encoder_2.device\n\n    # Create pipeline.\n    pipeline = StableDiffusionXLPipeline(\n        vae=vae,\n        text_encoder=text_encoder_1,\n        text_encoder_2=text_encoder_2,\n        tokenizer=tokenizer_1,\n        tokenizer_2=tokenizer_2,\n        unet=unet,\n        scheduler=noise_scheduler,\n    )\n    if config.enable_cpu_offload_during_validation:\n        pipeline.enable_model_cpu_offload(accelerator.device.index or 0)\n    else:\n        pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    validation_resolution = Resolution.parse(config.data_loader.resolution)\n\n    validation_images = ValidationImages(images=[], epoch=epoch, step=step)\n\n    validation_step_dir = os.path.join(out_dir, \"validation\", f\"epoch_{epoch:0>8}-step_{step:0>8}\")\n    logger.info(f\"Generating validation images ({validation_step_dir}).\")\n\n    # Run inference.\n    with torch.no_grad():\n        for prompt_idx in range(len(config.validation_prompts)):\n            positive_prompt = config.validation_prompts[prompt_idx]\n            negative_prompt = None\n            if config.negative_validation_prompts is not None:\n                negative_prompt = config.negative_validation_prompts[prompt_idx]\n            logger.info(f\"Validation prompt {prompt_idx}, pos: '{positive_prompt}', neg: '{negative_prompt or ''}'\")\n\n            generator = torch.Generator(device=accelerator.device)\n            if config.seed is not None:\n                generator = generator.manual_seed(config.seed)\n\n            images = []\n            for _ in range(config.num_validation_images_per_prompt):\n                with accelerator.autocast():\n                    images.append(\n                        pipeline(\n                            positive_prompt,\n                            num_inference_steps=30,\n                            generator=generator,\n                            height=validation_resolution.height,\n                            width=validation_resolution.width,\n                            negative_prompt=negative_prompt,\n                        ).images[0]\n                    )\n\n            # Save images to disk.\n            validation_prompt_dir = os.path.join(validation_step_dir, f\"prompt_{prompt_idx:0>4}\")\n            os.makedirs(validation_prompt_dir)\n            for image_idx, image in enumerate(images):\n                image_path = os.path.join(validation_prompt_dir, f\"{image_idx:0>4}.jpg\")\n                validation_images.images.append(\n                    ValidationImage(file_path=image_path, prompt=positive_prompt, image_idx=image_idx)\n                )\n                image.save(image_path)\n\n            # Log images to trackers. Currently, only tensorboard is supported.\n            for tracker in accelerator.trackers:\n                if tracker.name == \"tensorboard\":\n                    np_images = np.stack([np.asarray(img) for img in images])\n                    tracker.writer.add_images(\n                        f\"validation (prompt {prompt_idx})\",\n                        np_images,\n                        step,\n                        dataformats=\"NHWC\",\n                    )\n\n    del pipeline\n    torch.cuda.empty_cache()\n\n    # Remove hooks from models.\n    # HACK(ryand): Hooks get added when calling `pipeline.enable_model_cpu_offload(...)`, but\n    # `StableDiffusionXLPipeline` does not offer a way to clean them up so we have to do this manually.\n    for model in [unet, vae, text_encoder_1, text_encoder_2]:\n        remove_hook_from_module(model)\n\n    # Restore models to original devices.\n    unet.to(unet_device)\n    vae.to(vae_device)\n    text_encoder_1.to(text_encoder_1_device)\n    text_encoder_2.to(text_encoder_2_device)\n\n    # Run callbacks.\n    if callbacks is not None:\n        for cb in callbacks:\n            cb.on_save_validation_images(images=validation_images)\n"
  },
  {
    "path": "src/invoke_training/_shared/tools/__init__.py",
    "content": ""
  },
  {
    "path": "src/invoke_training/_shared/tools/generate_images.py",
    "content": "import os\nfrom pathlib import Path\nfrom typing import Optional\n\nimport torch\nfrom tqdm import tqdm\n\nfrom invoke_training._shared.data.datasets.image_pair_preference_dataset import ImagePairPreferenceDataset\nfrom invoke_training._shared.stable_diffusion.model_loading_utils import (\n    PipelineVersionEnum,\n    load_pipeline,\n)\n\n\ndef generate_images(\n    out_dir: str,\n    model: str,\n    hf_variant: str | None,\n    pipeline_version: PipelineVersionEnum,\n    prompts: list[str],\n    set_size: int,\n    num_sets: int,\n    height: int,\n    width: int,\n    loras: Optional[list[tuple[Path, float]]] = None,\n    ti_embeddings: Optional[list[str]] = None,\n    seed: int = 0,\n    torch_dtype: torch.dtype = torch.float16,\n    torch_device: str = \"cuda\",\n    enable_cpu_offload: bool = False,\n):\n    \"\"\"Generate a set of images and store them in a directory. Typically used to generate a datasets for prior\n    preservation / regularization.\n\n    Args:\n        out_dir (str): The output directory to create.\n        model (str): The name or path of the diffusers pipeline to generate with.\n        sd_version (PipelineVersionEnum): The model version.\n        prompt (str): The prompt to generate images with.\n        set_size (int): The number of images in a 'set' for a given prompt.\n        num_sets (int): The number of 'sets' to generate for each prompt.\n        height (int): The output image height in pixels (recommended to match the resolution that the model was trained\n            with).\n        width (int): The output image width in pixels (recommended to match the resolution that the model was trained\n            with).\n        loras (list[tuple[Path, float]], optional): Paths to LoRA models to apply to the base model with associated\n            weights.\n        ti_embeddings (list[str], optional): Paths to TI embeddings to apply to the base model.\n        seed (int, optional): A seed for repeatability. Defaults to 0.\n        torch_dtype (torch.dtype, optional): The torch dtype. Defaults to torch.float16.\n        torch_device (str, optional): The torch device. Defaults to \"cuda\".\n        enable_cpu_offload (bool, optional): If True, models will be loaded onto the GPU one by one to conserve VRAM.\n            Defaults to False.\n    \"\"\"\n\n    pipeline = load_pipeline(model_name_or_path=model, pipeline_version=pipeline_version, variant=hf_variant)\n\n    loras = loras or []\n    for lora in loras:\n        lora_path, lora_scale = lora\n        pipeline.load_lora_weights(str(lora_path), weight_name=str(lora_path.name))\n        pipeline.fuse_lora(lora_scale=lora_scale)\n\n    ti_embeddings = ti_embeddings or []\n    for ti_embedding in ti_embeddings:\n        pipeline.load_textual_inversion(ti_embedding)\n\n    pipeline.to(torch_dtype=torch_dtype)\n    if enable_cpu_offload:\n        pipeline.enable_model_cpu_offload()\n    else:\n        pipeline.to(torch_device=torch_device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    generator = torch.Generator(device=torch_device)\n    if seed is not None:\n        generator = generator.manual_seed(seed)\n\n    os.makedirs(out_dir)\n\n    metadata = []\n\n    total_images = num_sets * len(prompts) * set_size\n    with torch.no_grad(), tqdm(total=total_images) as pbar:\n        for prompt_idx in range(len(prompts)):\n            for set_idx in range(num_sets):\n                set_dir = os.path.join(out_dir, f\"prompt-{prompt_idx:0>4}\", f\"set-{set_idx:0>4}\")\n                os.makedirs(set_dir)\n                set_metadata_dict = {\"prompt\": prompts[prompt_idx]}\n                for image_idx in range(set_size):\n                    image = pipeline(\n                        prompts[prompt_idx],\n                        num_inference_steps=30,\n                        generator=generator,\n                        height=height,\n                        width=width,\n                    ).images[0]\n\n                    image_path = os.path.join(set_dir, f\"image-{image_idx}.jpg\")\n                    image.save(image_path)\n                    set_metadata_dict[f\"image_{image_idx}\"] = os.path.relpath(image_path, start=out_dir)\n                    set_metadata_dict[f\"prefer_{image_idx}\"] = False\n                    pbar.update(1)\n                metadata.append(set_metadata_dict)\n\n    ImagePairPreferenceDataset.save_metadata(metadata=metadata, dataset_dir=out_dir)\n"
  },
  {
    "path": "src/invoke_training/_shared/utils/import_xformers.py",
    "content": "def import_xformers():\n    try:\n        import xformers  # noqa: F401\n    except ImportError:\n        raise ImportError(\n            \"xformers is not installed. Either set `xformers = False` in your training config, or install it using \"\n            '`pip install \".[xformers]\"`.'\n        )\n"
  },
  {
    "path": "src/invoke_training/_shared/utils/jsonl.py",
    "content": "import json\nfrom pathlib import Path\nfrom typing import Any\n\n\ndef load_jsonl(jsonl_path: Path | str) -> list[Any]:\n    \"\"\"Load a JSONL file.\"\"\"\n    data = []\n    with open(jsonl_path) as f:\n        while (line := f.readline().strip()) != \"\":\n            data.append(json.loads(line))\n    return data\n\n\ndef save_jsonl(data: list[Any], jsonl_path: Path | str) -> None:\n    \"\"\"Save a list of objects to a JSONL file.\"\"\"\n    with open(jsonl_path, \"w\") as f:\n        for line in data:\n            f.write(json.dumps(line) + \"\\n\")\n"
  },
  {
    "path": "src/invoke_training/config/__init__.py",
    "content": ""
  },
  {
    "path": "src/invoke_training/config/base_pipeline_config.py",
    "content": "import typing\nfrom typing import Optional\n\nfrom invoke_training.config.config_base_model import ConfigBaseModel\n\n\nclass BasePipelineConfig(ConfigBaseModel):\n    \"\"\"A base config with fields that should be inherited by all pipelines.\"\"\"\n\n    type: str\n\n    seed: Optional[int] = None\n    \"\"\"A randomization seed for reproducible training. Set to any constant integer for consistent training results. If\n    set to `null`, training will be non-deterministic.\n    \"\"\"\n\n    base_output_dir: str\n    \"\"\"The output directory where the training outputs (model checkpoints, logs, intermediate predictions) will be\n    written. A subdirectory will be created with a timestamp for each new training run.\n    \"\"\"\n\n    report_to: typing.Literal[\"all\", \"tensorboard\", \"wandb\", \"comet_ml\"] = \"tensorboard\"\n    \"\"\"The integration to report results and logs to. This value is passed to Hugging Face Accelerate. See\n    `accelerate.Accelerator.log_with` for more details.\n    \"\"\"\n\n    max_train_steps: int | None = None\n    \"\"\"Total number of training steps to perform. One training step is one gradient update.\n\n    One of `max_train_steps` or `max_train_epochs` should be set.\n    \"\"\"\n\n    max_train_epochs: int | None = None\n    \"\"\"Total number of training epochs to perform. One epoch is one pass over the entire dataset.\n\n    One of `max_train_steps` or `max_train_epochs` should be set.\n    \"\"\"\n\n    save_every_n_epochs: int | None = None\n    \"\"\"The interval (in epochs) at which to save checkpoints.\n\n    One of `save_every_n_epochs` or `save_every_n_steps` should be set.\n    \"\"\"\n\n    save_every_n_steps: int | None = None\n    \"\"\"The interval (in steps) at which to save checkpoints.\n\n    One of `save_every_n_epochs` or `save_every_n_steps` should be set.\n    \"\"\"\n\n    validate_every_n_epochs: int | None = None\n    \"\"\"The interval (in epochs) at which validation images will be generated.\n\n    One of `validate_every_n_epochs` or `validate_every_n_steps` should be set.\n    \"\"\"\n\n    validate_every_n_steps: int | None = None\n    \"\"\"The interval (in steps) at which validation images will be generated.\n\n    One of `validate_every_n_epochs` or `validate_every_n_steps` should be set.\n    \"\"\"\n"
  },
  {
    "path": "src/invoke_training/config/config_base_model.py",
    "content": "from pydantic import BaseModel, ConfigDict\n\n\nclass ConfigBaseModel(BaseModel):\n    \"\"\"Base model for all invoke training configuration models.\"\"\"\n\n    # Configure to raise if extra fields are passed in.\n    model_config = ConfigDict(extra=\"forbid\")\n"
  },
  {
    "path": "src/invoke_training/config/data/__init__.py",
    "content": ""
  },
  {
    "path": "src/invoke_training/config/data/data_loader_config.py",
    "content": "from typing import Literal, Optional\n\nfrom invoke_training.config.config_base_model import ConfigBaseModel\nfrom invoke_training.config.data.dataset_config import (\n    ImageCaptionDatasetConfig,\n    ImageDirDatasetConfig,\n)\n\n\nclass AspectRatioBucketConfig(ConfigBaseModel):\n    target_resolution: int\n    \"\"\"The target resolution for all aspect ratios. When generating aspect ratio buckets, the resolution of each bucket\n    is selected to have roughly `target_resolution * target_resolution` pixels (i.e. a square image with dimensions\n    equal to `target_resolution`).\n    \"\"\"\n\n    start_dim: int\n    \"\"\"Aspect ratio bucket resolutions are generated as follows:\n\n    - Iterate over 'first' dimension values from `start_dim` to `end_dim` in steps of size `divisible_by`.\n    - Calculate the 'second' dimension to be as close as possible to the total number of pixels in `target_resolution`,\n    while still being divisible by `divisible_by`.\n\n    tip: Choosing aspect ratio buckets\n        The aspect ratio bucket resolutions are logged at the start of training with the number of images in each\n        bucket. Review these logs to make sure that images are being split into buckets as expected.\n\n        Highly fragmented splits (i.e. many buckets with few examples in each) can 1) limit the extent to which examples\n        can be shuffled, and 2) slow down training if there are many partial batches.\n    \"\"\"\n    end_dim: int\n    \"\"\"See explanation under\n    [`start_dim`][invoke_training.config.data.data_loader_config.AspectRatioBucketConfig.start_dim].\n    \"\"\"\n\n    divisible_by: int\n    \"\"\"See explanation under\n    [`start_dim`][invoke_training.config.data.data_loader_config.AspectRatioBucketConfig.start_dim].\n    \"\"\"\n\n\nclass ImageCaptionSDDataLoaderConfig(ConfigBaseModel):\n    type: Literal[\"IMAGE_CAPTION_SD_DATA_LOADER\"] = \"IMAGE_CAPTION_SD_DATA_LOADER\"\n\n    dataset: ImageCaptionDatasetConfig\n\n    aspect_ratio_buckets: AspectRatioBucketConfig | None = None\n\n    resolution: int | tuple[int, int] = 512\n    \"\"\"The resolution for input images. Either a scalar integer representing the square resolution height and width, or\n    a (height, width) tuple. All of the images in the dataset will be resized to this resolution unless the\n    `aspect_ratio_buckets` config is set.\n    \"\"\"\n\n    center_crop: bool = True\n    \"\"\"If True, input images will be center-cropped to the target resolution.\n    If False, input images will be randomly cropped to the target resolution.\n    \"\"\"\n\n    random_flip: bool = False\n    \"\"\"Whether random flip augmentations should be applied to input images.\n    \"\"\"\n\n    caption_prefix: str | None = None\n    \"\"\"A prefix that will be prepended to all captions. If None, no prefix will be added.\n    \"\"\"\n\n    dataloader_num_workers: int = 0\n    \"\"\"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\n    \"\"\"\n\n\nclass ImageCaptionFluxDataLoaderConfig(ConfigBaseModel):\n    type: Literal[\"IMAGE_CAPTION_FLUX_DATA_LOADER\"] = \"IMAGE_CAPTION_FLUX_DATA_LOADER\"\n\n    dataset: ImageCaptionDatasetConfig\n\n    aspect_ratio_buckets: AspectRatioBucketConfig | None = None\n\n    resolution: int | tuple[int, int] = 512\n    \"\"\"The resolution for input images. Either a scalar integer representing the square resolution height and width, or\n    a (height, width) tuple. All of the images in the dataset will be resized to this resolution unless the\n    `aspect_ratio_buckets` config is set.\n    \"\"\"\n\n    center_crop: bool = True\n    \"\"\"If True, input images will be center-cropped to the target resolution.\n    If False, input images will be randomly cropped to the target resolution.\n    \"\"\"\n\n    random_flip: bool = False\n    \"\"\"Whether random flip augmentations should be applied to input images.\n    \"\"\"\n\n    caption_prefix: str | None = None\n    \"\"\"A prefix that will be prepended to all captions. If None, no prefix will be added.\n    \"\"\"\n\n    dataloader_num_workers: int = 0\n    \"\"\"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\n    \"\"\"\n\n\nclass DreamboothSDDataLoaderConfig(ConfigBaseModel):\n    type: Literal[\"DREAMBOOTH_SD_DATA_LOADER\"] = \"DREAMBOOTH_SD_DATA_LOADER\"\n\n    instance_caption: str\n    class_caption: Optional[str] = None\n\n    instance_dataset: ImageDirDatasetConfig\n    class_dataset: Optional[ImageDirDatasetConfig] = None\n\n    class_data_loss_weight: float = 1.0\n    \"\"\"The loss weight applied to class dataset examples. Instance dataset examples have an implicit loss weight of 1.0.\n    \"\"\"\n\n    aspect_ratio_buckets: AspectRatioBucketConfig | None = None\n    \"\"\"The aspect ratio bucketing configuration. If None, aspect ratio bucketing is disabled, and all images will be\n    resized to the same resolution.\n    \"\"\"\n\n    resolution: int | tuple[int, int] = 512\n    \"\"\"The resolution for input images. Either a scalar integer representing the square resolution height and width, or\n    a (height, width) tuple. All of the images in the dataset will be resized to this resolution unless the\n    `aspect_ratio_buckets` config is set.\n    \"\"\"\n\n    center_crop: bool = True\n    \"\"\"If True, input images will be center-cropped to the target resolution.\n    If False, input images will be randomly cropped to the target resolution.\n    \"\"\"\n\n    random_flip: bool = False\n    \"\"\"Whether random flip augmentations should be applied to input images.\n    \"\"\"\n\n    dataloader_num_workers: int = 0\n    \"\"\"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\n    \"\"\"\n\n\nclass TextualInversionSDDataLoaderConfig(ConfigBaseModel):\n    type: Literal[\"TEXTUAL_INVERSION_SD_DATA_LOADER\"] = \"TEXTUAL_INVERSION_SD_DATA_LOADER\"\n\n    dataset: ImageDirDatasetConfig | ImageCaptionDatasetConfig\n\n    caption_preset: Literal[\"style\", \"object\"] | None = None\n\n    caption_templates: list[str] | None = None\n    \"\"\"A list of caption templates with a single template argument 'slot' in each.\n    E.g.:\n\n    - \"a photo of a {}\"\n    - \"a rendering of a {}\"\n    - \"a cropped photo of the {}\"\n    \"\"\"\n\n    keep_original_captions: bool = False\n    \"\"\"If `True`, then the captions generated as a result of the `caption_preset` or `caption_templates` will be used as\n    prefixes for the original captions. If `False`, then the generated captions will replace the original captions.\n    \"\"\"\n\n    aspect_ratio_buckets: AspectRatioBucketConfig | None = None\n    \"\"\"The aspect ratio bucketing configuration. If None, aspect ratio bucketing is disabled, and all images will be\n    resized to the same resolution.\n    \"\"\"\n\n    resolution: int | tuple[int, int] = 512\n    \"\"\"The resolution for input images. Either a scalar integer representing the square resolution height and width, or\n    a (height, width) tuple. All of the images in the dataset will be resized to this resolution unless the\n    `aspect_ratio_buckets` config is set.\n    \"\"\"\n\n    center_crop: bool = True\n    \"\"\"If True, input images will be center-cropped to the target resolution.\n    If False, input images will be randomly cropped to the target resolution.\n    \"\"\"\n\n    random_flip: bool = False\n    \"\"\"Whether random flip augmentations should be applied to input images.\n    \"\"\"\n\n    shuffle_caption_delimiter: str | None = None\n    \"\"\"If `None`, then no caption shuffling is applied. If set, then captions are split on this delimiter and shuffled.\n    \"\"\"\n\n    dataloader_num_workers: int = 0\n    \"\"\"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\n    \"\"\"\n"
  },
  {
    "path": "src/invoke_training/config/data/dataset_config.py",
    "content": "from typing import Annotated, Literal, Optional, Union\n\nfrom pydantic import Field\n\nfrom invoke_training.config.config_base_model import ConfigBaseModel\n\n\nclass HFHubImageCaptionDatasetConfig(ConfigBaseModel):\n    type: Literal[\"HF_HUB_IMAGE_CAPTION_DATASET\"] = \"HF_HUB_IMAGE_CAPTION_DATASET\"\n\n    dataset_name: str\n    \"\"\"The name of a Hugging Face dataset.\n    \"\"\"\n\n    dataset_config_name: Optional[str] = None\n    \"\"\"The Hugging Face dataset config name. Leave as None if there's only one config.\n    \"\"\"\n\n    hf_cache_dir: Optional[str] = None\n    \"\"\"The Hugging Face cache directory to use for dataset downloads.\n    If None, the default value will be used (usually '~/.cache/huggingface/datasets').\n    \"\"\"\n\n    image_column: str = \"image\"\n    \"\"\"The name of the dataset column that contains image paths.\n    \"\"\"\n\n    caption_column: str = \"text\"\n    \"\"\"The name of the dataset column that contains captions.\n    \"\"\"\n\n\nclass ImageCaptionJsonlDatasetConfig(ConfigBaseModel):\n    type: Literal[\"IMAGE_CAPTION_JSONL_DATASET\"] = \"IMAGE_CAPTION_JSONL_DATASET\"\n\n    jsonl_path: str\n    \"\"\"The path to a JSONL file containing image paths and captions.\"\"\"\n\n    image_column: str = \"image\"\n    \"\"\"The name of the dataset column that contains image paths.\n    \"\"\"\n\n    caption_column: str = \"text\"\n    \"\"\"The name of the dataset column that contains captions.\n    \"\"\"\n\n    keep_in_memory: bool = False\n    \"\"\"If `True`, load all images into memory on initialization so that they can be accessed quickly. If `False`, images\n    are loaded from disk each time they are accessed. Setting to `True` improves performance for datasets that are small\n    enough to be kept in memory.\n    \"\"\"\n\n\nclass ImageDirDatasetConfig(ConfigBaseModel):\n    type: Literal[\"IMAGE_DIR_DATASET\"] = \"IMAGE_DIR_DATASET\"\n\n    dataset_dir: str\n    \"\"\"The directory to load images from.\"\"\"\n\n    keep_in_memory: bool = False\n    \"\"\"If `True`, load all images into memory on initialization so that they can be accessed quickly. If `False`, images\n    are loaded from disk each time they are accessed. Setting to `True` improves performance for datasets that are small\n    enough to be kept in memory.\n    \"\"\"\n\n\nclass ImageCaptionDirDatasetConfig(ConfigBaseModel):\n    type: Literal[\"IMAGE_CAPTION_DIR_DATASET\"] = \"IMAGE_CAPTION_DIR_DATASET\"\n\n    dataset_dir: str\n    \"\"\"The directory to load images from.\"\"\"\n\n    keep_in_memory: bool = False\n    \"\"\"If `True`, load all images into memory on initialization so that they can be accessed quickly. If `False`, images\n    are loaded from disk each time they are accessed. Setting to `True` improves performance for datasets that are small\n    enough to be kept in memory.\n    \"\"\"\n\n\n# Datasets that produce image-caption pairs.\nImageCaptionDatasetConfig = Annotated[\n    Union[HFHubImageCaptionDatasetConfig, ImageCaptionJsonlDatasetConfig, ImageCaptionDirDatasetConfig],\n    Field(discriminator=\"type\"),\n]\n"
  },
  {
    "path": "src/invoke_training/config/optimizer/__init__.py",
    "content": ""
  },
  {
    "path": "src/invoke_training/config/optimizer/optimizer_config.py",
    "content": "import typing\n\nfrom invoke_training.config.config_base_model import ConfigBaseModel\n\n\nclass AdamOptimizerConfig(ConfigBaseModel):\n    optimizer_type: typing.Literal[\"AdamW\"] = \"AdamW\"\n\n    learning_rate: float = 1e-4\n    \"\"\"Initial learning rate to use (after the potential warmup period). Note that in some training pipelines this can\n    be overriden for a specific group of params: https://pytorch.org/docs/stable/optim.html#per-parameter-options\n    (E.g. see `text_encoder_learning_rate` and `unet_learning_rate`)\n    \"\"\"\n\n    beta1: float = 0.9\n    beta2: float = 0.999\n    weight_decay: float = 1e-2\n    epsilon: float = 1e-8\n\n    use_8bit: bool = False\n    \"\"\"Use an 8-bit version of the Adam optimizer. This requires the bitsandbytes library to be installed. use_8bit\n    reduces the VRAM usage of the optimizer, but increases the risk of issues with numerical stability.\n    \"\"\"\n\n\nclass ProdigyOptimizerConfig(ConfigBaseModel):\n    optimizer_type: typing.Literal[\"Prodigy\"] = \"Prodigy\"\n\n    learning_rate: float = 1.0\n    \"\"\"The learning rate. For the Prodigy optimizer, the learning rate is adjusted dynamically. A value of 1.0 is\n    recommended. Note that in some training pipelines this can be overriden for a specific group of params:\n    https://pytorch.org/docs/stable/optim.html#per-parameter-options (E.g. see `text_encoder_learning_rate` and\n    `unet_learning_rate`)\n    \"\"\"\n\n    weight_decay: float = 0.0\n    use_bias_correction: bool = False\n    safeguard_warmup: bool = False\n"
  },
  {
    "path": "src/invoke_training/config/pipeline_config.py",
    "content": "from typing import Annotated, Union\n\nfrom pydantic import Field\n\nfrom invoke_training.pipelines._experimental.sd_dpo_lora.config import SdDirectPreferenceOptimizationLoraConfig\nfrom invoke_training.pipelines.flux.lora.config import FluxLoraConfig\nfrom invoke_training.pipelines.stable_diffusion.lora.config import SdLoraConfig\nfrom invoke_training.pipelines.stable_diffusion.textual_inversion.config import SdTextualInversionConfig\nfrom invoke_training.pipelines.stable_diffusion_xl.finetune.config import SdxlFinetuneConfig\nfrom invoke_training.pipelines.stable_diffusion_xl.lora.config import SdxlLoraConfig\nfrom invoke_training.pipelines.stable_diffusion_xl.lora_and_textual_inversion.config import (\n    SdxlLoraAndTextualInversionConfig,\n)\nfrom invoke_training.pipelines.stable_diffusion_xl.textual_inversion.config import SdxlTextualInversionConfig\n\nPipelineConfig = Annotated[\n    Union[\n        FluxLoraConfig,\n        SdLoraConfig,\n        SdxlLoraConfig,\n        SdTextualInversionConfig,\n        SdxlTextualInversionConfig,\n        SdxlLoraAndTextualInversionConfig,\n        SdxlFinetuneConfig,\n        SdDirectPreferenceOptimizationLoraConfig,\n    ],\n    Field(discriminator=\"type\"),\n]\n"
  },
  {
    "path": "src/invoke_training/model_merge/__init__.py",
    "content": ""
  },
  {
    "path": "src/invoke_training/model_merge/extract_lora.py",
    "content": "import torch\nimport tqdm\nfrom peft.peft_model import PeftModel\n\n# All original base model weights in a PeftModel have this prefix and suffix.\nPEFT_BASE_LAYER_PREFIX = \"base_model.model.\"\nPEFT_BASE_LAYER_SUFFIX = \".base_layer.weight\"\n\n\ndef get_patched_base_weights_from_peft_model(peft_model: PeftModel) -> dict[str, torch.Tensor]:\n    \"\"\"Get a state_dict containing the base model weights *thath are patched* in the provided PeftModel. I.e. only\n    return base model weights that have associated LoRa layers, but don't return the LoRA layers.\n    \"\"\"\n    state_dict = peft_model.state_dict()\n    out_state_dict: dict[str, torch.Tensor] = {}\n    for weight_name in state_dict:\n        # Weights that end with \".base_layer.weight\" are the original weights for LoRA layers.\n        if weight_name.endswith(PEFT_BASE_LAYER_SUFFIX):\n            # Extract the base module name.\n            module_name = weight_name[: -len(PEFT_BASE_LAYER_SUFFIX)]\n            assert module_name.startswith(PEFT_BASE_LAYER_PREFIX)\n            module_name = module_name[len(PEFT_BASE_LAYER_PREFIX) :]\n\n            out_state_dict[module_name] = state_dict[weight_name]\n\n    return out_state_dict\n\n\ndef get_state_dict_diff(\n    state_dict_1: dict[str, torch.Tensor], state_dict_2: dict[str, torch.Tensor]\n) -> dict[str, torch.Tensor]:\n    \"\"\"Return the difference between two state_dicts: state_dict_1 - state_dict_2.\"\"\"\n    return {key: state_dict_1[key] - state_dict_2[key] for key in state_dict_1}\n\n\n@torch.no_grad()\ndef extract_lora_from_diffs(\n    diffs: dict[str, torch.Tensor], rank: int, clamp_quantile: float, out_dtype: torch.dtype\n) -> dict[str, tuple[torch.Tensor, torch.Tensor]]:\n    lora_weights = {}\n    for lora_name, mat in tqdm.tqdm(list(diffs.items())):\n        # Use full precision for the intermediate calculations.\n        mat = mat.to(torch.float32)\n\n        is_conv2d = False\n        if len(mat.shape) == 4:  # Conv2D\n            is_conv2d = True\n            out_dim, in_dim, kernel_h, kernel_w = mat.shape\n            # Reshape to (out_dim, in_dim * kernel_h * kernel_w).\n            mat = mat.flatten(start_dim=1)\n        elif len(mat.shape) == 2:  # Linear\n            out_dim, in_dim = mat.shape\n        else:\n            raise ValueError(f\"Unexpected weight shape: {mat.shape}\")\n\n        # LoRA rank cannot exceed the original dimensions.\n        assert rank < in_dim\n        assert rank < out_dim\n\n        u: torch.Tensor\n        s: torch.Tensor\n        v_h: torch.Tensor\n        u, s, v_h = torch.linalg.svd(mat)\n\n        # Apply the Eckart-Young-Mirsky theorem.\n        # https://en.wikipedia.org/wiki/Low-rank_approximation#Proof_of_Eckart%E2%80%93Young%E2%80%93Mirsky_theorem_(for_Frobenius_norm)\n        u = u[:, :rank]\n        s = s[:rank]\n        u = u @ torch.diag(s)\n\n        v_h = v_h[:rank, :]\n\n        # At this point, u is the lora_up (a.k.a. lora_B) weight, and v_h is the lora_down (a.k.a. lora_A) weight.\n        # The reason we don't use more appropriate variable names is to keep memory usage low - we want the old tensors\n        # to get cleaned up after each operation.\n\n        # Clamp the outliers.\n        dist = torch.cat([u.flatten(), v_h.flatten()])\n        hi_val = torch.quantile(dist, clamp_quantile)\n        low_val = -hi_val\n\n        u = u.clamp(low_val, hi_val)\n        v_h = v_h.clamp(low_val, hi_val)\n\n        if is_conv2d:\n            u = u.reshape(out_dim, rank, 1, 1)\n            v_h = v_h.reshape(rank, in_dim, kernel_h, kernel_w)\n\n        u = u.to(dtype=out_dtype).contiguous()\n        v_h = v_h.to(dtype=out_dtype).contiguous()\n\n        lora_weights[lora_name] = (u, v_h)\n    return lora_weights\n"
  },
  {
    "path": "src/invoke_training/model_merge/merge_models.py",
    "content": "from typing import Literal\n\nimport torch\nimport tqdm\n\nfrom invoke_training.model_merge.utils.normalize_weights import normalize_weights\n\n\n@torch.no_grad()\ndef merge_models(\n    state_dicts: list[dict[str, torch.Tensor]], weights: list[float], merge_method: Literal[\"LERP\", \"SLERP\"] = \"LERP\"\n):\n    \"\"\"Merge multiple models into a single model.\n\n    Args:\n        state_dicts (list[dict[str, torch.Tensor]]): The state dicts to merge.\n        weights (list[float]): The weights for each state dict. The weights will be normalized to sum to 1.\n        merge_method (Literal[\"LERP\", \"SLERP\"]): Merge method to use. Options:\n            - \"LERP\": Linear interpolation a.k.a. weighted sum.\n            - \"SLERP\": Spherical linear interpolation.\n    \"\"\"\n    if len(state_dicts) < 2:\n        raise ValueError(\"Must provide >=2 models to merge.\")\n\n    if len(state_dicts) != len(weights):\n        raise ValueError(\"Must provide a weight for each model.\")\n\n    if merge_method == \"LERP\":\n        merge_fn = lerp\n    elif merge_method == \"SLERP\":\n        merge_fn = slerp\n    else:\n        raise ValueError(f\"Unknown merge method: {merge_method}\")\n\n    normalized_weights = normalize_weights(weights)\n\n    out_state_dict: dict[str, torch.Tensor] = state_dicts[0].copy()\n    out_state_dict_weight = normalized_weights[0]\n    for state_dict, normalized_weight in zip(state_dicts[1:], normalized_weights[1:], strict=True):\n        if state_dict.keys() != out_state_dict.keys():\n            raise ValueError(\"State dicts must have the same keys.\")\n\n        cur_pair_weights = normalize_weights([out_state_dict_weight, normalized_weight])\n        for key in tqdm.tqdm(out_state_dict.keys()):\n            out_state_dict[key] = merge_fn(out_state_dict[key], state_dict[key], cur_pair_weights[0])\n\n        # Update the weight of out_state_dict to be the sum of all state dicts merged so far.\n        out_state_dict_weight += normalized_weight\n\n    return out_state_dict\n\n\ndef lerp(a: torch.Tensor, b: torch.Tensor, weight_a: float) -> torch.Tensor:\n    \"\"\"Linear interpolation.\"\"\"\n    return torch.lerp(a, b, (1.0 - weight_a))\n\n\ndef slerp(a: torch.Tensor, b: torch.Tensor, weight_a: float, dot_product_thres=0.9995, epsilon=1e-10):\n    \"\"\"Spherical linear interpolation.\"\"\"\n    # TODO(ryand): For multi-dimensional matrices, it might be better to apply slerp on a subset of the dimensions\n    # (e.g. per-row), rather than treating the entire matrix as a single flattened vector.\n\n    # Normalize the vectors.\n    a_norm = torch.linalg.norm(a)\n    b_norm = torch.linalg.norm(b)\n    a_normalized = a / a_norm\n    b_normalized = b / b_norm\n\n    if a_norm < epsilon or b_norm < epsilon:\n        # If either vector is very small, fallback to lerp to avoid weird effects.\n        # TODO(ryand): Is fallback here necessary?\n        return lerp(a, b, weight_a)\n\n    # Dot product of the normalized vectors.\n    # We are effectively treating multi-dimensional tensors as flattened vectors.\n    dot_prod = torch.sum(a_normalized * b_normalized)\n\n    # If the absolute value of the dot product is almost 1, the vectors are ~colinear, so use lerp.\n    if torch.abs(dot_prod) > dot_product_thres:\n        return lerp(a, b, weight_a)\n\n    # Calculate initial angle between the vectors.\n    theta_0 = torch.acos(dot_prod)\n\n    # Angle at timestep t.\n    t = 1.0 - weight_a\n    theta_t = theta_0 * t\n\n    sin_theta_0 = torch.sin(theta_0)\n    sin_theta_t = torch.sin(theta_t)\n\n    s0 = torch.sin(theta_0 - theta_t) / sin_theta_0\n    s1 = sin_theta_t / sin_theta_0\n    result = s0 * a + s1 * b\n\n    return result\n"
  },
  {
    "path": "src/invoke_training/model_merge/merge_tasks_to_base.py",
    "content": "from typing import Literal\n\nimport torch\nimport tqdm\nfrom peft.utils.merge_utils import dare_linear, dare_ties, ties\n\n\n@torch.no_grad()\ndef merge_tasks_to_base_model(\n    base_state_dict: dict[str, torch.Tensor],\n    task_state_dicts: list[dict[str, torch.Tensor]],\n    task_weights: list[float],\n    density: float = 0.2,\n    merge_method: Literal[\"TIES\", \"DARE_LINEAR\", \"DARE_TIES\"] = \"TIES\",\n) -> torch.Tensor:\n    \"\"\"Merge a base model with one or more task-specific models.\n\n    Args:\n        base_state_dict (dict[str, torch.Tensor]): The base state dict to merge with.\n        task_state_dicts (list[dict[str, torch.Tensor]]): A list of task-specific state dicts to merge into the base\n            state dict.\n        task_weights (list[float]): Weights for each task state dict. Weights of 1.0 for all task_state_dicts are\n            recommended as a starting point (e.g. [1.0, 1.0, 1.0]). The weights can be adjusted from there (e.g.\n            [1.0, 1.3, 1.0]). The weights are multipliers applied to the diff between each task_state_dict and the base\n            model.\n        density (float, optional): The fraction of values to preserve in the prune/trim step of DARE/TIES methods.\n            Should be in the range [0, 1].\n        merge_method (Literal[\"TIES\", \"DARE_LINEAR\", \"DARE_TIES\"], optional): The method to use for merging. Options:\n            - \"TIES\": Use the TIES method (https://arxiv.org/pdf/2306.01708)\n            - \"DARE_LINEAR\": Use the DARE method with linear interpolation (https://arxiv.org/pdf/2311.03099)\n            - \"DARE_TIES\": Use the DARE method for pruning, and the TIES method for merging.\n    \"\"\"\n    if len(task_state_dicts) != len(task_weights):\n        raise ValueError(\"Must provide a weight for each model.\")\n\n    task_weights = torch.tensor(task_weights)\n\n    # Choose the merging method.\n    if merge_method == \"TIES\":\n        merge_fn = ties\n    elif merge_method == \"DARE_LINEAR\":\n        merge_fn = dare_linear\n    elif merge_method == \"DARE_TIES\":\n        merge_fn = dare_ties\n    else:\n        raise ValueError(f\"Unknown merge method: {merge_method}\")\n\n    out_state_dict: dict[str, torch.Tensor] = {}\n    for key in tqdm.tqdm(base_state_dict.keys()):\n        base_tensor = base_state_dict[key]\n        orig_dtype = base_tensor.dtype\n\n        # Calculate the diff between each task tensor and the base tensor.\n        task_diff_tensors = [state_dict[key] - base_tensor for state_dict in task_state_dicts]\n\n        merged_diff_tensor = merge_fn(\n            task_tensors=task_diff_tensors,\n            weights=task_weights,\n            density=density,\n        )\n\n        # Some of the merge_fn implementations may return a tensor with a different dtype than the original tensors.\n        # We cast the merged_diff_tensor back to the original dtype here.\n        out_state_dict[key] = (base_tensor + merged_diff_tensor).to(dtype=orig_dtype)\n\n    return out_state_dict\n"
  },
  {
    "path": "src/invoke_training/model_merge/scripts/extract_lora_from_model_diff.py",
    "content": "# This script is based on\n# https://raw.githubusercontent.com/kohya-ss/sd-scripts/bfb352bc433326a77aca3124248331eb60c49e8c/networks/extract_lora_from_models.py\n# That script was originally based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py\n\nimport argparse\nimport logging\nimport sys\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom typing import Literal\n\nimport peft\nimport torch\nfrom diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, UNet2DConditionModel\nfrom transformers import CLIPTextModel, CLIPTextModelWithProjection\n\nfrom invoke_training._shared.accelerator.accelerator_utils import get_dtype_from_str\nfrom invoke_training._shared.stable_diffusion.lora_checkpoint_utils import (\n    TEXT_ENCODER_TARGET_MODULES,\n    UNET_TARGET_MODULES,\n    save_sdxl_kohya_checkpoint,\n)\nfrom invoke_training._shared.stable_diffusion.model_loading_utils import (\n    PipelineVersionEnum,\n    from_pretrained_with_variant_fallback,\n    load_pipeline,\n)\nfrom invoke_training.model_merge.extract_lora import (\n    PEFT_BASE_LAYER_PREFIX,\n    extract_lora_from_diffs,\n    get_patched_base_weights_from_peft_model,\n    get_state_dict_diff,\n)\nfrom invoke_training.model_merge.utils.parse_model_arg import parse_model_arg\n\n\n@dataclass\nclass StableDiffusionModel:\n    \"\"\"A helper class to store the submodels of a SD model that we are interested in for LoRA extraction.\"\"\"\n\n    unet: UNet2DConditionModel | None = None\n    text_encoder: CLIPTextModel | None = None\n    text_encoder_2: CLIPTextModelWithProjection | None = None\n\n    def all_none(self) -> bool:\n        return self.unet is None and self.text_encoder is None and self.text_encoder_2 is None\n\n\ndef load_model(\n    logger: logging.Logger,\n    model_name_or_path: str,\n    model_type: PipelineVersionEnum,\n    variant: str | None,\n    dtype: torch.dtype,\n) -> StableDiffusionModel:\n    sd_model = StableDiffusionModel()\n\n    model_path = Path(model_name_or_path)\n    if model_path.is_dir():\n        # model_path is a directory, so we'll try to load the submodels of interest from its subdirectories.\n        logger.info(f\"'{model_name_or_path}' is a directory. Attempting to load submodels.\")\n        for submodel_name, submodel_class in [\n            (\"unet\", UNet2DConditionModel),\n            (\"text_encoder\", CLIPTextModel),\n            (\"text_encoder_2\", CLIPTextModelWithProjection),\n        ]:\n            submodel_path: Path = model_path / submodel_name\n            if submodel_path.exists():\n                logger.info(f\"Loading '{submodel_name}' from '{submodel_path}'.\")\n                submodel = from_pretrained_with_variant_fallback(\n                    logger=logger,\n                    model_class=submodel_class,\n                    model_name_or_path=submodel_path,\n                    torch_dtype=dtype,\n                    variant=variant,\n                    local_files_only=True,\n                )\n                setattr(sd_model, submodel_name, submodel)\n            else:\n                logger.info(f\"'{submodel_name}' not found in '{model_name_or_path}'. Skipping.\")\n                continue\n    else:\n        # model_name_or_path is not a directory, so it is either:\n        # 1) a single checkpoint file\n        # 2) a HF model name\n        # Both can be loaded by calling load_pipeline.\n        logger.info(f\"'{model_name_or_path}' is a single checkpoint file. Attempting to load.\")\n        pipeline = load_pipeline(\n            logger=logger,\n            model_name_or_path=model_name_or_path,\n            pipeline_version=model_type,\n            torch_dtype=dtype,\n            variant=variant,\n        )\n        if isinstance(pipeline, StableDiffusionPipeline):\n            sd_model.unet = pipeline.unet\n            sd_model.text_encoder = pipeline.text_encoder\n        elif isinstance(pipeline, StableDiffusionXLPipeline):\n            sd_model.unet = pipeline.unet\n            sd_model.text_encoder = pipeline.text_encoder\n            sd_model.text_encoder_2 = pipeline.text_encoder_2\n        else:\n            raise RuntimeError(f\"Unexpected pipeline type: {type(pipeline)}.\")\n\n    if sd_model.all_none():\n        raise RuntimeError(f\"Failed to load any submodels from '{model_name_or_path}'.\")\n\n    return sd_model\n\n\ndef str_to_device(device_str: Literal[\"cuda\", \"cpu\"]) -> torch.device:\n    if device_str == \"cuda\":\n        return torch.device(\"cuda\")\n    elif device_str == \"cpu\":\n        return torch.device(\"cpu\")\n    else:\n        raise ValueError(f\"Unexpected device: {device_str}\")\n\n\ndef state_dict_to_device(state_dict: dict[str, torch.Tensor], device: torch.device) -> dict[str, torch.Tensor]:\n    return {k: v.to(device=device) for k, v in state_dict.items()}\n\n\ndef extract_lora_from_submodel(\n    logger: logging.Logger,\n    model_orig: torch.nn.Module,\n    model_tuned: torch.nn.Module,\n    device: torch.device,\n    out_dtype: torch.dtype,\n    lora_target_modules: list[str],\n    lora_rank: int,\n    clamp_quantile: float = 0.99,\n) -> peft.PeftModel:\n    \"\"\"Extract LoRA weights from the diff between model_orig and model_tuned. Returns a new model_orig, wrapped in a\n    PeftModel, with the LoRA weights applied.\n    \"\"\"\n    # Apply LoRA to the UNet.\n    # The only reason we do this is to get the module names for the weights that we'll extract. We don't actually use\n    # the LoRA weights initialized here.\n    unet_lora_config = peft.LoraConfig(\n        r=lora_rank,\n        # We set the alpha to the rank, because we don't want any scaling to be applied to the LoRA weights that we\n        # extract.\n        lora_alpha=lora_rank,\n        target_modules=lora_target_modules,\n    )\n    model_tuned = peft.get_peft_model(model_tuned, unet_lora_config)\n    model_orig = peft.get_peft_model(model_orig, unet_lora_config)\n\n    base_weights_tuned = get_patched_base_weights_from_peft_model(model_tuned)\n    base_weights_orig = get_patched_base_weights_from_peft_model(model_orig)\n\n    diffs = get_state_dict_diff(base_weights_tuned, base_weights_orig)\n\n    # Clear tuned model to save memory.\n    # TODO(ryand): We also need to clear the state_dicts. Move the diff extraction to a separate function so that memory\n    # cleanup is handled by scoping.\n    del model_tuned\n\n    # Apply SVD (Singluar Value Decomposition) to the diffs.\n    # We just use the device for this calculation, since it's slow, then we move the results back to the CPU.\n    logger.info(\"Calculating LoRA weights with SVD.\")\n    diffs = state_dict_to_device(diffs, device)\n    # TODO(ryand): Should we skip if the diffs are all zeros? This would happen if two models are identical. This could\n    # happen if some submodels differ while others don't.\n    lora_weights = extract_lora_from_diffs(\n        diffs=diffs, rank=lora_rank, clamp_quantile=clamp_quantile, out_dtype=out_dtype\n    )\n\n    # Prepare state dict for LoRA.\n    lora_state_dict = {}\n    for module_name, (lora_up, lora_down) in lora_weights.items():\n        lora_state_dict[PEFT_BASE_LAYER_PREFIX + module_name + \".lora_A.default.weight\"] = lora_down\n        lora_state_dict[PEFT_BASE_LAYER_PREFIX + module_name + \".lora_B.default.weight\"] = lora_up\n        # The alpha value is set once globally in the PEFT model, so no need to set it for each module.\n        # lora_state_dict[peft_base_layer_suffix + module_name + \".alpha\"] = torch.tensor(down_weight.size()[0])\n\n    lora_state_dict = state_dict_to_device(lora_state_dict, torch.device(\"cpu\"))\n\n    # Load the state_dict into the LoRA model.\n    model_orig.load_state_dict(lora_state_dict, strict=False, assign=True)\n\n    return model_orig\n\n\n@torch.no_grad()\ndef extract_lora(\n    logger: logging.Logger,\n    model_type: PipelineVersionEnum,\n    orig_model_name_or_path: str,\n    orig_model_variant: str | None,\n    tuned_model_name_or_path: str,\n    tuned_model_variant: str | None,\n    save_to: str,\n    load_precision: Literal[\"float32\", \"float16\", \"bfloat16\"],\n    save_precision: Literal[\"float32\", \"float16\", \"bfloat16\"],\n    device: Literal[\"cuda\", \"cpu\"],\n    lora_rank: int,\n    clamp_quantile=0.99,\n):\n    load_dtype = get_dtype_from_str(load_precision)\n    save_dtype = get_dtype_from_str(save_precision)\n    device = str_to_device(device)\n\n    orig_model = load_model(\n        logger=logger,\n        model_name_or_path=orig_model_name_or_path,\n        model_type=model_type,\n        dtype=load_dtype,\n        variant=orig_model_variant,\n    )\n    tuned_model = load_model(\n        logger=logger,\n        model_name_or_path=tuned_model_name_or_path,\n        model_type=model_type,\n        dtype=load_dtype,\n        variant=tuned_model_variant,\n    )\n\n    lora_models: dict[str, peft.PeftModel] = {}\n    for submodel_name, submodel_orig, submodel_tuned, lora_target_modules in [\n        (\"unet\", orig_model.unet, tuned_model.unet, UNET_TARGET_MODULES),\n        (\"text_encoder\", orig_model.text_encoder, tuned_model.text_encoder, TEXT_ENCODER_TARGET_MODULES),\n        (\"text_encoder_2\", orig_model.text_encoder_2, tuned_model.text_encoder_2, TEXT_ENCODER_TARGET_MODULES),\n    ]:\n        if submodel_orig is not None and submodel_tuned is not None:\n            logger.info(f\"Extracting LoRA weights for '{submodel_name}'.\")\n            lora_models[submodel_name] = extract_lora_from_submodel(\n                logger=logger,\n                model_orig=submodel_orig,\n                model_tuned=submodel_tuned,\n                device=device,\n                out_dtype=save_dtype,\n                lora_target_modules=lora_target_modules,\n                lora_rank=lora_rank,\n                clamp_quantile=clamp_quantile,\n            )\n        else:\n            logger.info(f\"Skipping '{submodel_name}'.\")\n\n    # Save the LoRA weights.\n    save_to_path = Path(save_to)\n    assert save_to_path.suffix == \".safetensors\"\n    if save_to_path.exists():\n        raise FileExistsError(f\"Destination file already exists: '{save_to}'.\")\n    save_to_path.parent.mkdir(parents=True, exist_ok=True)\n    save_sdxl_kohya_checkpoint(\n        save_to_path,\n        unet=lora_models.get(\"unet\", None),\n        text_encoder_1=lora_models.get(\"text_encoder\", None),\n        text_encoder_2=lora_models.get(\"text_encoder_2\", None),\n    )\n\n    logger.info(f\"Saved LoRA weights to: {save_to_path}\")\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--model-type\",\n        type=str,\n        choices=[\"SD\", \"SDXL\"],\n        help=\"The type of the models to merge ['SD', 'SDXL'].\",\n    )\n    parser.add_argument(\n        \"--model-orig\",\n        type=str,\n        required=True,\n        help=\"Path or HF Hub name of the original model. The model must be in one of the following formats: \"\n        \"1) a single checkpoint file (e.g. '.safetensors') containing all submodels, \"\n        \"2) a model in diffusers format containing all submodels, \"\n        \"or 3) a model in diffusers format containing a subset of the submodels (e.g. only a UNet).\"\n        \"An HF variant can optionally be appended to the model name after a double-colon delimiter ('::').\"\n        \"E.g. '--model-orig runwayml/stable-diffusion-v1-5::fp16'\",\n    )\n    parser.add_argument(\n        \"--model-tuned\",\n        type=str,\n        required=True,\n        help=\"Path or HF Hub name of the tuned model. The model must be in one of the following formats: \"\n        \"1) a single checkpoint file (e.g. '.safetensors') containing all submodels, \"\n        \"2) a model in diffusers format containing all submodels, \"\n        \"or 3) a model in diffusers format containing a subset of the submodels (e.g. only a UNet).\"\n        \"An HF variant can optionally be appended to the model name after a double-colon delimiter ('::').\"\n        \"E.g. '--model-orig runwayml/stable-diffusion-v1-5::fp16'\",\n    )\n    parser.add_argument(\n        \"--save-to\",\n        type=str,\n        required=True,\n        help=\"Destination file path (must have a .safetensors extension).\",\n    )\n    parser.add_argument(\n        \"--load-precision\",\n        type=str,\n        default=\"bfloat16\",\n        choices=[\"float32\", \"float16\", \"bfloat16\"],\n        help=\"Model load precision.\",\n    )\n    parser.add_argument(\n        \"--save-precision\",\n        type=str,\n        default=\"float16\",\n        choices=[\"float32\", \"float16\", \"bfloat16\"],\n        help=\"Model save precision.\",\n    )\n\n    parser.add_argument(\"--lora-rank\", type=int, default=4, help=\"LoRA rank dimension.\")\n    parser.add_argument(\"--clamp-quantile\", type=float, default=0.99, help=\"Quantile clamping value. (0-1)\")\n    parser.add_argument(\n        \"--device\", type=str, default=\"cuda\", choices=[\"cuda\", \"cpu\"], help=\"Device to use. (cuda or cpu)\"\n    )\n\n    args = parser.parse_args()\n\n    logging.basicConfig(stream=sys.stdout, level=logging.INFO)\n    logger = logging.getLogger()\n\n    orig_model_name_or_path, orig_model_variant = parse_model_arg(args.model_orig)\n    tuned_model_name_or_path, tuned_model_variant = parse_model_arg(args.model_tuned)\n\n    extract_lora(\n        logger=logger,\n        model_type=PipelineVersionEnum(args.model_type),\n        orig_model_name_or_path=orig_model_name_or_path,\n        orig_model_variant=orig_model_variant,\n        tuned_model_name_or_path=tuned_model_name_or_path,\n        tuned_model_variant=tuned_model_variant,\n        save_to=args.save_to,\n        load_precision=args.load_precision,\n        save_precision=args.save_precision,\n        device=args.device,\n        lora_rank=args.lora_rank,\n        clamp_quantile=args.clamp_quantile,\n    )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/invoke_training/model_merge/scripts/merge_lora_into_model.py",
    "content": "import argparse  # noqa: I001\nimport logging\nfrom pathlib import Path\n\nimport torch\nfrom diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline\n\n# fmt: off\n# HACK(ryand): Import order matters, because invokeai contains circular imports.\nfrom invokeai.backend.model_manager.taxonomy import BaseModelType\nfrom invokeai.backend.patches.layer_patcher import LayerPatcher\nfrom invokeai.backend.patches.lora_conversions.sd_lora_conversion_utils import \\\n    lora_model_from_sd_state_dict\nfrom invokeai.backend.util.original_weights_storage import \\\n    OriginalWeightsStorage\nfrom safetensors.torch import load_file\n\n# fmt: on\nfrom invoke_training._shared.accelerator.accelerator_utils import get_dtype_from_str\nfrom invoke_training._shared.stable_diffusion.model_loading_utils import PipelineVersionEnum, load_pipeline\nfrom invoke_training.model_merge.utils.parse_model_arg import parse_model_arg\n\n\ndef to_invokeai_base_model_type(model_type: PipelineVersionEnum):\n    if model_type == PipelineVersionEnum.SD:\n        return BaseModelType.StableDiffusion1\n    elif model_type == PipelineVersionEnum.SDXL:\n        return BaseModelType.StableDiffusionXL\n    else:\n        raise ValueError(f\"Unexpected model_type: {model_type}\")\n\n\n@torch.no_grad()\ndef merge_lora_into_sd_model(\n    logger: logging.Logger,\n    model_type: PipelineVersionEnum,\n    base_model: str,\n    base_model_variant: str | None,\n    lora_models: list[tuple[str, float]],\n    output: str,\n    save_dtype: str,\n):\n    pipeline: StableDiffusionXLPipeline | StableDiffusionPipeline = load_pipeline(\n        logger=logger, model_name_or_path=base_model, pipeline_version=model_type, variant=base_model_variant\n    )\n    save_dtype = get_dtype_from_str(save_dtype)\n\n    logger.info(f\"Loaded base model: '{base_model}'.\")\n\n    pipeline.to(save_dtype)\n\n    models: list[torch.nn.Module] = []\n    lora_prefixes: list[str] = []\n    if isinstance(pipeline, StableDiffusionPipeline):\n        models = [pipeline.unet, pipeline.text_encoder]\n        lora_prefixes = [\"lora_unet_\", \"lora_te_\"]\n    elif isinstance(pipeline, StableDiffusionXLPipeline):\n        models = [pipeline.unet, pipeline.text_encoder, pipeline.text_encoder_2]\n        lora_prefixes = [\"lora_unet_\", \"lora_te1_\", \"lora_te2_\"]\n    else:\n        raise ValueError(f\"Unexpected pipeline type: {type(pipeline)}\")\n\n    # Although we are not unpatching, the patcher might require this. Initialize empty.\n    original_weights = OriginalWeightsStorage()\n\n    for lora_model_path, lora_model_weight in lora_models:\n        # Load state dict from file\n        lora_path = Path(lora_model_path)\n        if lora_path.suffix == \".safetensors\":\n            state_dict = load_file(lora_path.absolute().as_posix(), device=\"cpu\")\n        else:\n            # Assuming .ckpt, .pt, .bin etc. are torch checkpoints\n            state_dict = torch.load(lora_path, map_location=\"cpu\")\n\n        # Convert state dict to ModelPatchRaw\n        lora_model = lora_model_from_sd_state_dict(state_dict=state_dict)\n\n        # Apply the patch using LayerPatcher\n        for model, lora_prefix in zip(models, lora_prefixes, strict=True):\n            LayerPatcher.apply_smart_model_patch(\n                model=model,\n                prefix=lora_prefix,\n                patch=lora_model,\n                patch_weight=lora_model_weight,\n                original_weights=original_weights,  # Pass storage, even if unused for merging\n                original_modules={},  # Pass empty dict, not needed for direct patching/merging\n                dtype=model.dtype,  # Use the model's dtype\n                # Force direct patching since we are merging into the main weights\n                force_direct_patching=True,\n                force_sidecar_patching=False,\n            )\n        logger.info(f\"Applied LoRA model '{lora_model_path}' with weight {lora_model_weight}.\")\n\n    output_path = Path(output)\n    output_path.mkdir(parents=True)\n\n    # TODO(ryand): Should we keep the base model variant? This is clearly a flawed assumption.\n    pipeline.save_pretrained(output_path, variant=base_model_variant)\n    logger.info(f\"Saved merged model to '{output_path}'.\")\n\n\ndef parse_lora_model_arg(lora_model_arg: str) -> tuple[str, float]:\n    \"\"\"Parse a --lora-model argument into a tuple of the model path and weight.\"\"\"\n    parts = lora_model_arg.split(\"::\")\n    if len(parts) == 1:\n        return parts[0], 1.0\n    elif len(parts) == 2:\n        return parts[0], float(parts[1])\n    else:\n        raise ValueError(f\"Unexpected format for --lora-model arg: '{lora_model_arg}'.\")\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--model-type\",\n        type=str,\n        choices=[\"SD\", \"SDXL\"],\n        help=\"The type of the models to merge ['SD', 'SDXL'].\",\n    )\n    parser.add_argument(\n        \"--base-model\",\n        type=str,\n        help=\"The base model to merge LoRAs into. The model can be either 1) an HF hub name, 2) a path to a local \"\n        \"diffusers model directory, or 3) a path to a single checkpoint file. An HF variant can optionally be appended \"\n        \"to the model name after a double-colon delimiter ('::').\"\n        \"E.g. '--base-model runwayml/stable-diffusion-v1-5::fp16'\",\n        required=True,\n    )\n    parser.add_argument(\n        \"--lora-models\",\n        type=str,\n        nargs=\"+\",\n        help=\"The path(s) to one or more LoRA models to merge into the base model. Model weights can be appended to \"\n        \"the path, separated by a double colon ('::'). The weight is optional and defaults to 1.0. E.g. \"\n        \"'--lora-models path/to/lora_model_1.safetensors::0.5 path/to/lora_model_2.safetensors'.\",\n        required=True,\n    )\n    parser.add_argument(\n        \"--output\",\n        type=str,\n        help=\"The path to an output directory where the merged model will be saved (in diffusers format).\",\n    )\n    parser.add_argument(\n        \"--save-dtype\",\n        type=str,\n        default=\"float16\",\n        choices=[\"float32\", \"float16\", \"bfloat16\"],\n        help=\"The dtype to save the model as.\",\n    )\n\n    args = parser.parse_args()\n\n    logging.basicConfig(level=logging.INFO)\n    logger = logging.getLogger()\n\n    base_model, base_model_variant = parse_model_arg(args.base_model)\n    lora_models = [parse_lora_model_arg(arg) for arg in args.lora_models]\n\n    # Log the parsed arguments\n    logger.info(f\"Model type: {args.model_type}\")\n    logger.info(f\"Base model: {base_model}\")\n    logger.info(f\"Base model variant: {base_model_variant}\")\n    logger.info(f\"Output directory: {args.output}\")\n    logger.info(f\"Save dtype: {args.save_dtype}\")\n    lora_models_str = \"  - \" + \"\\n  - \".join([f\"{model} ({weight})\" for model, weight in lora_models])\n    logger.info(f\"LoRA models:\\n{lora_models_str}\")\n\n    merge_lora_into_sd_model(\n        logger=logger,\n        model_type=PipelineVersionEnum(args.model_type),\n        base_model=base_model,\n        base_model_variant=base_model_variant,\n        lora_models=lora_models,\n        output=args.output,\n        save_dtype=args.save_dtype,\n    )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/invoke_training/model_merge/scripts/merge_models.py",
    "content": "import argparse\nimport logging\nfrom dataclasses import dataclass\nfrom pathlib import Path\n\nimport torch\nfrom diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline\n\nfrom invoke_training._shared.accelerator.accelerator_utils import get_dtype_from_str\nfrom invoke_training._shared.stable_diffusion.model_loading_utils import PipelineVersionEnum, load_pipeline\nfrom invoke_training.model_merge.merge_models import merge_models\nfrom invoke_training.model_merge.utils.parse_model_arg import parse_model_arg\n\n\n@dataclass\nclass MergeModel:\n    model_name_or_path: str\n    variant: str | None\n    weight: float\n\n\ndef run_merge_models(\n    logger: logging.Logger,\n    model_type: PipelineVersionEnum,\n    models: list[MergeModel],\n    method: str,\n    out_dir: str,\n    dtype: torch.dtype,\n):\n    # Create the output directory if it doesn't exist.\n    out_dir_path = Path(out_dir)\n    out_dir_path.mkdir(parents=True, exist_ok=False)\n\n    # Load the models.\n    loaded_models: list[StableDiffusionPipeline] | list[StableDiffusionXLPipeline] = []\n    for model in models:\n        loaded_model = load_pipeline(\n            logger=logger,\n            model_name_or_path=model.model_name_or_path,\n            pipeline_version=model_type,\n            torch_dtype=dtype,\n            variant=model.variant,\n        )\n        loaded_models.append(loaded_model)\n\n    # Select the submodels to merge.\n    if model_type == PipelineVersionEnum.SDXL:\n        submodel_names = [\"unet\", \"text_encoder\", \"text_encoder_2\"]\n    elif model_type == PipelineVersionEnum.SD:\n        submodel_names = [\"unet\", \"text_encoder\"]\n    else:\n        raise ValueError(f\"Unexpected model type: {model_type}\")\n\n    # Merge the models.\n    weights = [model.weight for model in models]\n    for submodel_name in submodel_names:\n        submodels: list[torch.nn.Module] = [getattr(loaded_model, submodel_name) for loaded_model in loaded_models]\n        submodel_state_dicts: list[dict[str, torch.Tensor]] = [submodel.state_dict() for submodel in submodels]\n\n        logger.info(f\"Merging {submodel_name} state_dicts...\")\n        merged_state_dict = merge_models(state_dicts=submodel_state_dicts, weights=weights, merge_method=method)\n\n        # Merge the merged_state_dict back into the first pipeline to keep memory utilization low.\n        submodels[0].load_state_dict(merged_state_dict, assign=True)\n        logger.info(f\"Merged {submodel_name} state_dicts.\")\n\n    # Save the merged model.\n    logger.info(\"Saving result...\")\n    loaded_models[0].save_pretrained(out_dir_path)\n    logger.info(f\"Saved merged model to '{out_dir_path}'.\")\n\n\ndef parse_model_args(models: list[str], weights: list[str]) -> list[MergeModel]:\n    \"\"\"Parse a list of --models arguments and --weights arguments into a list of MergeModels.\"\"\"\n    merge_model_list: list[MergeModel] = []\n    for model, weight in zip(models, weights, strict=True):\n        parsed_model, parsed_variant = parse_model_arg(model)\n        merge_model_list.append(\n            MergeModel(model_name_or_path=parsed_model, variant=parsed_variant, weight=float(weight))\n        )\n\n    return merge_model_list\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n\n    # TODO(ryand): Auto-detect the model-type.\n    parser.add_argument(\n        \"--model-type\",\n        type=str,\n        choices=[\"SD\", \"SDXL\"],\n        help=\"The type of the models to merge ['SD', 'SDXL'].\",\n    )\n    parser.add_argument(\n        \"--models\",\n        nargs=\"+\",\n        type=str,\n        required=True,\n        help=\"Two or more models to merge. Each model can be either 1) an HF hub name, 2) a path to a local diffusers \"\n        \"model directory, or 3) a path to a single checkpoint file. An HF variant can optionally be appended to the \"\n        \"model name after a double-colon delimiter ('::').\"\n        \"E.g. '--models runwayml/stable-diffusion-v1-5::fp16 path/to/local/model.safetensors'\",\n    )\n    parser.add_argument(\n        \"--weights\",\n        nargs=\"+\",\n        type=float,\n        required=True,\n        help=\"The weights for each model. The weights will be normalized to sum to 1. \"\n        \"For example, to merge weights with equal weights: '--weights 1.0 1.0'. \"\n        \"To weight the first model more heavily: '--weights 0.75 0.25'.\",\n    )\n    parser.add_argument(\n        \"--method\",\n        type=str,\n        default=\"LERP\",\n        choices=[\"LERP\", \"SLERP\"],\n        help=\"The merge method to use. Options: 'LERP' (linear interpolation) or 'SLERP' (spherical linear \"\n        \"interpolation).\",\n    )\n    parser.add_argument(\n        \"--out-dir\",\n        type=str,\n        required=True,\n        help=\"The output directory where the merged model will be written (in diffusers format).\",\n    )\n    parser.add_argument(\n        \"--dtype\",\n        help=\"The torch dtype that will be used for all calculations and for the output model.\",\n        type=str,\n        default=\"float16\",\n        choices=[\"float32\", \"float16\", \"bfloat16\"],\n    )\n\n    args = parser.parse_args()\n\n    logging.basicConfig(level=logging.INFO)\n    logger = logging.getLogger(__name__)\n\n    merge_model_list = parse_model_args(args.models, args.weights)\n    run_merge_models(\n        logger=logger,\n        model_type=PipelineVersionEnum(args.model_type),\n        models=merge_model_list,\n        method=args.method,\n        out_dir=args.out_dir,\n        dtype=get_dtype_from_str(args.dtype),\n    )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/invoke_training/model_merge/scripts/merge_task_models_to_base_model.py",
    "content": "import argparse\nimport logging\nfrom pathlib import Path\n\nimport torch\nfrom diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline\n\nfrom invoke_training._shared.accelerator.accelerator_utils import get_dtype_from_str\nfrom invoke_training._shared.stable_diffusion.model_loading_utils import PipelineVersionEnum, load_pipeline\nfrom invoke_training.model_merge.merge_tasks_to_base import merge_tasks_to_base_model\nfrom invoke_training.model_merge.scripts.merge_models import MergeModel, parse_model_args\n\n\ndef run_merge_models(\n    logger: logging.Logger,\n    model_type: PipelineVersionEnum,\n    base_model: MergeModel,\n    task_models: list[MergeModel],\n    method: str,\n    density: float,\n    out_dir: str,\n    dtype: torch.dtype,\n):\n    # Create the output directory if it doesn't exist.\n    out_dir_path = Path(out_dir)\n    out_dir_path.mkdir(parents=True, exist_ok=False)\n\n    # Load the base model.\n    loaded_base_model = load_pipeline(\n        logger=logger,\n        model_name_or_path=base_model.model_name_or_path,\n        pipeline_version=model_type,\n        torch_dtype=dtype,\n        variant=base_model.variant,\n    )\n\n    # Load the task models.\n    loaded_task_models: list[StableDiffusionPipeline] | list[StableDiffusionXLPipeline] = []\n    for task_model in task_models:\n        loaded_task_model = load_pipeline(\n            logger=logger,\n            model_name_or_path=task_model.model_name_or_path,\n            pipeline_version=model_type,\n            torch_dtype=dtype,\n            variant=task_model.variant,\n        )\n        loaded_task_models.append(loaded_task_model)\n\n    # Select the submodels to merge.\n    if model_type == PipelineVersionEnum.SDXL:\n        submodel_names = [\"unet\", \"text_encoder\", \"text_encoder_2\"]\n    elif model_type == PipelineVersionEnum.SD:\n        submodel_names = [\"unet\", \"text_encoder\"]\n    else:\n        raise ValueError(f\"Unexpected model type: {model_type}\")\n\n    # Merge the models.\n    task_model_weights = [task_model.weight for task_model in task_models]\n    for submodel_name in submodel_names:\n        base_submodel: torch.nn.Module = getattr(loaded_base_model, submodel_name)\n        base_submodel_state_dict = base_submodel.state_dict()\n        task_submodels: list[torch.nn.Module] = [\n            getattr(loaded_task_model, submodel_name) for loaded_task_model in loaded_task_models\n        ]\n        task_submodel_state_dict = [submodel.state_dict() for submodel in task_submodels]\n\n        logger.info(f\"Merging {submodel_name} state_dicts...\")\n        merged_state_dict = merge_tasks_to_base_model(\n            base_state_dict=base_submodel_state_dict,\n            task_state_dicts=task_submodel_state_dict,\n            task_weights=task_model_weights,\n            density=density,\n            merge_method=method,\n        )\n\n        # Merge the merged_state_dict back into the base model pipeline to keep memory utilization low.\n        base_submodel.load_state_dict(merged_state_dict, assign=True)\n        logger.info(f\"Merged {submodel_name} state_dicts.\")\n\n    # Delete the task models to free up memory.\n    # At the time of the writing, the save_pretrained(...) function below caused a large spike in memory usage. We free\n    # the task models to increase its likelihood of success.\n    del loaded_task_models\n\n    # Save the merged model.\n    logger.info(\"Saving result...\")\n    loaded_base_model.save_pretrained(out_dir_path)\n    logger.info(f\"Saved merged model to '{out_dir_path}'.\")\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    # TODO(ryand): Auto-detect the base-model-type.\n    parser.add_argument(\n        \"--model-type\",\n        type=str,\n        choices=[\"SD\", \"SDXL\"],\n        help=\"The type of the models to merge ['SD', 'SDXL'].\",\n    )\n    parser.add_argument(\n        \"--base-model\",\n        type=str,\n        help=\"The base model to merge task-specific models into. Can be either 1) an HF hub name, 2) a path to a local \"\n        \"diffusers model directory, or 3) a path to a single checkpoint file. An HF variant can optionally be appended \"\n        \"to the model name after a double-colon delimiter ('::').\"\n        \"E.g. '--base-model runwayml/stable-diffusion-v1-5::fp16'.\",\n    )\n    parser.add_argument(\n        \"--task-models\",\n        nargs=\"+\",\n        type=str,\n        required=True,\n        help=\"One or more task-specific models to merge into the base model. Each model can be either 1) an HF hub \"\n        \"name, 2) a path to a local diffusers model directory, or 3) a path to a single checkpoint file. An HF variant \"\n        \"can optionally be appended to the model name after a double-colon delimiter ('::').\"\n        \"E.g. '--task-models runwayml/stable-diffusion-v1-5::fp16 path/to/local/model.safetensors'\",\n    )\n    parser.add_argument(\n        \"--task-weights\",\n        nargs=\"+\",\n        type=float,\n        required=True,\n        help=\"The weights for each task model. The weights are multipliers applied to the diff between each task model \"\n        \"and the base model. As a starting point, it is recommended to use a weight of 1.0 for all task models, e.g. \"\n        \"'--task-weights 1.0 1.0'. The weights can then be tuned from there, e.g. '--task-weights 1.0 1.3'.\",\n    )\n    parser.add_argument(\n        \"--method\",\n        type=str,\n        default=\"TIES\",\n        choices=[\"TIES\", \"DARE_LINEAR\", \"DARE_TIES\"],\n        help=\"The merge method to use. Options: ['TIES', 'DARE_LINEAR', 'DARE_TIES'].\",\n    )\n    parser.add_argument(\n        \"--density\",\n        type=float,\n        default=0.2,\n        help=\"The fraction of values to preserve in the prune/trim step of DARE/TIES methods. Should be in the range \"\n        \"[0, 1].\",\n    )\n    parser.add_argument(\n        \"--out-dir\",\n        type=str,\n        required=True,\n        help=\"The output directory where the merged model will be written (in diffusers format).\",\n    )\n    parser.add_argument(\n        \"--dtype\",\n        help=\"The torch dtype that will be used for all calculations and for the output model.\",\n        type=str,\n        default=\"float16\",\n        choices=[\"float32\", \"float16\", \"bfloat16\"],\n    )\n\n    args = parser.parse_args()\n\n    logging.basicConfig(level=logging.INFO)\n    logger = logging.getLogger()\n\n    base_model = parse_model_args([args.base_model], [1.0])[0]\n    task_models = parse_model_args(args.task_models, args.task_weights)\n    run_merge_models(\n        logger=logger,\n        model_type=PipelineVersionEnum(args.model_type),\n        base_model=base_model,\n        task_models=task_models,\n        method=args.method,\n        density=args.density,\n        out_dir=args.out_dir,\n        dtype=get_dtype_from_str(args.dtype),\n    )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/invoke_training/model_merge/utils/__init__.py",
    "content": ""
  },
  {
    "path": "src/invoke_training/model_merge/utils/normalize_weights.py",
    "content": "def normalize_weights(weights: list[float]) -> list[float]:\n    total = sum(weights)\n    return [weight / total for weight in weights]\n"
  },
  {
    "path": "src/invoke_training/model_merge/utils/parse_model_arg.py",
    "content": "def parse_model_arg(model: str, delimiter: str = \"::\") -> tuple[str, str | None]:\n    \"\"\"Parse a model argument into a model and a variant.\"\"\"\n    parts = model.split(delimiter)\n    if len(parts) == 1:\n        return parts[0], None\n    elif len(parts) == 2:\n        return parts[0], parts[1]\n    else:\n        raise ValueError(f\"Unexpected format for --models arg: '{model}'.\")\n"
  },
  {
    "path": "src/invoke_training/pipelines/__init__.py",
    "content": ""
  },
  {
    "path": "src/invoke_training/pipelines/_experimental/sd_dpo_lora/config.py",
    "content": "from typing import Annotated, Literal, Union\n\nfrom pydantic import Field, model_validator\n\nfrom invoke_training.config.base_pipeline_config import BasePipelineConfig\nfrom invoke_training.config.config_base_model import ConfigBaseModel\nfrom invoke_training.config.optimizer.optimizer_config import AdamOptimizerConfig, ProdigyOptimizerConfig\n\n\nclass HFHubImagePairPreferenceDatasetConfig(ConfigBaseModel):\n    type: Literal[\"HF_HUB_IMAGE_PAIR_PREFERENCE_DATASET\"] = \"HF_HUB_IMAGE_PAIR_PREFERENCE_DATASET\"\n\n    # TODO(ryand): Fill this out.\n\n\nclass ImagePairPreferenceDatasetConfig(ConfigBaseModel):\n    type: Literal[\"IMAGE_PAIR_PREFERENCE_DATASET\"] = \"IMAGE_PAIR_PREFERENCE_DATASET\"\n\n    dataset_dir: str\n    \"\"\"The directory to load the dataset from.\"\"\"\n\n\nclass ImagePairPreferenceSDDataLoaderConfig(ConfigBaseModel):\n    type: Literal[\"IMAGE_PAIR_PREFERENCE_SD_DATA_LOADER\"] = \"IMAGE_PAIR_PREFERENCE_SD_DATA_LOADER\"\n\n    dataset: Annotated[\n        Union[HFHubImagePairPreferenceDatasetConfig, ImagePairPreferenceDatasetConfig], Field(discriminator=\"type\")\n    ]\n\n    resolution: int | tuple[int, int] = 512\n    \"\"\"The resolution for input images. Either a scalar integer representing the square resolution height and width, or\n    a (height, width) tuple. All of the images in the dataset will be resized to this resolution unless the\n    `aspect_ratio_buckets` config is set.\n    \"\"\"\n\n    center_crop: bool = True\n    \"\"\"If True, input images will be center-cropped to the target resolution.\n    If False, input images will be randomly cropped to the target resolution.\n    \"\"\"\n\n    random_flip: bool = False\n    \"\"\"Whether random flip augmentations should be applied to input images.\n    \"\"\"\n\n    dataloader_num_workers: int = 0\n    \"\"\"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\n    \"\"\"\n\n\nclass SdDirectPreferenceOptimizationLoraConfig(BasePipelineConfig):\n    type: Literal[\"SD_DIRECT_PREFERENCE_OPTIMIZATION_LORA\"] = \"SD_DIRECT_PREFERENCE_OPTIMIZATION_LORA\"\n\n    model: str = \"runwayml/stable-diffusion-v1-5\"\n    \"\"\"Name or path of the base model to train. Can be in diffusers format, or a single stable diffusion checkpoint\n    file. (E.g. 'runwayml/stable-diffusion-v1-5', '/path/to/realisticVisionV51_v51VAE.safetensors', etc. )\n    \"\"\"\n\n    hf_variant: str | None = \"fp16\"\n    \"\"\"The Hugging Face Hub model variant to use. Only applies if `model` is a Hugging Face Hub model name.\n    \"\"\"\n\n    # Note: Pydantic handles mutable default values well:\n    # https://docs.pydantic.dev/latest/concepts/models/#fields-with-non-hashable-default-values\n    base_embeddings: dict[str, str] = {}\n    \"\"\"A mapping of embedding tokens to trained embedding file paths. These embeddings will be applied to the base model\n    before training.\n\n    Example:\n    ```\n    base_embeddings = {\n        \"bruce_the_gnome\": \"/path/to/bruce_the_gnome.safetensors\",\n    }\n    ```\n\n    Consider also adding the embedding tokens to the `data_loader.caption_prefix` if they are not already present in the\n    dataset captions.\n\n    Note that the embeddings themselves are not fine-tuned further, but they will impact the LoRA model training if they\n    are referenced in the dataset captions. The list of embeddings provided here should be the same list used at\n    generation time with the resultant LoRA model.\n    \"\"\"\n\n    lora_checkpoint_format: Literal[\"invoke_peft\", \"kohya\"] = \"kohya\"\n    \"\"\"The format of the LoRA checkpoint to save. Choose between `invoke_peft` or `kohya`.\"\"\"\n\n    train_unet: bool = True\n    \"\"\"Whether to add LoRA layers to the UNet model and train it.\n    \"\"\"\n\n    train_text_encoder: bool = True\n    \"\"\"Whether to add LoRA layers to the text encoder and train it.\n    \"\"\"\n\n    optimizer: AdamOptimizerConfig | ProdigyOptimizerConfig = AdamOptimizerConfig()\n\n    text_encoder_learning_rate: float | None = None\n    \"\"\"The learning rate to use for the text encoder model. If set, this overrides the optimizer's default learning\n    rate. Set to null or 0 to use the optimizer's default learning rate.\n    \"\"\"\n\n    unet_learning_rate: float | None = None\n    \"\"\"The learning rate to use for the UNet model. If set, this overrides the optimizer's default learning rate.\n    Set to null or 0 to use the optimizer's default learning rate.\n    \"\"\"\n\n    lr_scheduler: Literal[\n        \"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\", \"constant\", \"constant_with_warmup\"\n    ] = \"constant\"\n\n    lr_warmup_steps: int = 0\n    \"\"\"The number of warmup steps in the learning rate scheduler. Only applied to schedulers that support warmup.\n    See lr_scheduler.\n    \"\"\"\n\n    min_snr_gamma: float | None = 5.0\n    \"\"\"Min-SNR weighting for diffusion training was introduced in https://arxiv.org/abs/2303.09556. This strategy\n    improves the speed of training convergence by adjusting the weight of each sample.\n\n    `min_snr_gamma` acts like an an upper bound on the weight of samples with low noise levels.\n\n    If `None`, then Min-SNR weighting will not be applied. If enabled, the recommended value is `min_snr_gamma = 5.0`.\n    \"\"\"\n\n    lora_rank_dim: int = 4\n    \"\"\"The rank dimension to use for the LoRA layers. Increasing the rank dimension increases the model's expressivity,\n    but also increases the size of the generated LoRA model.\n    \"\"\"\n\n    cache_text_encoder_outputs: bool = False\n    \"\"\"If True, the text encoder(s) will be applied to all of the captions in the dataset before starting training and\n    the results will be cached to disk. This reduces the VRAM requirements during training (don't have to keep the\n    text encoders in VRAM), and speeds up training  (don't have to run the text encoders for each training example).\n    This option can only be enabled if `train_text_encoder == False` and there are no caption augmentations being\n    applied.\n    \"\"\"\n\n    cache_vae_outputs: bool = False\n    \"\"\"If True, the VAE will be applied to all of the images in the dataset before starting training and the results\n    will be cached to disk. This reduces the VRAM requirements during training (don't have to keep the VAE in VRAM), and\n    speeds up training (don't have to run the VAE encoding step). This option can only be enabled if all\n    non-deterministic image augmentations are disabled (i.e. center_crop=True, random_flip=False).\n    \"\"\"\n\n    enable_cpu_offload_during_validation: bool = False\n    \"\"\"If True, models will be kept in CPU memory and loaded into GPU memory one-by-one while generating validation\n    images. This reduces VRAM requirements at the cost of slower generation of validation images.\n    \"\"\"\n\n    gradient_accumulation_steps: int = 1\n    \"\"\"The number of gradient steps to accumulate before each weight update. This value is passed to Hugging Face\n    Accelerate. This is an alternative to increasing the batch size when training with limited VRAM.\n    \"\"\"\n\n    weight_dtype: Literal[\"float32\", \"float16\", \"bfloat16\"] = \"bfloat16\"\n    \"\"\"All weights (trainable and fixed) will be cast to this precision. Lower precision dtypes require less VRAM, and\n    result in faster training, but are more prone to issues with numerical stability.\n\n    Recommendations:\n\n    - `\"float32\"`: Use this mode if you have plenty of VRAM available.\n    - `\"bfloat16\"`: Use this mode if you have limited VRAM and a GPU that supports bfloat16.\n    - `\"float16\"`: Use this mode if you have limited VRAM and a GPU that does not support bfloat16.\n\n    See also [`mixed_precision`][invoke_training.pipelines._experimental.sd_dpo_lora.config.SdDirectPreferenceOptimizationLoraConfig.mixed_precision].\n    \"\"\"  # noqa: E501\n\n    mixed_precision: Literal[\"no\", \"fp16\", \"bf16\", \"fp8\"] = \"no\"\n    \"\"\"The mixed precision mode to use.\n\n    If mixed precision is enabled, then all non-trainable parameters will be cast to the specified `weight_dtype`, and\n    trainable parameters are kept in float32 precision to avoid issues with numerical stability.\n\n    This value is passed to Hugging Face Accelerate. See\n    [`accelerate.Accelerator.mixed_precision`](https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.mixed_precision)\n    for more details.\n    \"\"\"  # noqa: E501\n\n    xformers: bool = False\n    \"\"\"If true, use xformers for more efficient attention blocks.\n    \"\"\"\n\n    gradient_checkpointing: bool = False\n    \"\"\"Whether or not to use gradient checkpointing to save memory at the expense of a slower backward pass. Enabling\n    gradient checkpointing slows down training by ~20%.\n    \"\"\"\n\n    max_checkpoints: int | None = None\n    \"\"\"The maximum number of checkpoints to keep. New checkpoints will replace earlier checkpoints to stay under this\n    limit. Note that this limit is applied to 'step' and 'epoch' checkpoints separately.\n    \"\"\"\n\n    prediction_type: Literal[\"epsilon\", \"v_prediction\"] | None = None\n    \"\"\"The prediction_type that will be used for training. Choose between 'epsilon' or 'v_prediction' or leave 'None'.\n    If 'None', the prediction type of the scheduler: `noise_scheduler.config.prediction_type` is used.\n    \"\"\"\n\n    max_grad_norm: float | None = None\n    \"\"\"Max gradient norm for clipping. Set to null or 0 for no clipping.\n    \"\"\"\n\n    validation_prompts: list[str] = []\n    \"\"\"A list of prompts that will be used to generate images throughout training for the purpose of tracking progress.\n    See also 'validate_every_n_epochs'.\n    \"\"\"\n\n    negative_validation_prompts: list[str] | None = None\n    \"\"\"A list of negative prompts that will be applied when generating validation images. If set, this list should have\n    the same length as 'validation_prompts'.\n    \"\"\"\n\n    num_validation_images_per_prompt: int = 4\n    \"\"\"The number of validation images to generate for each prompt in 'validation_prompts'. Careful, validation can\n    become quite slow if this number is too large.\n    \"\"\"\n\n    train_batch_size: int = 4\n    \"\"\"The training batch size.\n    \"\"\"\n\n    data_loader: ImagePairPreferenceSDDataLoaderConfig\n\n    initial_lora: str | None = None\n    \"\"\"The LoRA checkpoint directory to initialize the LoRA weights from.\n\n    If set, the following configuration parameters are ignored:\n    - `train_unet`: The UNet will be trained if it is present in `initial_lora`.\n    - `train_text_encoder`: The text encoder will be trained if it is present in `initial_lora`.\n    - `lora_rank_dim`: The LoRA rank dimension from `initial_lora` will be used.\n\n    Currently only LoRA checkpoints in the internal `invoke-training` PEFT format are supported (i.e. checkpoints\n    generated by an `invoke-training` training pipeline).\n    \"\"\"\n\n    beta: float = 5000.0\n    \"\"\"The beta parameter, as defined in (https://arxiv.org/pdf/2311.12908.pdf). Larger beta values increase the\n    KL-Divergence penalty, discouraging divergence from the reference model weights.\n\n    Typical values for `beta` are in the range [1000.0, 10000.0].\n    \"\"\"\n\n    @model_validator(mode=\"after\")\n    def check_validation_prompts(self):\n        if self.negative_validation_prompts is not None and len(self.negative_validation_prompts) != len(\n            self.validation_prompts\n        ):\n            raise ValueError(\n                f\"The number of validation_prompts ({len(self.validation_prompts)}) must match the number of \"\n                f\"negative_validation_prompts ({len(self.negative_validation_prompts)}).\"\n            )\n        return self\n"
  },
  {
    "path": "src/invoke_training/pipelines/_experimental/sd_dpo_lora/train.py",
    "content": "import copy\nimport itertools\nimport json\nimport logging\nimport math\nimport os\nimport tempfile\nimport time\nfrom pathlib import Path\nfrom typing import Literal\n\nimport peft\nimport torch\nimport torch.utils.data\nfrom accelerate.utils import set_seed\nfrom diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel\nfrom diffusers.optimization import get_scheduler\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPTextModel, CLIPTokenizer\n\nfrom invoke_training._shared.accelerator.accelerator_utils import (\n    get_dtype_from_str,\n    initialize_accelerator,\n    initialize_logging,\n)\nfrom invoke_training._shared.checkpoints.checkpoint_tracker import CheckpointTracker\nfrom invoke_training._shared.data.data_loaders.image_pair_preference_sd_dataloader import (\n    build_image_pair_preference_sd_dataloader,\n)\nfrom invoke_training._shared.optimizer.optimizer_utils import initialize_optimizer\nfrom invoke_training._shared.stable_diffusion.lora_checkpoint_utils import (\n    TEXT_ENCODER_TARGET_MODULES,\n    UNET_TARGET_MODULES,\n    load_sd_peft_checkpoint,\n    save_sd_kohya_checkpoint,\n    save_sd_peft_checkpoint,\n)\nfrom invoke_training._shared.stable_diffusion.model_loading_utils import load_models_sd\nfrom invoke_training._shared.stable_diffusion.tokenize_captions import tokenize_captions\nfrom invoke_training._shared.stable_diffusion.validation import generate_validation_images_sd\nfrom invoke_training._shared.utils.import_xformers import import_xformers\nfrom invoke_training.pipelines._experimental.sd_dpo_lora.config import SdDirectPreferenceOptimizationLoraConfig\nfrom invoke_training.pipelines.callbacks import PipelineCallbacks\nfrom invoke_training.pipelines.stable_diffusion.lora.train import cache_text_encoder_outputs\n\n\ndef _save_sd_lora_checkpoint(\n    epoch: int,\n    step: int,\n    unet: peft.PeftModel | None,\n    text_encoder: peft.PeftModel | None,\n    logger: logging.Logger,\n    checkpoint_tracker: CheckpointTracker,\n    lora_checkpoint_format: Literal[\"invoke_peft\", \"kohya\"],\n):\n    # Prune checkpoints and get new checkpoint path.\n    num_pruned = checkpoint_tracker.prune(1)\n    if num_pruned > 0:\n        logger.info(f\"Pruned {num_pruned} checkpoint(s).\")\n    save_path = checkpoint_tracker.get_path(epoch=epoch, step=step)\n\n    if lora_checkpoint_format == \"invoke_peft\":\n        save_sd_peft_checkpoint(Path(save_path), unet=unet, text_encoder=text_encoder)\n    elif lora_checkpoint_format == \"kohya\":\n        save_sd_kohya_checkpoint(Path(save_path), unet=unet, text_encoder=text_encoder)\n    else:\n        raise ValueError(f\"Unsupported lora_checkpoint_format: '{lora_checkpoint_format}'.\")\n\n\ndef train_forward_dpo(  # noqa: C901\n    config: SdDirectPreferenceOptimizationLoraConfig,\n    data_batch: dict,\n    vae: AutoencoderKL,\n    noise_scheduler: DDPMScheduler,\n    tokenizer: CLIPTokenizer,\n    text_encoder: CLIPTextModel,\n    unet: UNet2DConditionModel,\n    ref_text_encoder: CLIPTextModel,\n    ref_unet: UNet2DConditionModel,\n    weight_dtype: torch.dtype,\n) -> torch.Tensor:\n    \"\"\"Run the forward training pass for a single data_batch.\n\n    This forward pass is based on 'Diffusion Model Alignment Using Direct Preference Optimization'\n    (https://arxiv.org/pdf/2311.12908.pdf). See the \"Pseudocode for Training Objective\" Appendix section for a helpful\n    reference.\n\n    Returns:\n        torch.Tensor: Loss\n    \"\"\"\n    batch_size = data_batch[\"image_0\"].shape[0]\n\n    # Concatenate image_0 and image_1 images into a single image batch.\n    images = torch.concat((data_batch[\"image_0\"], data_batch[\"image_1\"]))\n\n    # Re-order images so that the 'images' batch contains all winner images followed by all loser images.\n    w_indices = []\n    l_indices = []\n    prefer_0 = data_batch[\"prefer_0\"]\n    prefer_1 = data_batch[\"prefer_1\"]\n    for i in range(batch_size):\n        if prefer_0[i] and not prefer_1[i]:\n            w_indices.append(i)\n            l_indices.append(i + batch_size)\n        elif not prefer_0[i] and prefer_1[i]:\n            w_indices.append(i + batch_size)\n            l_indices.append(i)\n        else:\n            raise ValueError(f\"Encountered image pair with prefer_0={prefer_0[i]} and prefer_1={prefer_1[i]}.\")\n    images = images[w_indices + l_indices]\n\n    # Update batch_size in case image pairs were filtered due to no-preference.\n    batch_size = images.shape[0] // 2\n\n    # Convert images to latent space.\n    # The VAE output may have been cached and included in the data_batch. If not, we calculate it here.\n    latents = data_batch.get(\"vae_output\", None)\n    if latents is None:\n        latents = vae.encode(images.to(dtype=weight_dtype)).latent_dist.sample()\n        latents = latents * vae.config.scaling_factor\n\n    # Sample noise that we'll add to the latents.\n    # We want to use the same noise for the winning and losing example in each pair, so we generate noise for the\n    # winning latents and then repeat it.\n    noise = torch.randn_like(latents[:batch_size])\n    noise = noise.repeat((2, 1, 1, 1))\n\n    # Sample a random timestep for each image **pair**.\n    timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (batch_size,), device=latents.device)\n    timesteps = timesteps.repeat((2,)).long()\n\n    # Add noise to the latents according to the noise magnitude at each timestep (this is the forward\n    # diffusion process).\n    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n\n    # Get the text embedding for conditioning (for both the text_encoder and ref_text_encoder).\n    # The text_encoder_output may have been cached and included in the data_batch. If not, we calculate it here.\n    encoder_hidden_states = data_batch.get(\"text_encoder_output\", None)\n    if encoder_hidden_states is None:\n        caption_token_ids = tokenize_captions(tokenizer, data_batch[\"caption\"]).to(text_encoder.device)\n        encoder_hidden_states = text_encoder(caption_token_ids)[0].to(dtype=weight_dtype)\n        ref_encoder_hidden_states = ref_text_encoder(caption_token_ids)[0].to(dtype=weight_dtype)\n    encoder_hidden_states = encoder_hidden_states.repeat((2, 1, 1))\n    ref_encoder_hidden_states = ref_encoder_hidden_states.repeat((2, 1, 1))\n\n    # Get the target for loss depending on the prediction type.\n    if config.prediction_type is not None:\n        # Set the prediction_type of scheduler if it's defined in config.\n        noise_scheduler.register_to_config(prediction_type=config.prediction_type)\n    if noise_scheduler.config.prediction_type == \"epsilon\":\n        target = noise\n    elif noise_scheduler.config.prediction_type == \"v_prediction\":\n        target = noise_scheduler.get_velocity(latents, noise, timesteps)\n    else:\n        raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n    # Predict the noise residual.\n    ref_model_pred: torch.Tensor = ref_unet(noisy_latents, timesteps, ref_encoder_hidden_states).sample\n    model_pred: torch.Tensor = unet(noisy_latents, timesteps, encoder_hidden_states).sample\n\n    if \"loss_weight\" in data_batch:\n        raise NotImplementedError(\"loss_weight is not yet supported.\")\n\n    target = target.float()\n    w_target = target[:batch_size]\n    l_target = target[batch_size:]\n    model_w_pred = model_pred[:batch_size]\n    model_l_pred = model_pred[batch_size:]\n    ref_w_pred = ref_model_pred[:batch_size]\n    ref_l_pred = ref_model_pred[batch_size:]\n\n    # The pseudo-code from the paper uses `.norm().pow(2)` to calculate the errors. We take the mean over all pixels\n    # rather than the sum over all pixels instead. This helps keep the learning rate stable across different image\n    # resolutions. It also means that the the recommended settings for beta from the paper are not correct.\n    # > model_w_err = (model_w_pred - target).norm().pow(2)\n    # > model_l_err = (model_l_pred - target).norm().pow(2)\n    # > ref_w_err = (ref_w_pred - target).norm().pow(2)\n    # > ref_l_err = (ref_l_pred - target).norm().pow(2)\n    model_w_err = torch.nn.functional.mse_loss(model_w_pred, w_target)\n    model_l_err = torch.nn.functional.mse_loss(model_l_pred, l_target)\n    ref_w_err = torch.nn.functional.mse_loss(ref_w_pred, w_target)\n    ref_l_err = torch.nn.functional.mse_loss(ref_l_pred, l_target)\n\n    w_diff = model_w_err - ref_w_err\n    l_diff = model_l_err - ref_l_err\n    inside_term = -1 * config.beta * (w_diff - l_diff)\n    loss = -1 * torch.nn.functional.logsigmoid(inside_term)\n    return loss\n\n\ndef train(config: SdDirectPreferenceOptimizationLoraConfig, callbacks: list[PipelineCallbacks] | None = None):  # noqa: C901\n    if callbacks:\n        raise ValueError(f\"This pipeline does not support callbacks, but {len(callbacks)} were provided.\")\n\n    # Give a clear error message if an unsupported base model was chosen.\n    # TODO(ryan): Update this check to work with single-file SD checkpoints.\n    # check_base_model_version(\n    #     {BaseModelVersionEnum.STABLE_DIFFUSION_V1, BaseModelVersionEnum.STABLE_DIFFUSION_V2},\n    #     config.model,\n    #     local_files_only=False,\n    # )\n\n    # Create a timestamped directory for all outputs.\n    out_dir = os.path.join(config.base_output_dir, f\"{time.time()}\")\n    ckpt_dir = os.path.join(out_dir, \"checkpoints\")\n    os.makedirs(ckpt_dir)\n\n    accelerator = initialize_accelerator(\n        out_dir, config.gradient_accumulation_steps, config.mixed_precision, config.report_to\n    )\n    logger = initialize_logging(os.path.basename(__file__), accelerator)\n\n    # Set the accelerate seed.\n    if config.seed is not None:\n        set_seed(config.seed)\n\n    # Log the accelerator configuration from every process to help with debugging.\n    logger.info(accelerator.state, main_process_only=False)\n\n    logger.info(\"Starting LoRA Training.\")\n    logger.info(f\"Configuration:\\n{json.dumps(config.dict(), indent=2, default=str)}\")\n    logger.info(f\"Output dir: '{out_dir}'\")\n\n    # Write the configuration to disk.\n    with open(os.path.join(out_dir, \"config.json\"), \"w\") as f:\n        json.dump(config.dict(), f, indent=2, default=str)\n\n    weight_dtype = get_dtype_from_str(config.weight_dtype)\n\n    logger.info(\"Loading models.\")\n    tokenizer, noise_scheduler, text_encoder, vae, unet = load_models_sd(\n        logger=logger,\n        model_name_or_path=config.model,\n        hf_variant=config.hf_variant,\n        base_embeddings=config.base_embeddings,\n        dtype=weight_dtype,\n    )\n    ref_text_encoder = copy.deepcopy(text_encoder)\n    ref_unet = copy.deepcopy(unet)\n\n    if config.xformers:\n        import_xformers()\n\n        # TODO(ryand): There is a known issue if xformers is enabled when training in mixed precision where xformers\n        # will fail because Q, K, V have different dtypes.\n        unet.enable_xformers_memory_efficient_attention()\n        ref_unet.enable_xformers_memory_efficient_attention()\n        vae.enable_xformers_memory_efficient_attention()\n\n    # Prepare text encoder output cache.\n    text_encoder_output_cache_dir_name = None\n    if config.cache_text_encoder_outputs:\n        # TODO(ryand): Think about how to better check if it is safe to cache the text encoder outputs. Currently, there\n        # are a number of configurations that would cause variation in the text encoder outputs and should not be used\n        # with caching.\n        # TODO(ryand): This check does not make sense when config.initial_lora is set.\n        if config.train_text_encoder:\n            raise ValueError(\"'cache_text_encoder_outputs' and 'train_text_encoder' cannot both be True.\")\n\n        # We use a temporary directory for the cache. The directory will automatically be cleaned up when\n        # tmp_text_encoder_output_cache_dir is destroyed.\n        tmp_text_encoder_output_cache_dir = tempfile.TemporaryDirectory()\n        text_encoder_output_cache_dir_name = tmp_text_encoder_output_cache_dir.name\n        if accelerator.is_local_main_process:\n            # Only the main process should populate the cache.\n            logger.info(f\"Generating text encoder output cache ('{text_encoder_output_cache_dir_name}').\")\n            text_encoder.to(accelerator.device, dtype=weight_dtype)\n            cache_text_encoder_outputs(text_encoder_output_cache_dir_name, config, tokenizer, text_encoder)\n        # Move the text_encoder back to the CPU, because it is not needed for training.\n        text_encoder.to(\"cpu\")\n        accelerator.wait_for_everyone()\n    else:\n        text_encoder.to(accelerator.device, dtype=weight_dtype)\n        ref_text_encoder.to(accelerator.device, dtype=weight_dtype)\n\n    # Prepare VAE output cache.\n    vae_output_cache_dir_name = None\n    if config.cache_vae_outputs:\n        raise NotImplementedError(\"VAE caching is not implemented for Diffusion-DPO training yet.\")\n        # # We use a temporary directory for the cache. The directory will automatically be cleaned up when\n        # # tmp_vae_output_cache_dir is destroyed.\n        # tmp_vae_output_cache_dir = tempfile.TemporaryDirectory()\n        # vae_output_cache_dir_name = tmp_vae_output_cache_dir.name\n        # if accelerator.is_local_main_process:\n        #     # Only the main process should populate the cache.\n        #     logger.info(f\"Generating VAE output cache ('{vae_output_cache_dir_name}').\")\n        #     vae.to(accelerator.device, dtype=weight_dtype)\n\n        #     data_loader = build_data_loader(\n        #         data_loader_config=config.data_loader,\n        #         batch_size=config.train_batch_size,\n        #         shuffle=False,\n        #         sequential_batching=True,\n        #     )\n        #     cache_vae_outputs(vae_output_cache_dir_name, data_loader, vae)\n        # # Move the VAE back to the CPU, because it is not needed for training.\n        # vae.to(\"cpu\")\n        # accelerator.wait_for_everyone()\n    else:\n        vae.to(accelerator.device, dtype=weight_dtype)\n\n    unet.to(accelerator.device, dtype=weight_dtype)\n    ref_unet.to(accelerator.device, dtype=weight_dtype)\n\n    # Add LoRA layers to the models being trained.\n    trainable_param_groups = []\n    all_trainable_models: list[peft.PeftModel] = []\n\n    # Add LoRA layers to the model.\n    trainable_param_groups = []\n    if config.initial_lora is not None:\n        unet, text_encoder = load_sd_peft_checkpoint(\n            checkpoint_dir=config.initial_lora, unet=unet, text_encoder=text_encoder, is_trainable=True\n        )\n        ref_unet, ref_text_encoder = load_sd_peft_checkpoint(\n            checkpoint_dir=config.initial_lora, unet=ref_unet, text_encoder=ref_text_encoder, is_trainable=False\n        )\n    else:\n        if config.train_unet:\n            unet_lora_config = peft.LoraConfig(\n                r=config.lora_rank_dim,\n                # TODO(ryand): Diffusers uses lora_alpha=config.lora_rank_dim. Is that preferred?\n                lora_alpha=1.0,\n                target_modules=UNET_TARGET_MODULES,\n            )\n            unet = peft.get_peft_model(unet, unet_lora_config)\n\n        if config.train_text_encoder:\n            text_encoder_lora_config = peft.LoraConfig(\n                r=config.lora_rank_dim,\n                lora_alpha=1.0,\n                # init_lora_weights=\"gaussian\",\n                target_modules=TEXT_ENCODER_TARGET_MODULES,\n            )\n            text_encoder = peft.get_peft_model(text_encoder, text_encoder_lora_config)\n\n    def prep_peft_model(model, lr: float | None = None):\n        if not isinstance(model, peft.PeftModel):\n            return False\n\n        model.print_trainable_parameters()\n\n        # Populate `trainable_param_groups`, to be passed to the optimizer.\n        param_group = {\"params\": list(filter(lambda p: p.requires_grad, model.parameters()))}\n        if lr is not None:\n            param_group[\"lr\"] = lr\n        trainable_param_groups.append(param_group)\n\n        # Populate all_trainable_models.\n        all_trainable_models.append(model)\n\n        model.train()\n\n        return True\n\n    training_unet = prep_peft_model(unet, config.unet_learning_rate)\n    training_text_encoder = prep_peft_model(text_encoder, config.text_encoder_learning_rate)\n\n    # If mixed_precision is enabled, cast all trainable params to float32.\n    if config.mixed_precision != \"no\":\n        for trainable_model in all_trainable_models:\n            for param in trainable_model.parameters():\n                if param.requires_grad:\n                    param.data = param.to(torch.float32)\n\n    if config.gradient_checkpointing:\n        # We want to enable gradient checkpointing in the UNet regardless of whether it is being trained.\n        unet.enable_gradient_checkpointing()\n        # unet must be in train() mode for gradient checkpointing to take effect.\n        # At the time of writing, the unet dropout probabilities default to 0, so putting the unet in train mode does\n        # not change its forward behavior.\n        unet.train()\n        if training_text_encoder:\n            text_encoder.gradient_checkpointing_enable()\n\n            # The text encoder must be in train() mode for gradient checkpointing to take effect. This should\n            # already be the case, since we are training the text_encoder, but we do it explicitly to make it clear\n            # that this is required.\n            # At the time of writing, the text encoder dropout probabilities default to 0, so putting the text\n            # encoders in train mode does not change their forward behavior.\n            text_encoder.train()\n\n            # Set requires_grad = True on the first parameters of the text encoders. Without this, the text encoder\n            # LoRA weights would have 0 gradients, and so would not get trained. Note that the set of\n            # trainable_param_groups has already been populated - the embeddings will not be trained.\n            text_encoder.text_model.embeddings.requires_grad_(True)\n\n    optimizer = initialize_optimizer(config.optimizer, trainable_param_groups)\n\n    data_loader = build_image_pair_preference_sd_dataloader(\n        config=config.data_loader,\n        batch_size=config.train_batch_size,\n        text_encoder_output_cache_dir=text_encoder_output_cache_dir_name,\n        text_encoder_cache_field_to_output_field={\"text_encoder_output\": \"text_encoder_output\"},\n        vae_output_cache_dir=vae_output_cache_dir_name,\n        shuffle=True,\n    )\n\n    # TODO(ryand): Test in a distributed training environment and more clearly document the rationale for scaling steps\n    # by the number of processes. This scaling logic was copied from the diffusers example training code, but it appears\n    # in many places so I don't know where it originated. Internally, accelerate makes one LR scheduler step per process\n    # (https://github.com/huggingface/accelerate/blame/49cb83a423f2946059117d8bb39b7c8747d29d80/src/accelerate/scheduler.py#L72-L82),\n    # so the scaling here simply reverses that behaviour.\n    lr_scheduler: torch.optim.lr_scheduler.LRScheduler = get_scheduler(\n        config.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=config.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=config.max_train_steps * accelerator.num_processes,\n    )\n\n    prepared_result: tuple[\n        UNet2DConditionModel | peft.PeftModel,\n        CLIPTextModel | peft.PeftModel,\n        torch.optim.Optimizer,\n        torch.utils.data.DataLoader,\n        torch.optim.lr_scheduler.LRScheduler,\n    ] = accelerator.prepare(\n        unet,\n        text_encoder,\n        optimizer,\n        data_loader,\n        lr_scheduler,\n        # Disable automatic device placement for text_encoder if the text encoder outputs were cached.\n        device_placement=[True, not config.cache_text_encoder_outputs, True, True, True],\n    )\n    unet, text_encoder, optimizer, data_loader, lr_scheduler = prepared_result\n\n    # Calculate the number of epochs and total training steps. A \"step\" represents a single weight update operation\n    # (i.e. takes into account gradient accumulation steps).\n    # math.ceil(...) is used in calculating the num_steps_per_epoch, because by default an optimizer step is taken when\n    # the end of the dataloader is reached, even if gradient_accumulation_steps hasn't been reached.\n    num_steps_per_epoch = math.ceil(len(data_loader) / config.gradient_accumulation_steps)\n    num_train_epochs = math.ceil(config.max_train_steps / num_steps_per_epoch)\n\n    if accelerator.is_main_process:\n        accelerator.init_trackers(\"lora_training\")\n        # Tensorboard uses markdown formatting, so we wrap the config json in a code block.\n        accelerator.log({\"configuration\": f\"```json\\n{json.dumps(config.dict(), indent=2, default=str)}\\n```\\n\"})\n\n    checkpoint_tracker = CheckpointTracker(\n        base_dir=ckpt_dir,\n        prefix=\"checkpoint\",\n        extension=\".safetensors\" if config.lora_checkpoint_format == \"kohya\" else None,\n        max_checkpoints=config.max_checkpoints,\n    )\n\n    # Train!\n    total_batch_size = config.train_batch_size * accelerator.num_processes * config.gradient_accumulation_steps\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(data_loader)}\")\n    logger.info(f\"  Instantaneous batch size per device = {config.train_batch_size}\")\n    logger.info(f\"  Gradient accumulation steps = {config.gradient_accumulation_steps}\")\n    logger.info(f\"  Parallel processes = {accelerator.num_processes}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Total optimization steps = {config.max_train_steps}\")\n\n    global_step = 0\n    first_epoch = 0\n    completed_epochs = first_epoch\n\n    progress_bar = tqdm(\n        range(global_step, config.max_train_steps),\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n    progress_bar.set_description(\"Steps\")\n\n    for epoch in range(first_epoch, num_train_epochs):\n        train_loss = 0.0\n        for data_batch_idx, data_batch in enumerate(data_loader):\n            with accelerator.accumulate(unet, text_encoder):\n                loss = train_forward_dpo(\n                    config=config,\n                    data_batch=data_batch,\n                    vae=vae,\n                    noise_scheduler=noise_scheduler,\n                    tokenizer=tokenizer,\n                    text_encoder=text_encoder,\n                    unet=unet,\n                    ref_text_encoder=ref_text_encoder,\n                    ref_unet=ref_unet,\n                    weight_dtype=weight_dtype,\n                )\n\n                # Gather the losses across all processes for logging (if we use distributed training).\n                # TODO(ryand): Test that this works properly with distributed training.\n                avg_loss = accelerator.gather(loss.repeat(config.train_batch_size)).mean()\n                train_loss += avg_loss.item() / config.gradient_accumulation_steps\n\n                # Backpropagate.\n                accelerator.backward(loss)\n                if accelerator.sync_gradients and config.max_grad_norm is not None:\n                    params_to_clip = itertools.chain.from_iterable([m.parameters() for m in all_trainable_models])\n                    accelerator.clip_grad_norm_(params_to_clip, config.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad(set_to_none=True)\n\n            # Checks if the accelerator has performed an optimization step behind the scenes.\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n                completed_epochs = epoch if (data_batch_idx + 1) < len(data_loader) else epoch + 1\n                log = {\"train_loss\": train_loss}\n\n                lrs = lr_scheduler.get_last_lr()\n                if training_unet:\n                    # When training the UNet, it will always be the first parameter group.\n                    log[\"lr/unet\"] = float(lrs[0])\n                    if config.optimizer.optimizer_type == \"Prodigy\":\n                        log[\"lr/d*lr/unet\"] = optimizer.param_groups[0][\"d\"] * optimizer.param_groups[0][\"lr\"]\n                if training_text_encoder:\n                    # When training the text encoder, it will always be the last parameter group.\n                    log[\"lr/text_encoder\"] = float(lrs[-1])\n                    if config.optimizer.optimizer_type == \"Prodigy\":\n                        log[\"lr/d*lr/text_encoder\"] = optimizer.param_groups[-1][\"d\"] * optimizer.param_groups[-1][\"lr\"]\n\n                accelerator.log(log, step=global_step)\n                train_loss = 0.0\n\n                # global_step represents the *number of completed steps* at this point.\n                if config.save_every_n_steps is not None and global_step % config.save_every_n_steps == 0:\n                    accelerator.wait_for_everyone()\n                    if accelerator.is_main_process:\n                        _save_sd_lora_checkpoint(\n                            epoch=completed_epochs,\n                            step=global_step,\n                            unet=accelerator.unwrap_model(unet) if training_unet else None,\n                            text_encoder=accelerator.unwrap_model(text_encoder) if training_text_encoder else None,\n                            logger=logger,\n                            checkpoint_tracker=checkpoint_tracker,\n                            lora_checkpoint_format=config.lora_checkpoint_format,\n                        )\n\n            logs = {\n                \"step_loss\": loss.detach().item(),\n                \"lr\": lr_scheduler.get_last_lr()[0],\n            }\n            progress_bar.set_postfix(**logs)\n\n            if global_step >= config.max_train_steps:\n                break\n\n        # Save a checkpoint every n epochs.\n        if config.save_every_n_epochs is not None and completed_epochs % config.save_every_n_epochs == 0:\n            if accelerator.is_main_process:\n                accelerator.wait_for_everyone()\n                _save_sd_lora_checkpoint(\n                    epoch=completed_epochs,\n                    step=global_step,\n                    unet=accelerator.unwrap_model(unet) if training_unet else None,\n                    text_encoder=accelerator.unwrap_model(text_encoder) if training_text_encoder else None,\n                    logger=logger,\n                    checkpoint_tracker=checkpoint_tracker,\n                    lora_checkpoint_format=config.lora_checkpoint_format,\n                )\n\n        # Generate validation images every n epochs.\n        if len(config.validation_prompts) > 0 and completed_epochs % config.validate_every_n_epochs == 0:\n            if accelerator.is_main_process:\n                generate_validation_images_sd(\n                    epoch=completed_epochs,\n                    step=global_step,\n                    out_dir=out_dir,\n                    accelerator=accelerator,\n                    vae=vae,\n                    text_encoder=text_encoder,\n                    tokenizer=tokenizer,\n                    noise_scheduler=noise_scheduler,\n                    unet=unet,\n                    config=config,\n                    logger=logger,\n                )\n\n    accelerator.end_training()\n"
  },
  {
    "path": "src/invoke_training/pipelines/callbacks.py",
    "content": "from abc import ABC\nfrom enum import Enum\n\n\nclass ModelType(Enum):\n    # At first glance, it feels like these model types should be further broken down into separate enums (e.g.\n    # base_model, model_type, checkpoint_format). But, I haven't yet come up with a taxonomy that feels sufficiently\n    # future-proof. So, for now, there is one enum for each file type that invoke-training can produce.\n\n    # A Flux LoRA model in PEFT format.\n    FLUX_LORA_PEFT = \"FLUX_LORA_PEFT\"\n    # A Flux LoRA model in Kohya format.\n    FLUX_LORA_KOHYA = \"FLUX_LORA_KOHYA\"\n\n    # A Stable Diffusion 1.x LoRA model in Kohya format.\n    SD1_LORA_KOHYA = \"SD1_LORA_KOHYA\"\n    # A Stable Diffusion 1.x LoRA model in PEFT format.\n    SD1_LORA_PEFT = \"SD1_LORA_PEFT\"\n    # A Stable Diffusion XL LoRA model in Kohya format.\n    SDXL_LORA_KOHYA = \"SDXL_LORA_KOHYA\"\n    # A Stable Diffusion XL LoRA model in PEFT format.\n    SDXL_LORA_PEFT = \"SDXL_LORA_PEFT\"\n\n    # A Stable Diffusion 1.x Textual Inversion model.\n    SD1_TEXTUAL_INVERSION = \"SD1_TEXTUAL_INVERSION\"\n    # A Stable Diffusion XL Textual Inversion model.\n    SDXL_TEXTUAL_INVERSION = \"SDXL_TEXTUAL_INVERSION\"\n\n    # A Stable Diffusion 1.x UNet checkpoint in diffusers format.\n    SD1_UNET_DIFFUSERS = \"SD1_UNET_DIFFUSERS\"\n    # A Stable Diffusion XL UNet checkpoint in diffusers format.\n    SDXL_UNET_DIFFUSERS = \"SDXL_UNET_DIFFUSERS\"\n    # A full Stable Diffusion XL checkpoint in diffusers format.\n    SDXL_FULL_DIFFUSERS = \"SDXL_FULL_DIFFUSERS\"\n\n\nclass ModelCheckpoint:\n    \"\"\"A single model checkpoint.\"\"\"\n\n    def __init__(self, file_path: str, model_type: ModelType):\n        self.file_path = file_path\n        self.model_type = model_type\n\n\nclass TrainingCheckpoint:\n    \"\"\"A training checkpoint. May contain multiple model checkpoints if multiple models are being trained\n    simultaneously.\n    \"\"\"\n\n    def __init__(self, models: list[ModelCheckpoint], epoch: int, step: int):\n        self.models = models\n        self.epoch = epoch\n        self.step = step\n\n\nclass ValidationImage:\n    def __init__(self, file_path: str, prompt: str, image_idx: int):\n        \"\"\"A single validation image.\n\n        Args:\n            file_path (str): Path to the image file.\n            prompt (str): The prompt used to generate the image.\n            image_idx (int): The index of this image in the current validation set (i.e. in the set of images generated\n                with the same prompt at the same validation point).\n        \"\"\"\n        self.file_path = file_path\n        self.prompt = prompt\n        self.image_idx = image_idx\n\n\nclass ValidationImages:\n    def __init__(self, images: list[ValidationImage], epoch: int, step: int):\n        \"\"\"A collection of validation images.\n\n        Args:\n            images (list[ValidationImage]): The validation images.\n            epoch (int): The last completed epoch at the time that these images were generated.\n            step (int): The last completed training step at the time that these images were generated.\n        \"\"\"\n        self.images = images\n        self.epoch = epoch\n        self.step = step\n\n\nclass PipelineCallbacks(ABC):\n    def on_save_checkpoint(self, checkpoint: TrainingCheckpoint):\n        pass\n\n    def on_save_validation_images(self, images: ValidationImages):\n        pass\n"
  },
  {
    "path": "src/invoke_training/pipelines/flux/lora/__init__.py",
    "content": ""
  },
  {
    "path": "src/invoke_training/pipelines/flux/lora/config.py",
    "content": "from typing import Annotated, Literal, Union\n\nfrom pydantic import Field\n\nfrom invoke_training._shared.flux.lora_checkpoint_utils import (\n    FLUX_TRANSFORMER_TARGET_MODULES,\n    TEXT_ENCODER_TARGET_MODULES,\n)\nfrom invoke_training.config.base_pipeline_config import BasePipelineConfig\nfrom invoke_training.config.data.data_loader_config import ImageCaptionFluxDataLoaderConfig\nfrom invoke_training.config.optimizer.optimizer_config import (\n    AdamOptimizerConfig,\n    ProdigyOptimizerConfig,\n)\n\n\nclass FluxLoraConfig(BasePipelineConfig):\n    type: Literal[\"FLUX_LORA\"] = \"FLUX_LORA\"\n\n    model: str = \"black-forest-labs/FLUX.1-dev\"\n    \"\"\"Name or path of the base model to train. Can be in diffusers format, or a single Flux.1-dev checkpoint\n    file. (E.g. 'black-forest-labs/FLUX.1-dev', '/path/to/flux.1-dev.safetensors', etc. )\n    \"\"\"\n\n    transformer_path: str | None = None\n    \"\"\"Path to the custom transformer .safetensors file. If not provided, the default black-forest-labs/FLUX.1-dev\n    transformer will be used.\n    \"\"\"\n\n    text_encoder_1_path: str | None = None\n    \"\"\"Path to the custom CLIP text encoder .safetensors file. If not provided, the default openai/clip-vit-base-patch32\n    text encoder will be used.\n    \"\"\"\n\n    text_encoder_2_path: str | None = None\n    \"\"\"Path to the custom T5 text encoder .safetensors file. If not provided, the default google/t5-v1_1-xl text encoder\n     will be used.\n     \"\"\"\n\n    lora_checkpoint_format: Literal[\"invoke_peft\", \"kohya\"] = \"kohya\"\n    \"\"\"The format of the LoRA checkpoint to save. Choose between `invoke_peft` or `kohya`.\"\"\"\n\n    train_transformer: bool = True\n    \"\"\"Whether to add LoRA layers to the FluxTransformer2DModel and train it.\n    \"\"\"\n\n    train_text_encoder: bool = False\n    \"\"\"Whether to add LoRA layers to the text encoder and train it.\n    \"\"\"\n\n    optimizer: AdamOptimizerConfig | ProdigyOptimizerConfig = AdamOptimizerConfig()\n\n    text_encoder_learning_rate: float | None = 1e-4\n    \"\"\"The learning rate to use for the text encoder model. If set, this overrides the optimizer's default learning\n    rate. Set to null or 0 to use the optimizer's default learning rate.\n    \"\"\"\n\n    transformer_learning_rate: float | None = 4e-4\n    \"\"\"The learning rate to use for the transformer model. If set, this overrides the optimizer's default learning\n    rate. Set to null or 0 to use the optimizer's default learning rate.\n    \"\"\"\n\n    lr_scheduler: Literal[\n        \"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\", \"constant\", \"constant_with_warmup\"\n    ] = \"constant_with_warmup\"\n\n    lr_warmup_steps: int = 10\n    \"\"\"The number of warmup steps in the learning rate scheduler. Only applied to schedulers that support warmup.\n    See lr_scheduler.\n    \"\"\"\n\n    min_snr_gamma: float | None = None\n    \"\"\"Min-SNR weighting for diffusion training was introduced in https://arxiv.org/abs/2303.09556. This strategy\n    improves the speed of training convergence by adjusting the weight of each sample.\n\n    `min_snr_gamma` acts like an an upper bound on the weight of samples with low noise levels.\n\n    If `None`, then Min-SNR weighting will not be applied. If enabled, the recommended value is `min_snr_gamma = 5.0`.\n    \"\"\"\n\n    lora_rank_dim: int = 4\n    \"\"\"The rank dimension to use for the LoRA layers. Increasing the rank dimension increases the model's expressivity,\n    but also increases the size of the generated LoRA model.\n    \"\"\"\n\n    flux_lora_target_modules: list[str] = FLUX_TRANSFORMER_TARGET_MODULES\n    \"\"\"The list of target modules to apply LoRA layers to in the FluxTransformer2DModel. The default list will produce a\n    highly expressive LoRA model.\n\n    For a smaller and less expressive LoRA model, the following list is recommended:\n    ```python\n    flux_lora_target_modules = [\"to_k\", \"to_q\", \"to_v\", \"to_out.0\"]\n    ```\n\n    The list of target modules is passed to Hugging Face's PEFT library. See\n    [the docs](https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.target_modules) for\n    details.\n    \"\"\"\n\n    text_encoder_lora_target_modules: list[str] = TEXT_ENCODER_TARGET_MODULES\n    \"\"\"The list of target modules to apply LoRA layers to in the CLIP text encoder. The default list will produce a\n    highly expressive LoRA model.\n\n    For a smaller and less expressive LoRA model, the following list is recommended:\n    ```python\n    text_encoder_lora_target_modules = [\"fc1\", \"fc2\", \"q_proj\", \"k_proj\", \"v_proj\", \"out_proj\"]\n    ```\n\n    The list of target modules is passed to Hugging Face's PEFT library. See\n    [the docs](https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.target_modules) for\n    details.\n    \"\"\"\n\n    cache_text_encoder_outputs: bool = False\n    \"\"\"If True, the text encoder(s) will be applied to all of the captions in the dataset before starting training and\n    the results will be cached to disk. This reduces the VRAM requirements during training (don't have to keep the\n    text encoders in VRAM), and speeds up training  (don't have to run the text encoders for each training example).\n    This option can only be enabled if `train_text_encoder == False` and there are no caption augmentations being\n    applied.\n    \"\"\"\n\n    cache_vae_outputs: bool = False\n    \"\"\"If True, the VAE will be applied to all of the images in the dataset before starting training and the results\n    will be cached to disk. This reduces the VRAM requirements during training (don't have to keep the VAE in VRAM), and\n    speeds up training (don't have to run the VAE encoding step). This option can only be enabled if all\n    non-deterministic image augmentations are disabled (i.e. center_crop=True, random_flip=False).\n    \"\"\"\n\n    enable_cpu_offload_during_validation: bool = False\n    \"\"\"If True, models will be kept in CPU memory and loaded into GPU memory one-by-one while generating validation\n    images. This reduces VRAM requirements at the cost of slower generation of validation images.\n    \"\"\"\n\n    gradient_accumulation_steps: int = 1\n    \"\"\"The number of gradient steps to accumulate before each weight update. This value is passed to Hugging Face\n    Accelerate. This is an alternative to increasing the batch size when training with limited VRAM.\n    \"\"\"\n\n    weight_dtype: Literal[\"float32\", \"float16\", \"bfloat16\"] = \"float16\"\n    \"\"\"All weights (trainable and fixed) will be cast to this precision. Lower precision dtypes require less VRAM, and\n    result in faster training, but are more prone to issues with numerical stability.\n\n    Recommendations:\n\n    - `\"float32\"`: Use this mode if you have plenty of VRAM available.\n    - `\"bfloat16\"`: Use this mode if you have limited VRAM and a GPU that supports bfloat16.\n    - `\"float16\"`: Use this mode if you have limited VRAM and a GPU that does not support bfloat16.\n\n    See also [`mixed_precision`][invoke_training.pipelines.flux.lora.config.FluxLoraConfig.mixed_precision].\n    \"\"\"  # noqa: E501\n\n    mixed_precision: Literal[\"no\", \"fp16\", \"bf16\", \"fp8\"] = \"no\"\n    \"\"\"The mixed precision mode to use.\n\n    If mixed precision is enabled, then all non-trainable parameters will be cast to the specified `weight_dtype`, and\n    trainable parameters are kept in float32 precision to avoid issues with numerical stability.\n\n    This value is passed to Hugging Face Accelerate. See\n    [`accelerate.Accelerator.mixed_precision`](https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.mixed_precision)\n    for more details.\n    \"\"\"  # noqa: E501\n\n    gradient_checkpointing: bool = False\n    \"\"\"Whether or not to use gradient checkpointing to save memory at the expense of a slower backward pass. Enabling\n    gradient checkpointing slows down training by ~20%.\n    \"\"\"\n\n    max_checkpoints: int | None = None\n    \"\"\"The maximum number of checkpoints to keep. New checkpoints will replace earlier checkpoints to stay under this\n    limit. Note that this limit is applied to 'step' and 'epoch' checkpoints separately.\n    \"\"\"\n\n    prediction_type: Literal[\"epsilon\", \"v_prediction\"] | None = None\n    \"\"\"The prediction_type that will be used for training. Choose between 'epsilon' or 'v_prediction' or leave 'None'.\n    If 'None', the prediction type of the scheduler: `noise_scheduler.config.prediction_type` is used.\n    \"\"\"\n\n    max_grad_norm: float | None = None\n    \"\"\"Max gradient norm for clipping. Set to null or 0 for no clipping.\n    \"\"\"\n\n    validation_prompts: list[str] = []\n    \"\"\"A list of prompts that will be used to generate images throughout training for the purpose of tracking progress.\n    See also 'validate_every_n_epochs'.\n    \"\"\"\n\n    num_validation_images_per_prompt: int = 4\n    \"\"\"The number of validation images to generate for each prompt in 'validation_prompts'. Careful, validation can\n    become quite slow if this number is too large.\n    \"\"\"\n\n    train_batch_size: int = 1\n    \"\"\"The training batch size.\n    \"\"\"\n\n    use_masks: bool = False\n    \"\"\"If True, image masks will be applied to weight the loss during training. The dataset must contain masks for this\n    feature to be used.\n    \"\"\"\n\n    data_loader: Annotated[Union[ImageCaptionFluxDataLoaderConfig], Field(discriminator=\"type\")]\n\n    timestep_sampler: Literal[\"shift\", \"uniform\"] = \"shift\"\n    \"\"\"The timestep sampler to use. Choose between 'shift' or 'uniform'.\"\"\"\n\n    discrete_flow_shift: float = 3.0\n    \"\"\"The shift parameter for the discrete flow. Only used if `timestep_sampler == \"shift\"`.\n    \"\"\"\n\n    sigmoid_scale: float = 1.0\n    \"\"\"The scale parameter for the sigmoid function. Only used if `timestep_sampler == \"shift\"`.\n    \"\"\"\n\n    lora_scale: float | None = 1.0\n    \"\"\"The scale parameter for the LoRA layers. If set, this overrides the optimizer's default learning rate.\n    \"\"\"\n\n    guidance_scale: float = 1.0\n    \"\"\"The guidance scale for the Flux model.\n    \"\"\"\n\n    train_transformer: bool = True\n    \"\"\"Whether to train the Flux transformer (FluxTransformer2DModel) model.\n    \"\"\"\n\n    clip_tokenizer_max_length: int = 77\n    \"\"\"The maximum length of the CLIP tokenizer. The maximum length of the CLIP tokenizer is 77.\n    \"\"\"\n\n    t5_tokenizer_max_length: int = 512\n    \"\"\"The maximum length of the T5 tokenizer. The maximum length of the T5 tokenizer is 512.\n    \"\"\"\n"
  },
  {
    "path": "src/invoke_training/pipelines/flux/lora/train.py",
    "content": "import itertools\nimport json\nimport logging\nimport math\nimport os\nimport tempfile\nimport time\nfrom pathlib import Path\nfrom typing import Literal, Optional, Union\n\nimport numpy as np\nimport peft\nimport torch\nimport torch.utils.data\nfrom accelerate.utils import set_seed\nfrom diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxTransformer2DModel\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.pipelines.flux.pipeline_flux import FluxPipeline\nfrom peft import PeftModel\nfrom PIL import Image\nfrom torch.utils.data import DataLoader\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer\n\nfrom invoke_training._shared.accelerator.accelerator_utils import (\n    get_dtype_from_str,\n    initialize_accelerator,\n    initialize_logging,\n)\nfrom invoke_training._shared.checkpoints.checkpoint_tracker import CheckpointTracker\nfrom invoke_training._shared.data.data_loaders.image_caption_flux_dataloader import build_image_caption_flux_dataloader\nfrom invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache\nfrom invoke_training._shared.flux.encoding_utils import encode_prompt\nfrom invoke_training._shared.flux.lora_checkpoint_utils import (\n    save_flux_kohya_checkpoint,\n    save_flux_peft_checkpoint,\n)\nfrom invoke_training._shared.flux.model_loading_utils import load_models_flux\nfrom invoke_training._shared.flux.validation import generate_validation_images_flux\nfrom invoke_training._shared.optimizer.optimizer_utils import initialize_optimizer\nfrom invoke_training._shared.stable_diffusion.tokenize_captions import tokenize_captions\nfrom invoke_training.config.data.data_loader_config import ImageCaptionSDDataLoaderConfig\nfrom invoke_training.pipelines.callbacks import ModelCheckpoint, ModelType, PipelineCallbacks, TrainingCheckpoint\nfrom invoke_training.pipelines.flux.lora.config import FluxLoraConfig\n\n\ndef _save_flux_lora_checkpoint(\n    epoch: int,\n    step: int,\n    transformer: peft.PeftModel | None,\n    text_encoder_1: CLIPTextModel | None,\n    text_encoder_2: T5EncoderModel | None,\n    logger: logging.Logger,\n    checkpoint_tracker: CheckpointTracker,\n    callbacks: list[PipelineCallbacks] | None,\n    lora_checkpoint_format: Literal[\"invoke_peft\", \"kohya\"] = \"invoke_peft\",\n):\n    # Prune checkpoints and get new checkpoint path.\n    num_pruned = checkpoint_tracker.prune(1)\n    if num_pruned > 0:\n        logger.info(f\"Pruned {num_pruned} checkpoint(s).\")\n    save_path = checkpoint_tracker.get_path(epoch=epoch, step=step)\n\n    if lora_checkpoint_format == \"invoke_peft\":\n        model_type = ModelType.FLUX_LORA_PEFT\n        save_flux_peft_checkpoint(\n            Path(save_path), transformer=transformer, text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2\n        )\n    elif lora_checkpoint_format == \"kohya\":\n        model_type = ModelType.FLUX_LORA_KOHYA\n        save_flux_kohya_checkpoint(\n            Path(save_path), transformer=transformer, text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2\n        )\n    else:\n        raise ValueError(f\"Unsupported lora_checkpoint_format: '{lora_checkpoint_format}'.\")\n\n    if callbacks is not None:\n        for cb in callbacks:\n            cb.on_save_checkpoint(\n                TrainingCheckpoint(\n                    models=[ModelCheckpoint(file_path=save_path, model_type=model_type)], epoch=epoch, step=step\n                )\n            )\n\n\ndef _build_data_loader(\n    data_loader_config: Union[ImageCaptionSDDataLoaderConfig],\n    batch_size: int,\n    use_masks: bool = False,\n    text_encoder_output_cache_dir: Optional[str] = None,\n    vae_output_cache_dir: Optional[str] = None,\n    shuffle: bool = True,\n    sequential_batching: bool = False,\n) -> DataLoader:\n    if data_loader_config.type == \"IMAGE_CAPTION_FLUX_DATA_LOADER\":\n        return build_image_caption_flux_dataloader(\n            config=data_loader_config,\n            batch_size=batch_size,\n            use_masks=use_masks,\n            text_encoder_output_cache_dir=text_encoder_output_cache_dir,\n            text_encoder_cache_field_to_output_field={\"text_encoder_output\": \"text_encoder_output\"},\n            vae_output_cache_dir=vae_output_cache_dir,\n            shuffle=shuffle,\n        )\n    else:\n        raise ValueError(f\"Unsupported data loader config type: '{data_loader_config.type}'.\")\n\n\ndef cache_text_encoder_outputs(\n    cache_dir: str, config: FluxLoraConfig, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel\n):\n    \"\"\"Run the text encoder on all captions in the dataset and cache the results to disk.\n\n    Args:\n        cache_dir (str): The directory where the results will be cached.\n        config (FluxLoraConfig): Training config.\n        tokenizer (CLIPTokenizer): The tokenizer.\n        text_encoder (CLIPTextModel): The text_encoder.\n    \"\"\"\n    data_loader = _build_data_loader(\n        data_loader_config=config.data_loader,\n        batch_size=config.train_batch_size,\n        shuffle=False,\n        sequential_batching=True,\n    )\n\n    cache = TensorDiskCache(cache_dir)\n\n    for data_batch in tqdm(data_loader):\n        caption_token_ids = tokenize_captions(tokenizer, data_batch[\"caption\"]).to(text_encoder.device)\n        text_encoder_output_batch = text_encoder(caption_token_ids)[0]\n        # Split batch before caching.\n        for i in range(len(data_batch[\"id\"])):\n            cache.save(data_batch[\"id\"][i], {\"text_encoder_output\": text_encoder_output_batch[i]})\n\n\ndef cache_vae_outputs(cache_dir: str, data_loader: DataLoader, vae: AutoencoderKL):\n    \"\"\"Run the VAE on all images in the dataset and cache the results to disk.\"\"\"\n    cache = TensorDiskCache(cache_dir)\n\n    for data_batch in tqdm(data_loader):\n        latents = vae.encode(data_batch[\"image\"].to(device=vae.device, dtype=vae.dtype)).latent_dist.sample()\n        latents = latents * vae.config.scaling_factor\n        # Split batch before caching.\n        for i in range(len(data_batch[\"id\"])):\n            data = {\n                \"vae_output\": latents[i],\n                \"original_size_hw\": data_batch[\"original_size_hw\"][i],\n                \"crop_top_left_yx\": data_batch[\"crop_top_left_yx\"][i],\n            }\n            if \"mask\" in data_batch:\n                data[\"mask\"] = data_batch[\"mask\"][i]\n            cache.save(data_batch[\"id\"][i], data)\n\n\ndef get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):\n    sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)\n    schedule_timesteps = noise_scheduler.timesteps.to(device)\n    timesteps = timesteps.to(device)\n    step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]\n\n    sigma = sigmas[step_indices].flatten()\n    while len(sigma.shape) < n_dim:\n        sigma = sigma.unsqueeze(-1)\n    return sigma\n\n\ndef get_noisy_latents(noise_scheduler: FlowMatchEulerDiscreteScheduler, latents: torch.Tensor, config: FluxLoraConfig):\n    \"\"\"\n    Generate random noise. Sample a random timestep from the distribution chosen by the config.\n    Linearly interpolate between the latents and the noise based on timestep.\n    See Section 3.1 of https://arxiv.org/pdf/2403.03206v1 for timestep sampling.\n\n    Args:\n        noise_scheduler (FlowMatchEulerDiscreteScheduler): The noise scheduler.\n        latents (torch.Tensor): The latents.\n        config (FluxLoraConfig): The config.\n\n    Returns:\n        torch.Tensor: The noisy latents.\n\n    \"\"\"\n\n    batch_size = latents.shape[0]\n    dtype = latents.dtype\n    device = latents.device\n    noise = torch.randn_like(latents)\n\n    if config.timestep_sampler == \"shift\":\n        shift = config.discrete_flow_shift\n        sigmas = torch.randn(batch_size, device=device)\n        sigmas = sigmas * config.sigmoid_scale  # larger scale for more uniform sampling\n        sigmas = sigmas.sigmoid()\n        sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas)\n        timesteps = sigmas * noise_scheduler.config.num_train_timesteps\n    else:\n        u = torch.rand(size=(batch_size,), device=\"cpu\")\n        indices = (u * noise_scheduler.config.num_train_timesteps).long()\n        timesteps = noise_scheduler.timesteps[indices].to(device=device)\n        sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)\n\n    sigmas = sigmas.view(-1, 1, 1, 1)\n\n    # Linearly interpolate between the latents and the noise.\n    noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise\n    return noisy_model_input.to(dtype), noise.to(dtype), timesteps.to(dtype), sigmas.to(dtype)\n\n\ndef decode_latents(vae: AutoencoderKL, latents: torch.Tensor):\n    latents = latents / vae.config.scaling_factor\n    image = vae.decode(latents).sample\n\n    # tensor to image\n    image = image.cpu().numpy()\n    image = (image * 255).astype(np.uint8)\n    image = Image.fromarray(image)\n\n    image.save(\"image.png\")\n    return image\n\n\ndef train_forward(  # noqa: C901\n    config: FluxLoraConfig,\n    data_batch: dict,\n    vae: AutoencoderKL,\n    noise_scheduler: FlowMatchEulerDiscreteScheduler,\n    tokenizer_1: CLIPTokenizer,\n    tokenizer_2: T5Tokenizer,\n    text_encoder_1: CLIPTextModel,\n    text_encoder_2: T5EncoderModel,\n    transformer: FluxTransformer2DModel | PeftModel,\n    weight_dtype: torch.dtype,\n    use_masks: bool = False,\n    min_snr_gamma: float | None = None,\n    logger: logging.Logger = None,\n) -> torch.Tensor:\n    \"\"\"Run the forward training pass for a single data_batch.\n\n    Returns:\n        torch.Tensor: Loss\n    \"\"\"\n    # Convert images to latent space.\n    # The VAE output may have been cached and included in the data_batch. If not, we calculate it here.\n    latents = data_batch.get(\"vae_output\", None)\n    if latents is None:\n        # Cast input image to same dtype as VAE\n        image = data_batch[\"image\"].to(device=vae.device, dtype=vae.dtype)\n        latents = vae.encode(image).latent_dist.sample()\n        batch_size, num_channels, height, width = latents.shape\n        latents = latents * vae.config.scaling_factor\n        latents = FluxPipeline._pack_latents(latents, batch_size, num_channels, height, width)\n    else:\n        batch_size, num_channels, height, width = latents.shape\n    # Sample noise that we'll add to the latents.\n    latent_image_ids = FluxPipeline._prepare_latent_image_ids(\n        batch_size, height // 2, width // 2, latents.device, latents.dtype\n    )\n\n    # Add noise to the latents according to the noise magnitude at each timestep (this is the forward\n    # diffusion process).\n    noisy_latents, noise, timesteps, sigmas = get_noisy_latents(noise_scheduler, latents, config)\n\n    # Get the text embedding for conditioning.\n    # The text encoder output may have been cached and included in the data_batch. If not, we calculate it here.\n    if \"prompt_embeds\" in data_batch:\n        prompt_embeds = data_batch[\"prompt_embeds\"]\n        pooled_prompt_embeds = data_batch[\"pooled_prompt_embeds\"]\n    else:\n        prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(\n            prompt=data_batch[\"caption\"],\n            prompt_2=data_batch.get(\"caption_2\", None),\n            clip_tokenizer=tokenizer_1,\n            t5_tokenizer=tokenizer_2,\n            clip_text_encoder=text_encoder_1,\n            t5_text_encoder=text_encoder_2,\n            device=latents.device,\n            num_images_per_prompt=1,\n            lora_scale=config.lora_scale,\n            clip_tokenizer_max_length=config.clip_tokenizer_max_length,\n            t5_tokenizer_max_length=config.t5_tokenizer_max_length,\n            logger=logger,\n        )\n\n    guidance = torch.full((batch_size,), float(config.guidance_scale), device=latents.device)\n    model_pred = transformer(\n        hidden_states=noisy_latents[0],\n        timestep=timesteps / 1000,\n        pooled_projections=pooled_prompt_embeds,\n        encoder_hidden_states=prompt_embeds,\n        guidance=guidance,\n        txt_ids=text_ids,\n        img_ids=latent_image_ids,\n        return_dict=False,\n    )[0]\n    ### Flow matching loss\n    # See here for more discussion:https://discuss.huggingface.co/t/meaning-of-vector-fields-in-flux-and-sd3-loss-function/106601\n    target = noise - latents\n\n    loss = torch.nn.functional.mse_loss(model_pred.float(), target.float(), reduction=\"none\")\n    loss = loss.mean(dim=list(range(1, len(loss.shape))))\n    return loss.mean()\n\n\ndef train(config: FluxLoraConfig, callbacks: list[PipelineCallbacks] | None = None):  # noqa: C901\n    # Create a timestamped directory for all outputs.\n    out_dir = os.path.join(config.base_output_dir, f\"{time.time()}\")\n    ckpt_dir = os.path.join(out_dir, \"checkpoints\")\n    os.makedirs(ckpt_dir)\n\n    accelerator = initialize_accelerator(\n        out_dir, config.gradient_accumulation_steps, config.mixed_precision, config.report_to\n    )\n    logger = initialize_logging(os.path.basename(__file__), accelerator)\n\n    # Set the accelerate seed.\n    if config.seed is not None:\n        set_seed(config.seed)\n\n    # Log the accelerator configuration from every process to help with debugging.\n    logger.info(accelerator.state, main_process_only=False)\n\n    logger.info(\"Starting LoRA Training.\")\n    logger.info(f\"Configuration:\\n{json.dumps(config.dict(), indent=2, default=str)}\")\n    logger.info(f\"Output dir: '{out_dir}'\")\n\n    # Write the configuration to disk.\n    with open(os.path.join(out_dir, \"config.json\"), \"w\") as f:\n        json.dump(config.dict(), f, indent=2, default=str)\n\n    weight_dtype = get_dtype_from_str(config.weight_dtype)\n\n    logger.info(\"Loading models.\")\n    tokenizer_1, tokenizer_2, noise_scheduler, text_encoder_1, text_encoder_2, vae, transformer = load_models_flux(\n        model_name_or_path=config.model,\n        transformer_path=config.transformer_path,\n        text_encoder_1_path=config.text_encoder_1_path,\n        text_encoder_2_path=config.text_encoder_2_path,\n        dtype=weight_dtype,\n        logger=logger,\n    )\n\n    # Prepare text encoder output cache.\n    text_encoder_output_cache_dir_name = None\n    if config.cache_text_encoder_outputs:\n        # TODO(ryand): Think about how to better check if it is safe to cache the text encoder outputs. Currently, there\n        # are a number of configurations that would cause variation in the text encoder outputs and should not be used\n        # with caching.\n\n        # We use a temporary directory for the cache. The directory will automatically be cleaned up when\n        # tmp_text_encoder_output_cache_dir is destroyed.\n        tmp_text_encoder_output_cache_dir = tempfile.TemporaryDirectory()\n        text_encoder_output_cache_dir_name = tmp_text_encoder_output_cache_dir.name\n        if accelerator.is_local_main_process:\n            # Only the main process should populate the cache.\n            logger.info(f\"Generating text encoder output cache ('{text_encoder_output_cache_dir_name}').\")\n            text_encoder_1.to(accelerator.device, dtype=weight_dtype)\n            text_encoder_2.to(accelerator.device, dtype=weight_dtype)\n            # TODO(ryan): Move cache_text_encoder_outputs to a shared location so that it is not imported from another\n            # pipeline.\n            cache_text_encoder_outputs(\n                text_encoder_output_cache_dir_name, config, tokenizer_1, tokenizer_2, text_encoder_1, text_encoder_2\n            )\n        # Move the text_encoders back to the CPU, because they are not needed for training.\n        text_encoder_1.to(\"cpu\")\n        text_encoder_2.to(\"cpu\")\n        accelerator.wait_for_everyone()\n    else:\n        text_encoder_1.to(accelerator.device, dtype=weight_dtype)\n        text_encoder_2.to(accelerator.device, dtype=weight_dtype)\n\n    # Prepare VAE output cache.\n    # vae_output_cache_dir_name = None\n    if config.cache_vae_outputs:\n        if config.data_loader.random_flip:\n            raise ValueError(\"'cache_vae_outputs' cannot be True if 'random_flip' is True.\")\n        if not config.data_loader.center_crop:\n            raise ValueError(\"'cache_vae_outputs' cannot be True if 'center_crop' is False.\")\n\n        # We use a temporary directory for the cache. The directory will automatically be cleaned up when\n        # tmp_vae_output_cache_dir is destroyed.\n        tmp_vae_output_cache_dir = tempfile.TemporaryDirectory()\n        vae_output_cache_dir_name = tmp_vae_output_cache_dir.name\n        if accelerator.is_local_main_process:\n            # Only the main process should populate the cache.\n            logger.info(f\"Generating VAE output cache ('{vae_output_cache_dir_name}').\")\n            vae.to(accelerator.device, dtype=weight_dtype)\n            data_loader = _build_data_loader(\n                data_loader_config=config.data_loader,\n                batch_size=config.train_batch_size,\n                shuffle=False,\n                sequential_batching=True,\n            )\n            cache_vae_outputs(vae_output_cache_dir_name, data_loader, vae)\n        # Move the VAE back to the CPU, because it is not needed for training.\n        vae.to(\"cpu\")\n        accelerator.wait_for_everyone()\n    else:\n        vae.to(accelerator.device, dtype=weight_dtype)\n\n    transformer.to(accelerator.device, dtype=weight_dtype)\n\n    # Add LoRA layers to the models being trained.\n    trainable_param_groups = []\n    all_trainable_models: list[peft.PeftModel] = []\n\n    def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = None) -> peft.PeftModel:\n        peft_model = peft.get_peft_model(model, lora_config)\n        peft_model.print_trainable_parameters()\n\n        # Populate `trainable_param_groups`, to be passed to the optimizer.\n        param_group = {\"params\": list(filter(lambda p: p.requires_grad, peft_model.parameters()))}\n        if lr is not None:\n            param_group[\"lr\"] = lr\n        trainable_param_groups.append(param_group)\n\n        # Populate all_trainable_models.\n        all_trainable_models.append(peft_model)\n\n        peft_model.train()\n\n        return peft_model\n\n    # Add LoRA layers to the model.\n    if config.train_transformer:\n        transformer_lora_config = peft.LoraConfig(\n            r=config.lora_rank_dim,\n            # TODO(ryand): Diffusers uses lora_alpha=config.lora_rank_dim. Is that preferred?\n            lora_alpha=1.0,\n            target_modules=config.flux_lora_target_modules,\n        )\n        transformer = inject_lora_layers(transformer, transformer_lora_config, lr=config.transformer_learning_rate)\n\n    if config.train_text_encoder:\n        text_encoder_lora_config = peft.LoraConfig(\n            r=config.lora_rank_dim,\n            lora_alpha=1.0,\n            # init_lora_weights=\"gaussian\",\n            target_modules=config.text_encoder_lora_target_modules,\n        )\n\n        text_encoder_1 = inject_lora_layers(\n            text_encoder_1, text_encoder_lora_config, lr=config.text_encoder_learning_rate\n        )\n\n    # Enable gradient checkpointing.\n    if config.gradient_checkpointing:\n        # We want to enable gradient checkpointing in the UNet regardless of whether it is being trained.\n        transformer.enable_gradient_checkpointing()\n        # unet must be in train() mode for gradient checkpointing to take effect.\n        # At the time of writing, the unet dropout probabilities default to 0, so putting the unet in train mode does\n        # not change its forward behavior.\n        transformer.train()\n        if config.train_text_encoder:\n            text_encoder_1.gradient_checkpointing_enable()\n            # The text encoders must be in train() mode for gradient checkpointing to take effect. This should\n            # already be the case, since we are training the text_encoders, be we do it explicitly to make it clear\n            # that this is required.\n            # At the time of writing, the text encoder dropout probabilities default to 0, so putting the text\n            # encoders in train mode does not change their forward behavior.\n            text_encoder_1.train()\n            # Set requires_grad = True on the first parameters of the text encoders. Without this, the text encoder\n            # LoRA weights would have 0 gradients, and so would not get trained. Note that the set of\n            # trainable_param_groups has already been populated - the embeddings will not be trained.\n            text_encoder_1.text_model.embeddings.requires_grad_(True)\n\n    optimizer = initialize_optimizer(config.optimizer, trainable_param_groups)\n\n    data_loader = _build_data_loader(\n        data_loader_config=config.data_loader,\n        batch_size=config.train_batch_size,\n        # text_encoder_output_cache_dir=text_encoder_output_cache_dir_name,\n        # vae_output_cache_dir=vae_output_cache_dir_name,\n    )\n\n    assert sum([config.max_train_steps is not None, config.max_train_epochs is not None]) == 1\n    assert sum([config.save_every_n_steps is not None, config.save_every_n_epochs is not None]) == 1\n    assert sum([config.validate_every_n_steps is not None, config.validate_every_n_epochs is not None]) == 1\n\n    # A \"step\" represents a single weight update operation (i.e. takes into account gradient accumulation steps).\n    # math.ceil(...) is used in calculating the num_steps_per_epoch, because by default an optimizer step is taken when\n    # the end of the dataloader is reached, even if gradient_accumulation_steps hasn't been reached.\n    num_steps_per_epoch = math.ceil(len(data_loader) / config.gradient_accumulation_steps)\n    num_train_steps = config.max_train_steps or config.max_train_epochs * num_steps_per_epoch\n    num_train_epochs = math.ceil(num_train_steps / num_steps_per_epoch)\n\n    # TODO(ryand): Test in a distributed training environment and more clearly document the rationale for scaling steps\n    # by the number of processes. This scaling logic was copied from the diffusers example training code, but it appears\n    # in many places so I don't know where it originated. Internally, accelerate makes one LR scheduler step per process\n    # (https://github.com/huggingface/accelerate/blame/49cb83a423f2946059117d8bb39b7c8747d29d80/src/accelerate/scheduler.py#L72-L82),\n    # so the scaling here simply reverses that behaviour.\n    lr_scheduler: torch.optim.lr_scheduler.LRScheduler = get_scheduler(\n        config.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=config.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=num_train_steps * accelerator.num_processes,\n    )\n\n    prepared_result: tuple[\n        FluxTransformer2DModel,\n        CLIPTextModel,\n        T5EncoderModel,\n        torch.optim.Optimizer,\n        torch.utils.data.DataLoader,\n        torch.optim.lr_scheduler.LRScheduler,\n    ] = accelerator.prepare(\n        transformer,\n        text_encoder_1,\n        text_encoder_2,\n        optimizer,\n        data_loader,\n        lr_scheduler,\n        # Disable automatic device placement for text_encoder if the text encoder outputs were cached.\n        device_placement=[\n            True,\n            not config.cache_text_encoder_outputs,\n            not config.cache_text_encoder_outputs,\n            True,\n            True,\n            True,\n        ],\n    )\n    transformer, text_encoder_1, text_encoder_2, optimizer, data_loader, lr_scheduler = prepared_result\n\n    if accelerator.is_main_process:\n        accelerator.init_trackers(\"lora_training\")\n        # Tensorboard uses markdown formatting, so we wrap the config json in a code block.\n        accelerator.log({\"configuration\": f\"```json\\n{json.dumps(config.dict(), indent=2, default=str)}\\n```\\n\"})\n\n    checkpoint_tracker = CheckpointTracker(\n        base_dir=ckpt_dir,\n        prefix=\"checkpoint\",\n        max_checkpoints=config.max_checkpoints,\n        extension=\".safetensors\" if config.lora_checkpoint_format == \"kohya\" else None,\n    )\n\n    # Train!\n    total_batch_size = config.train_batch_size * accelerator.num_processes * config.gradient_accumulation_steps\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num batches = {len(data_loader)}\")\n    logger.info(f\"  Instantaneous batch size per device = {config.train_batch_size}\")\n    logger.info(f\"  Gradient accumulation steps = {config.gradient_accumulation_steps}\")\n    logger.info(f\"  Parallel processes = {accelerator.num_processes}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Total optimization steps = {num_train_steps}\")\n    logger.info(f\"  Total epochs = {num_train_epochs}\")\n\n    global_step = 0\n    first_epoch = 0\n    completed_epochs = 0\n\n    progress_bar = tqdm(\n        range(global_step, num_train_steps),\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n    progress_bar.set_description(\"Steps\")\n\n    def save_checkpoint(num_completed_epochs: int, num_completed_steps: int):\n        accelerator.wait_for_everyone()\n        if accelerator.is_main_process:\n            _save_flux_lora_checkpoint(\n                epoch=num_completed_epochs,\n                step=num_completed_steps,\n                transformer=transformer if config.train_transformer else None,\n                text_encoder_1=text_encoder_1 if config.train_text_encoder else None,\n                text_encoder_2=text_encoder_2 if config.train_text_encoder else None,\n                logger=logger,\n                checkpoint_tracker=checkpoint_tracker,\n                lora_checkpoint_format=config.lora_checkpoint_format,\n                callbacks=callbacks,\n            )\n        accelerator.wait_for_everyone()\n\n    def validate(num_completed_epochs: int, num_completed_steps: int):\n        accelerator.wait_for_everyone()\n        if accelerator.is_main_process:\n            generate_validation_images_flux(\n                epoch=num_completed_epochs,\n                step=num_completed_steps,\n                out_dir=out_dir,\n                accelerator=accelerator,\n                vae=vae,\n                text_encoder_1=text_encoder_1,\n                text_encoder_2=text_encoder_2,\n                tokenizer_1=tokenizer_1,\n                tokenizer_2=tokenizer_2,\n                noise_scheduler=noise_scheduler,\n                transformer=transformer,\n                config=config,\n                logger=logger,\n                callbacks=callbacks,\n            )\n        accelerator.wait_for_everyone()\n\n    for epoch in range(first_epoch, num_train_epochs):\n        train_loss = 0.0\n        for data_batch_idx, data_batch in enumerate(data_loader):\n            # (Pdb) data_batch['image'].shape\n            # torch.Size([4, 3, 512, 512])\n            with accelerator.accumulate(transformer, text_encoder_1, text_encoder_2):\n                loss = train_forward(\n                    config=config,\n                    data_batch=data_batch,\n                    vae=vae,\n                    noise_scheduler=noise_scheduler,\n                    tokenizer_1=tokenizer_1,\n                    tokenizer_2=tokenizer_2,\n                    text_encoder_1=text_encoder_1,\n                    text_encoder_2=text_encoder_2,\n                    transformer=transformer,\n                    weight_dtype=weight_dtype,\n                    min_snr_gamma=config.min_snr_gamma,\n                )\n\n                # Gather the losses across all processes for logging (if we use distributed training).\n                # TODO(ryand): Test that this works properly with distributed training.\n                avg_loss = accelerator.gather(loss.repeat(config.train_batch_size)).mean()\n                train_loss += avg_loss.item() / config.gradient_accumulation_steps\n\n                # Backpropagate.\n                accelerator.backward(loss)\n                if accelerator.sync_gradients and config.max_grad_norm is not None:\n                    params_to_clip = itertools.chain.from_iterable([m.parameters() for m in all_trainable_models])\n                    accelerator.clip_grad_norm_(params_to_clip, config.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad(set_to_none=True)\n\n            # Checks if the accelerator has performed an optimization step behind the scenes.\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n                completed_epochs = epoch if (data_batch_idx + 1) < len(data_loader) else epoch + 1\n                log = {\"train_loss\": train_loss}\n\n                lrs = lr_scheduler.get_last_lr()\n                if config.train_transformer:\n                    # When training the UNet, it will always be the first parameter group.\n                    log[\"lr/transformer\"] = float(lrs[0])\n                    if config.optimizer.optimizer_type == \"Prodigy\":\n                        log[\"lr/d*lr/transformer\"] = optimizer.param_groups[0][\"d\"] * optimizer.param_groups[0][\"lr\"]\n                if config.train_text_encoder:\n                    # When training the text encoder, it will always be the last parameter group.\n                    log[\"lr/text_encoder\"] = float(lrs[-1])\n                    if config.optimizer.optimizer_type == \"Prodigy\":\n                        log[\"lr/d*lr/text_encoder\"] = optimizer.param_groups[-1][\"d\"] * optimizer.param_groups[-1][\"lr\"]\n\n                accelerator.log(log, step=global_step)\n                train_loss = 0.0\n\n                # global_step represents the *number of completed steps* at this point.\n                if config.save_every_n_steps is not None and global_step % config.save_every_n_steps == 0:\n                    save_checkpoint(num_completed_epochs=completed_epochs, num_completed_steps=global_step)\n\n                if (\n                    config.validate_every_n_steps is not None\n                    and global_step % config.validate_every_n_steps == 0\n                    and len(config.validation_prompts) > 0\n                ):\n                    validate(num_completed_epochs=completed_epochs, num_completed_steps=global_step)\n\n            logs = {\n                \"step_loss\": loss.detach().item(),\n                \"lr\": lr_scheduler.get_last_lr()[0],\n            }\n            progress_bar.set_postfix(**logs)\n\n            if global_step >= num_train_steps:\n                break\n\n        # Save a checkpoint every n epochs.\n        if config.save_every_n_epochs is not None and completed_epochs % config.save_every_n_epochs == 0:\n            save_checkpoint(num_completed_epochs=completed_epochs, num_completed_steps=global_step)\n\n        # Generate validation images every n epochs.\n        if (\n            config.validate_every_n_epochs is not None\n            and completed_epochs % config.validate_every_n_epochs == 0\n            and len(config.validation_prompts) > 0\n        ):\n            validate(num_completed_epochs=completed_epochs, num_completed_steps=global_step)\n\n    accelerator.end_training()\n"
  },
  {
    "path": "src/invoke_training/pipelines/invoke_train.py",
    "content": "import os\n\nfrom invoke_training.config.pipeline_config import PipelineConfig\nfrom invoke_training.pipelines._experimental.sd_dpo_lora.train import train as train_sd_ddpo_lora\nfrom invoke_training.pipelines.callbacks import PipelineCallbacks\nfrom invoke_training.pipelines.flux.lora.train import train as train_flux_lora\nfrom invoke_training.pipelines.stable_diffusion.lora.train import train as train_sd_lora\nfrom invoke_training.pipelines.stable_diffusion.textual_inversion.train import train as train_sd_ti\nfrom invoke_training.pipelines.stable_diffusion_xl.finetune.train import train as train_sdxl_finetune\nfrom invoke_training.pipelines.stable_diffusion_xl.lora.train import train as train_sdxl_lora\nfrom invoke_training.pipelines.stable_diffusion_xl.lora_and_textual_inversion.train import (\n    train as train_sdxl_lora_and_ti,\n)\nfrom invoke_training.pipelines.stable_diffusion_xl.textual_inversion.train import train as train_sdxl_ti\n\n\ndef train(config: PipelineConfig, callbacks: list[PipelineCallbacks] | None = None):\n    \"\"\"This is the main entry point for all training pipelines.\"\"\"\n\n    # Fail early if invalid callback types are provided, rather than failing later when the callbacks are used.\n    for cb in callbacks or []:\n        assert isinstance(cb, PipelineCallbacks)\n\n    if config.type == \"FLUX_LORA\":\n        # Disable tokenizer parallelism to avoid issues with tokenization\n        os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n        train_flux_lora(config, callbacks)\n    elif config.type == \"SD_LORA\":\n        train_sd_lora(config, callbacks)\n    elif config.type == \"SDXL_LORA\":\n        train_sdxl_lora(config, callbacks)\n    elif config.type == \"SD_TEXTUAL_INVERSION\":\n        train_sd_ti(config, callbacks)\n    elif config.type == \"SDXL_TEXTUAL_INVERSION\":\n        train_sdxl_ti(config, callbacks)\n    elif config.type == \"SDXL_LORA_AND_TEXTUAL_INVERSION\":\n        train_sdxl_lora_and_ti(config, callbacks)\n    elif config.type == \"SDXL_FINETUNE\":\n        train_sdxl_finetune(config, callbacks)\n    elif config.type == \"SD_DIRECT_PREFERENCE_OPTIMIZATION_LORA\":\n        print(f\"Running EXPERIMENTAL pipeline: '{config.type}'.\")\n        train_sd_ddpo_lora(config, callbacks)\n    else:\n        raise ValueError(f\"Unexpected pipeline type: '{config.type}'.\")\n"
  },
  {
    "path": "src/invoke_training/pipelines/stable_diffusion/__init__.py",
    "content": ""
  },
  {
    "path": "src/invoke_training/pipelines/stable_diffusion/lora/__init__.py",
    "content": ""
  },
  {
    "path": "src/invoke_training/pipelines/stable_diffusion/lora/config.py",
    "content": "from typing import Annotated, Literal, Union\n\nfrom pydantic import Field, model_validator\n\nfrom invoke_training._shared.stable_diffusion.lora_checkpoint_utils import (\n    TEXT_ENCODER_TARGET_MODULES,\n    UNET_TARGET_MODULES,\n)\nfrom invoke_training.config.base_pipeline_config import BasePipelineConfig\nfrom invoke_training.config.data.data_loader_config import DreamboothSDDataLoaderConfig, ImageCaptionSDDataLoaderConfig\nfrom invoke_training.config.optimizer.optimizer_config import AdamOptimizerConfig, ProdigyOptimizerConfig\n\n\nclass SdLoraConfig(BasePipelineConfig):\n    type: Literal[\"SD_LORA\"] = \"SD_LORA\"\n\n    model: str = \"runwayml/stable-diffusion-v1-5\"\n    \"\"\"Name or path of the base model to train. Can be in diffusers format, or a single stable diffusion checkpoint\n    file. (E.g. 'runwayml/stable-diffusion-v1-5', '/path/to/realisticVisionV51_v51VAE.safetensors', etc. )\n    \"\"\"\n\n    hf_variant: str | None = \"fp16\"\n    \"\"\"The Hugging Face Hub model variant to use. Only applies if `model` is a Hugging Face Hub model name.\n    \"\"\"\n\n    # Note: Pydantic handles mutable default values well:\n    # https://docs.pydantic.dev/latest/concepts/models/#fields-with-non-hashable-default-values\n    base_embeddings: dict[str, str] = {}\n    \"\"\"A mapping of embedding tokens to trained embedding file paths. These embeddings will be applied to the base model\n    before training.\n\n    Example:\n    ```\n    base_embeddings = {\n        \"bruce_the_gnome\": \"/path/to/bruce_the_gnome.safetensors\",\n    }\n    ```\n\n    Consider also adding the embedding tokens to the `data_loader.caption_prefix` if they are not already present in the\n    dataset captions.\n\n    Note that the embeddings themselves are not fine-tuned further, but they will impact the LoRA model training if they\n    are referenced in the dataset captions. The list of embeddings provided here should be the same list used at\n    generation time with the resultant LoRA model.\n    \"\"\"\n\n    lora_checkpoint_format: Literal[\"invoke_peft\", \"kohya\"] = \"kohya\"\n    \"\"\"The format of the LoRA checkpoint to save. Choose between `invoke_peft` or `kohya`.\"\"\"\n\n    train_unet: bool = True\n    \"\"\"Whether to add LoRA layers to the UNet model and train it.\n    \"\"\"\n\n    train_text_encoder: bool = True\n    \"\"\"Whether to add LoRA layers to the text encoder and train it.\n    \"\"\"\n\n    optimizer: AdamOptimizerConfig | ProdigyOptimizerConfig = AdamOptimizerConfig()\n\n    text_encoder_learning_rate: float | None = None\n    \"\"\"The learning rate to use for the text encoder model. If set, this overrides the optimizer's default learning\n    rate. Set to null or 0 to use the optimizer's default learning rate.\n    \"\"\"\n\n    unet_learning_rate: float | None = None\n    \"\"\"The learning rate to use for the UNet model. If set, this overrides the optimizer's default learning rate.\n    Set to null or 0 to use the optimizer's default learning rate.\n    \"\"\"\n\n    lr_scheduler: Literal[\n        \"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\", \"constant\", \"constant_with_warmup\"\n    ] = \"constant\"\n\n    lr_warmup_steps: int = 0\n    \"\"\"The number of warmup steps in the learning rate scheduler. Only applied to schedulers that support warmup.\n    See lr_scheduler.\n    \"\"\"\n\n    min_snr_gamma: float | None = 5.0\n    \"\"\"Min-SNR weighting for diffusion training was introduced in https://arxiv.org/abs/2303.09556. This strategy\n    improves the speed of training convergence by adjusting the weight of each sample.\n\n    `min_snr_gamma` acts like an an upper bound on the weight of samples with low noise levels.\n\n    If `None`, then Min-SNR weighting will not be applied. If enabled, the recommended value is `min_snr_gamma = 5.0`.\n    \"\"\"\n\n    lora_rank_dim: int = 4\n    \"\"\"The rank dimension to use for the LoRA layers. Increasing the rank dimension increases the model's expressivity,\n    but also increases the size of the generated LoRA model.\n    \"\"\"\n\n    # The default list of target modules is based on\n    # https://github.com/huggingface/peft/blob/8665e2b5719faa4e4b91749ddec09442927b53e0/examples/stable_diffusion/train_dreambooth.py#L49C1-L65C87\n    unet_lora_target_modules: list[str] = UNET_TARGET_MODULES\n    \"\"\"The list of target modules to apply LoRA layers to in the UNet model. The default list will produce a highly\n    expressive LoRA model.\n\n    For a smaller and less expressive LoRA model, the following list is recommended:\n    ```python\n    unet_lora_target_modules = [\"to_k\", \"to_q\", \"to_v\", \"to_out.0\"]\n    ```\n\n    The list of target modules is passed to Hugging Face's PEFT library. See\n    [the docs](https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.target_modules) for\n    details.\n    \"\"\"\n\n    text_encoder_lora_target_modules: list[str] = TEXT_ENCODER_TARGET_MODULES\n    \"\"\"The list of target modules to apply LoRA layers to in the text encoder models. The default list will produce a\n    highly expressive LoRA model.\n\n    For a smaller and less expressive LoRA model, the following list is recommended:\n    ```python\n    text_encoder_lora_target_modules = [\"fc1\", \"fc2\", \"q_proj\", \"k_proj\", \"v_proj\", \"out_proj\"]\n    ```\n\n    The list of target modules is passed to Hugging Face's PEFT library. See\n    [the docs](https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.target_modules) for\n    details.\n    \"\"\"\n\n    cache_text_encoder_outputs: bool = False\n    \"\"\"If True, the text encoder(s) will be applied to all of the captions in the dataset before starting training and\n    the results will be cached to disk. This reduces the VRAM requirements during training (don't have to keep the\n    text encoders in VRAM), and speeds up training  (don't have to run the text encoders for each training example).\n    This option can only be enabled if `train_text_encoder == False` and there are no caption augmentations being\n    applied.\n    \"\"\"\n\n    cache_vae_outputs: bool = False\n    \"\"\"If True, the VAE will be applied to all of the images in the dataset before starting training and the results\n    will be cached to disk. This reduces the VRAM requirements during training (don't have to keep the VAE in VRAM), and\n    speeds up training (don't have to run the VAE encoding step). This option can only be enabled if all\n    non-deterministic image augmentations are disabled (i.e. center_crop=True, random_flip=False).\n    \"\"\"\n\n    enable_cpu_offload_during_validation: bool = False\n    \"\"\"If True, models will be kept in CPU memory and loaded into GPU memory one-by-one while generating validation\n    images. This reduces VRAM requirements at the cost of slower generation of validation images.\n    \"\"\"\n\n    gradient_accumulation_steps: int = 1\n    \"\"\"The number of gradient steps to accumulate before each weight update. This value is passed to Hugging Face\n    Accelerate. This is an alternative to increasing the batch size when training with limited VRAM.\n    \"\"\"\n\n    weight_dtype: Literal[\"float32\", \"float16\", \"bfloat16\"] = \"bfloat16\"\n    \"\"\"All weights (trainable and fixed) will be cast to this precision. Lower precision dtypes require less VRAM, and\n    result in faster training, but are more prone to issues with numerical stability.\n\n    Recommendations:\n\n    - `\"float32\"`: Use this mode if you have plenty of VRAM available.\n    - `\"bfloat16\"`: Use this mode if you have limited VRAM and a GPU that supports bfloat16.\n    - `\"float16\"`: Use this mode if you have limited VRAM and a GPU that does not support bfloat16.\n\n    See also [`mixed_precision`][invoke_training.pipelines.stable_diffusion.lora.config.SdLoraConfig.mixed_precision].\n    \"\"\"  # noqa: E501\n\n    mixed_precision: Literal[\"no\", \"fp16\", \"bf16\", \"fp8\"] = \"no\"\n    \"\"\"The mixed precision mode to use.\n\n    If mixed precision is enabled, then all non-trainable parameters will be cast to the specified `weight_dtype`, and\n    trainable parameters are kept in float32 precision to avoid issues with numerical stability.\n\n    This value is passed to Hugging Face Accelerate. See\n    [`accelerate.Accelerator.mixed_precision`](https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.mixed_precision)\n    for more details.\n    \"\"\"  # noqa: E501\n\n    xformers: bool = False\n    \"\"\"If true, use xformers for more efficient attention blocks.\n    \"\"\"\n\n    gradient_checkpointing: bool = False\n    \"\"\"Whether or not to use gradient checkpointing to save memory at the expense of a slower backward pass. Enabling\n    gradient checkpointing slows down training by ~20%.\n    \"\"\"\n\n    max_checkpoints: int | None = None\n    \"\"\"The maximum number of checkpoints to keep. New checkpoints will replace earlier checkpoints to stay under this\n    limit. Note that this limit is applied to 'step' and 'epoch' checkpoints separately.\n    \"\"\"\n\n    prediction_type: Literal[\"epsilon\", \"v_prediction\"] | None = None\n    \"\"\"The prediction_type that will be used for training. Choose between 'epsilon' or 'v_prediction' or leave 'None'.\n    If 'None', the prediction type of the scheduler: `noise_scheduler.config.prediction_type` is used.\n    \"\"\"\n\n    max_grad_norm: float | None = None\n    \"\"\"Max gradient norm for clipping. Set to null or 0 for no clipping.\n    \"\"\"\n\n    validation_prompts: list[str] = []\n    \"\"\"A list of prompts that will be used to generate images throughout training for the purpose of tracking progress.\n    See also 'validate_every_n_epochs'.\n    \"\"\"\n\n    negative_validation_prompts: list[str] | None = None\n    \"\"\"A list of negative prompts that will be applied when generating validation images. If set, this list should have\n    the same length as 'validation_prompts'.\n    \"\"\"\n\n    num_validation_images_per_prompt: int = 4\n    \"\"\"The number of validation images to generate for each prompt in 'validation_prompts'. Careful, validation can\n    become quite slow if this number is too large.\n    \"\"\"\n\n    train_batch_size: int = 4\n    \"\"\"The training batch size.\n    \"\"\"\n\n    use_masks: bool = False\n    \"\"\"If True, image masks will be applied to weight the loss during training. The dataset must contain masks for this\n    feature to be used.\n    \"\"\"\n\n    data_loader: Annotated[\n        Union[ImageCaptionSDDataLoaderConfig, DreamboothSDDataLoaderConfig], Field(discriminator=\"type\")\n    ]\n\n    @model_validator(mode=\"after\")\n    def check_validation_prompts(self):\n        if self.negative_validation_prompts is not None and len(self.negative_validation_prompts) != len(\n            self.validation_prompts\n        ):\n            raise ValueError(\n                f\"The number of validation_prompts ({len(self.validation_prompts)}) must match the number of \"\n                f\"negative_validation_prompts ({len(self.negative_validation_prompts)}).\"\n            )\n        return self\n"
  },
  {
    "path": "src/invoke_training/pipelines/stable_diffusion/lora/train.py",
    "content": "import itertools\nimport json\nimport logging\nimport math\nimport os\nimport tempfile\nimport time\nfrom pathlib import Path\nfrom typing import Literal, Optional, Union\n\nimport peft\nimport torch\nimport torch.utils.data\nfrom accelerate.utils import set_seed\nfrom diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel\nfrom diffusers.optimization import get_scheduler\nfrom torch.utils.data import DataLoader\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPTextModel, CLIPTokenizer\n\nfrom invoke_training._shared.accelerator.accelerator_utils import (\n    get_dtype_from_str,\n    initialize_accelerator,\n    initialize_logging,\n)\nfrom invoke_training._shared.checkpoints.checkpoint_tracker import CheckpointTracker\nfrom invoke_training._shared.data.data_loaders.dreambooth_sd_dataloader import build_dreambooth_sd_dataloader\nfrom invoke_training._shared.data.data_loaders.image_caption_sd_dataloader import build_image_caption_sd_dataloader\nfrom invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import log_aspect_ratio_buckets\nfrom invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache\nfrom invoke_training._shared.optimizer.optimizer_utils import initialize_optimizer\nfrom invoke_training._shared.stable_diffusion.lora_checkpoint_utils import (\n    save_sd_kohya_checkpoint,\n    save_sd_peft_checkpoint,\n)\nfrom invoke_training._shared.stable_diffusion.min_snr_weighting import compute_snr\nfrom invoke_training._shared.stable_diffusion.model_loading_utils import load_models_sd\nfrom invoke_training._shared.stable_diffusion.tokenize_captions import tokenize_captions\nfrom invoke_training._shared.stable_diffusion.validation import generate_validation_images_sd\nfrom invoke_training._shared.utils.import_xformers import import_xformers\nfrom invoke_training.config.data.data_loader_config import DreamboothSDDataLoaderConfig, ImageCaptionSDDataLoaderConfig\nfrom invoke_training.pipelines.callbacks import ModelCheckpoint, ModelType, PipelineCallbacks, TrainingCheckpoint\nfrom invoke_training.pipelines.stable_diffusion.lora.config import SdLoraConfig\n\n\ndef _save_sd_lora_checkpoint(\n    epoch: int,\n    step: int,\n    unet: peft.PeftModel | None,\n    text_encoder: peft.PeftModel | None,\n    logger: logging.Logger,\n    checkpoint_tracker: CheckpointTracker,\n    lora_checkpoint_format: Literal[\"invoke_peft\", \"kohya\"],\n    callbacks: list[PipelineCallbacks] | None,\n):\n    # Prune checkpoints and get new checkpoint path.\n    num_pruned = checkpoint_tracker.prune(1)\n    if num_pruned > 0:\n        logger.info(f\"Pruned {num_pruned} checkpoint(s).\")\n    save_path = checkpoint_tracker.get_path(epoch=epoch, step=step)\n\n    if lora_checkpoint_format == \"invoke_peft\":\n        model_type = ModelType.SD1_LORA_PEFT\n        save_sd_peft_checkpoint(Path(save_path), unet=unet, text_encoder=text_encoder)\n    elif lora_checkpoint_format == \"kohya\":\n        model_type = ModelType.SD1_LORA_KOHYA\n        save_sd_kohya_checkpoint(Path(save_path), unet=unet, text_encoder=text_encoder)\n    else:\n        raise ValueError(f\"Unsupported lora_checkpoint_format: '{lora_checkpoint_format}'.\")\n\n    if callbacks is not None:\n        for cb in callbacks:\n            cb.on_save_checkpoint(\n                TrainingCheckpoint(\n                    models=[ModelCheckpoint(file_path=save_path, model_type=model_type)], epoch=epoch, step=step\n                )\n            )\n\n\ndef _build_data_loader(\n    data_loader_config: Union[ImageCaptionSDDataLoaderConfig, DreamboothSDDataLoaderConfig],\n    batch_size: int,\n    use_masks: bool = False,\n    text_encoder_output_cache_dir: Optional[str] = None,\n    vae_output_cache_dir: Optional[str] = None,\n    shuffle: bool = True,\n    sequential_batching: bool = False,\n) -> DataLoader:\n    if data_loader_config.type == \"IMAGE_CAPTION_SD_DATA_LOADER\":\n        return build_image_caption_sd_dataloader(\n            config=data_loader_config,\n            batch_size=batch_size,\n            use_masks=use_masks,\n            text_encoder_output_cache_dir=text_encoder_output_cache_dir,\n            text_encoder_cache_field_to_output_field={\"text_encoder_output\": \"text_encoder_output\"},\n            vae_output_cache_dir=vae_output_cache_dir,\n            shuffle=shuffle,\n        )\n    elif data_loader_config.type == \"DREAMBOOTH_SD_DATA_LOADER\":\n        if use_masks:\n            raise NotImplementedError(\"Masks are not yet supported for DreamBooth data loaders.\")\n        return build_dreambooth_sd_dataloader(\n            config=data_loader_config,\n            batch_size=batch_size,\n            text_encoder_output_cache_dir=text_encoder_output_cache_dir,\n            text_encoder_cache_field_to_output_field={\"text_encoder_output\": \"text_encoder_output\"},\n            vae_output_cache_dir=vae_output_cache_dir,\n            shuffle=shuffle,\n            sequential_batching=sequential_batching,\n        )\n    else:\n        raise ValueError(f\"Unsupported data loader config type: '{data_loader_config.type}'.\")\n\n\ndef cache_text_encoder_outputs(\n    cache_dir: str, config: SdLoraConfig, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel\n):\n    \"\"\"Run the text encoder on all captions in the dataset and cache the results to disk.\n\n    Args:\n        cache_dir (str): The directory where the results will be cached.\n        config (SdLoraConfig): Training config.\n        tokenizer (CLIPTokenizer): The tokenizer.\n        text_encoder (CLIPTextModel): The text_encoder.\n    \"\"\"\n    data_loader = _build_data_loader(\n        data_loader_config=config.data_loader,\n        batch_size=config.train_batch_size,\n        shuffle=False,\n        sequential_batching=True,\n    )\n\n    cache = TensorDiskCache(cache_dir)\n\n    for data_batch in tqdm(data_loader):\n        caption_token_ids = tokenize_captions(tokenizer, data_batch[\"caption\"]).to(text_encoder.device)\n        text_encoder_output_batch = text_encoder(caption_token_ids)[0]\n        # Split batch before caching.\n        for i in range(len(data_batch[\"id\"])):\n            cache.save(data_batch[\"id\"][i], {\"text_encoder_output\": text_encoder_output_batch[i]})\n\n\ndef cache_vae_outputs(cache_dir: str, data_loader: DataLoader, vae: AutoencoderKL):\n    \"\"\"Run the VAE on all images in the dataset and cache the results to disk.\"\"\"\n    cache = TensorDiskCache(cache_dir)\n\n    for data_batch in tqdm(data_loader):\n        latents = vae.encode(data_batch[\"image\"].to(device=vae.device, dtype=vae.dtype)).latent_dist.sample()\n        latents = latents * vae.config.scaling_factor\n        # Split batch before caching.\n        for i in range(len(data_batch[\"id\"])):\n            data = {\n                \"vae_output\": latents[i],\n                \"original_size_hw\": data_batch[\"original_size_hw\"][i],\n                \"crop_top_left_yx\": data_batch[\"crop_top_left_yx\"][i],\n            }\n            if \"mask\" in data_batch:\n                data[\"mask\"] = data_batch[\"mask\"][i]\n            cache.save(data_batch[\"id\"][i], data)\n\n\ndef train_forward(  # noqa: C901\n    config: SdLoraConfig,\n    data_batch: dict,\n    vae: AutoencoderKL,\n    noise_scheduler: DDPMScheduler,\n    tokenizer: CLIPTokenizer,\n    text_encoder: CLIPTextModel,\n    unet: UNet2DConditionModel,\n    weight_dtype: torch.dtype,\n    use_masks: bool = False,\n    min_snr_gamma: float | None = None,\n) -> torch.Tensor:\n    \"\"\"Run the forward training pass for a single data_batch.\n\n    Returns:\n        torch.Tensor: Loss\n    \"\"\"\n    # Convert images to latent space.\n    # The VAE output may have been cached and included in the data_batch. If not, we calculate it here.\n    latents = data_batch.get(\"vae_output\", None)\n    if latents is None:\n        latents = vae.encode(data_batch[\"image\"].to(dtype=weight_dtype)).latent_dist.sample()\n        latents = latents * vae.config.scaling_factor\n\n    # Sample noise that we'll add to the latents.\n    noise = torch.randn_like(latents)\n\n    batch_size = latents.shape[0]\n    # Sample a random timestep for each image.\n    timesteps = torch.randint(\n        0,\n        noise_scheduler.config.num_train_timesteps,\n        (batch_size,),\n        device=latents.device,\n    )\n    timesteps = timesteps.long()\n\n    # Add noise to the latents according to the noise magnitude at each timestep (this is the forward\n    # diffusion process).\n    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n\n    # Get the text embedding for conditioning.\n    # The text_encoder_output may have been cached and included in the data_batch. If not, we calculate it here.\n    encoder_hidden_states = data_batch.get(\"text_encoder_output\", None)\n    if encoder_hidden_states is None:\n        caption_token_ids = tokenize_captions(tokenizer, data_batch[\"caption\"]).to(text_encoder.device)\n        encoder_hidden_states = text_encoder(caption_token_ids)[0].to(dtype=weight_dtype)\n\n    # Get the target for loss depending on the prediction type.\n    if config.prediction_type is not None:\n        # Set the prediction_type of scheduler if it's defined in config.\n        noise_scheduler.register_to_config(prediction_type=config.prediction_type)\n    if noise_scheduler.config.prediction_type == \"epsilon\":\n        target = noise\n    elif noise_scheduler.config.prediction_type == \"v_prediction\":\n        target = noise_scheduler.get_velocity(latents, noise, timesteps)\n    else:\n        raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n    # Predict the noise residual.\n    model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample\n\n    min_snr_weights = None\n    if min_snr_gamma is not None:\n        # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.\n        # Since we predict the noise instead of x_0, the original formulation is slightly changed.\n        # This is discussed in Section 4.2 of the same paper.\n\n        snr = compute_snr(noise_scheduler, timesteps)\n\n        # Note: We divide by snr here per Section 4.2 of the paper, since we are predicting the noise instead of x_0.\n        # w_t = min(1, SNR(t)) / SNR(t)\n        min_snr_weights = torch.clamp(snr, max=min_snr_gamma) / snr\n\n        if noise_scheduler.config.prediction_type == \"epsilon\":\n            pass\n        elif noise_scheduler.config.prediction_type == \"v_prediction\":\n            # Velocity objective needs to be floored to an SNR weight of one.\n            min_snr_weights = min_snr_weights + 1\n        else:\n            raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n    loss = torch.nn.functional.mse_loss(model_pred.float(), target.float(), reduction=\"none\")\n\n    if use_masks:\n        # TODO(ryand): As a future performance optimization, we may want to do this resizing in the dataloader.\n        mask = data_batch[\"mask\"].to(dtype=loss.dtype, device=loss.device)\n        _, _, latent_h, latent_w = loss.shape\n        mask = torch.nn.functional.interpolate(mask, size=(latent_h, latent_w), mode=\"nearest\")\n        loss = loss * mask\n\n    # Mean-reduce the loss along all dimensions except for the batch dimension.\n    loss = loss.mean(dim=list(range(1, len(loss.shape))))\n\n    # Apply min_snr_weights.\n    if min_snr_weights is not None:\n        loss = loss * min_snr_weights\n\n    # Apply per-example loss weights.\n    if \"loss_weight\" in data_batch:\n        loss = loss * data_batch[\"loss_weight\"]\n\n    return loss.mean()\n\n\ndef train(config: SdLoraConfig, callbacks: list[PipelineCallbacks] | None = None):  # noqa: C901\n    # Give a clear error message if an unsupported base model was chosen.\n    # TODO(ryan): Update this check to work with single-file SD checkpoints.\n    # check_base_model_version(\n    #     {BaseModelVersionEnum.STABLE_DIFFUSION_V1, BaseModelVersionEnum.STABLE_DIFFUSION_V2},\n    #     config.model,\n    #     local_files_only=False,\n    # )\n\n    # Create a timestamped directory for all outputs.\n    out_dir = os.path.join(config.base_output_dir, f\"{time.time()}\")\n    ckpt_dir = os.path.join(out_dir, \"checkpoints\")\n    os.makedirs(ckpt_dir)\n\n    accelerator = initialize_accelerator(\n        out_dir, config.gradient_accumulation_steps, config.mixed_precision, config.report_to\n    )\n    logger = initialize_logging(os.path.basename(__file__), accelerator)\n\n    # Set the accelerate seed.\n    if config.seed is not None:\n        set_seed(config.seed)\n\n    # Log the accelerator configuration from every process to help with debugging.\n    logger.info(accelerator.state, main_process_only=False)\n\n    logger.info(\"Starting LoRA Training.\")\n    logger.info(f\"Configuration:\\n{json.dumps(config.dict(), indent=2, default=str)}\")\n    logger.info(f\"Output dir: '{out_dir}'\")\n\n    # Write the configuration to disk.\n    with open(os.path.join(out_dir, \"config.json\"), \"w\") as f:\n        json.dump(config.dict(), f, indent=2, default=str)\n\n    weight_dtype = get_dtype_from_str(config.weight_dtype)\n\n    logger.info(\"Loading models.\")\n    tokenizer, noise_scheduler, text_encoder, vae, unet = load_models_sd(\n        logger=logger,\n        model_name_or_path=config.model,\n        hf_variant=config.hf_variant,\n        base_embeddings=config.base_embeddings,\n        dtype=weight_dtype,\n    )\n\n    if config.xformers:\n        import_xformers()\n\n        # TODO(ryand): There is a known issue if xformers is enabled when training in mixed precision where xformers\n        # will fail because Q, K, V have different dtypes.\n        unet.enable_xformers_memory_efficient_attention()\n        vae.enable_xformers_memory_efficient_attention()\n\n    # Prepare text encoder output cache.\n    text_encoder_output_cache_dir_name = None\n    if config.cache_text_encoder_outputs:\n        # TODO(ryand): Think about how to better check if it is safe to cache the text encoder outputs. Currently, there\n        # are a number of configurations that would cause variation in the text encoder outputs and should not be used\n        # with caching.\n        if config.train_text_encoder:\n            raise ValueError(\"'cache_text_encoder_outputs' and 'train_text_encoder' cannot both be True.\")\n\n        # We use a temporary directory for the cache. The directory will automatically be cleaned up when\n        # tmp_text_encoder_output_cache_dir is destroyed.\n        tmp_text_encoder_output_cache_dir = tempfile.TemporaryDirectory()\n        text_encoder_output_cache_dir_name = tmp_text_encoder_output_cache_dir.name\n        if accelerator.is_local_main_process:\n            # Only the main process should populate the cache.\n            logger.info(f\"Generating text encoder output cache ('{text_encoder_output_cache_dir_name}').\")\n            text_encoder.to(accelerator.device, dtype=weight_dtype)\n            cache_text_encoder_outputs(text_encoder_output_cache_dir_name, config, tokenizer, text_encoder)\n        # Move the text_encoder back to the CPU, because it is not needed for training.\n        text_encoder.to(\"cpu\")\n        accelerator.wait_for_everyone()\n    else:\n        text_encoder.to(accelerator.device, dtype=weight_dtype)\n\n    # Prepare VAE output cache.\n    vae_output_cache_dir_name = None\n    if config.cache_vae_outputs:\n        if config.data_loader.random_flip:\n            raise ValueError(\"'cache_vae_outputs' cannot be True if 'random_flip' is True.\")\n        if not config.data_loader.center_crop:\n            raise ValueError(\"'cache_vae_outputs' cannot be True if 'center_crop' is False.\")\n\n        # We use a temporary directory for the cache. The directory will automatically be cleaned up when\n        # tmp_vae_output_cache_dir is destroyed.\n        tmp_vae_output_cache_dir = tempfile.TemporaryDirectory()\n        vae_output_cache_dir_name = tmp_vae_output_cache_dir.name\n        if accelerator.is_local_main_process:\n            # Only the main process should populate the cache.\n            logger.info(f\"Generating VAE output cache ('{vae_output_cache_dir_name}').\")\n            vae.to(accelerator.device, dtype=weight_dtype)\n            data_loader = _build_data_loader(\n                data_loader_config=config.data_loader,\n                batch_size=config.train_batch_size,\n                use_masks=config.use_masks,\n                shuffle=False,\n                sequential_batching=True,\n            )\n            cache_vae_outputs(vae_output_cache_dir_name, data_loader, vae)\n        # Move the VAE back to the CPU, because it is not needed for training.\n        vae.to(\"cpu\")\n        accelerator.wait_for_everyone()\n    else:\n        vae.to(accelerator.device, dtype=weight_dtype)\n\n    unet.to(accelerator.device, dtype=weight_dtype)\n\n    # Add LoRA layers to the models being trained.\n    trainable_param_groups = []\n    all_trainable_models: list[peft.PeftModel] = []\n\n    def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = None) -> peft.PeftModel:\n        peft_model = peft.get_peft_model(model, lora_config)\n        peft_model.print_trainable_parameters()\n\n        # Populate `trainable_param_groups`, to be passed to the optimizer.\n        param_group = {\"params\": list(filter(lambda p: p.requires_grad, peft_model.parameters()))}\n        if lr is not None:\n            param_group[\"lr\"] = lr\n        trainable_param_groups.append(param_group)\n\n        # Populate all_trainable_models.\n        all_trainable_models.append(peft_model)\n\n        peft_model.train()\n\n        return peft_model\n\n    # Add LoRA layers to the model.\n    if config.train_unet:\n        unet_lora_config = peft.LoraConfig(\n            r=config.lora_rank_dim,\n            # TODO(ryand): Diffusers uses lora_alpha=config.lora_rank_dim. Is that preferred?\n            lora_alpha=1.0,\n            target_modules=config.unet_lora_target_modules,\n        )\n        unet = inject_lora_layers(unet, unet_lora_config, lr=config.unet_learning_rate)\n\n    if config.train_text_encoder:\n        text_encoder_lora_config = peft.LoraConfig(\n            r=config.lora_rank_dim,\n            lora_alpha=1.0,\n            # init_lora_weights=\"gaussian\",\n            target_modules=config.text_encoder_lora_target_modules,\n        )\n        text_encoder = inject_lora_layers(text_encoder, text_encoder_lora_config, lr=config.text_encoder_learning_rate)\n\n    # If mixed_precision is enabled, cast all trainable params to float32.\n    if config.mixed_precision != \"no\":\n        for trainable_model in all_trainable_models:\n            for param in trainable_model.parameters():\n                if param.requires_grad:\n                    param.data = param.to(torch.float32)\n\n    if config.gradient_checkpointing:\n        # We want to enable gradient checkpointing in the UNet regardless of whether it is being trained.\n        unet.enable_gradient_checkpointing()\n        # unet must be in train() mode for gradient checkpointing to take effect.\n        # At the time of writing, the unet dropout probabilities default to 0, so putting the unet in train mode does\n        # not change its forward behavior.\n        unet.train()\n        if config.train_text_encoder:\n            text_encoder.gradient_checkpointing_enable()\n\n            # The text encoder must be in train() mode for gradient checkpointing to take effect. This should\n            # already be the case, since we are training the text_encoder, but we do it explicitly to make it clear\n            # that this is required.\n            # At the time of writing, the text encoder dropout probabilities default to 0, so putting the text\n            # encoders in train mode does not change their forward behavior.\n            text_encoder.train()\n\n            # Set requires_grad = True on the first parameters of the text encoders. Without this, the text encoder\n            # LoRA weights would have 0 gradients, and so would not get trained. Note that the set of\n            # trainable_param_groups has already been populated - the embeddings will not be trained.\n            text_encoder.text_model.embeddings.requires_grad_(True)\n\n    optimizer = initialize_optimizer(config.optimizer, trainable_param_groups)\n\n    data_loader = _build_data_loader(\n        data_loader_config=config.data_loader,\n        batch_size=config.train_batch_size,\n        use_masks=config.use_masks,\n        text_encoder_output_cache_dir=text_encoder_output_cache_dir_name,\n        vae_output_cache_dir=vae_output_cache_dir_name,\n    )\n\n    log_aspect_ratio_buckets(logger=logger, batch_sampler=data_loader.batch_sampler)\n\n    assert sum([config.max_train_steps is not None, config.max_train_epochs is not None]) == 1\n    assert sum([config.save_every_n_steps is not None, config.save_every_n_epochs is not None]) == 1\n    assert sum([config.validate_every_n_steps is not None, config.validate_every_n_epochs is not None]) == 1\n\n    # A \"step\" represents a single weight update operation (i.e. takes into account gradient accumulation steps).\n    # math.ceil(...) is used in calculating the num_steps_per_epoch, because by default an optimizer step is taken when\n    # the end of the dataloader is reached, even if gradient_accumulation_steps hasn't been reached.\n    num_steps_per_epoch = math.ceil(len(data_loader) / config.gradient_accumulation_steps)\n    num_train_steps = config.max_train_steps or config.max_train_epochs * num_steps_per_epoch\n    num_train_epochs = math.ceil(num_train_steps / num_steps_per_epoch)\n\n    # TODO(ryand): Test in a distributed training environment and more clearly document the rationale for scaling steps\n    # by the number of processes. This scaling logic was copied from the diffusers example training code, but it appears\n    # in many places so I don't know where it originated. Internally, accelerate makes one LR scheduler step per process\n    # (https://github.com/huggingface/accelerate/blame/49cb83a423f2946059117d8bb39b7c8747d29d80/src/accelerate/scheduler.py#L72-L82),\n    # so the scaling here simply reverses that behaviour.\n    lr_scheduler: torch.optim.lr_scheduler.LRScheduler = get_scheduler(\n        config.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=config.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=num_train_steps * accelerator.num_processes,\n    )\n\n    prepared_result: tuple[\n        UNet2DConditionModel,\n        CLIPTextModel,\n        torch.optim.Optimizer,\n        torch.utils.data.DataLoader,\n        torch.optim.lr_scheduler.LRScheduler,\n    ] = accelerator.prepare(\n        unet,\n        text_encoder,\n        optimizer,\n        data_loader,\n        lr_scheduler,\n        # Disable automatic device placement for text_encoder if the text encoder outputs were cached.\n        device_placement=[True, not config.cache_text_encoder_outputs, True, True, True],\n    )\n    unet, text_encoder, optimizer, data_loader, lr_scheduler = prepared_result\n\n    if accelerator.is_main_process:\n        accelerator.init_trackers(\"lora_training\")\n        # Tensorboard uses markdown formatting, so we wrap the config json in a code block.\n        accelerator.log({\"configuration\": f\"```json\\n{json.dumps(config.dict(), indent=2, default=str)}\\n```\\n\"})\n\n    checkpoint_tracker = CheckpointTracker(\n        base_dir=ckpt_dir,\n        prefix=\"checkpoint\",\n        max_checkpoints=config.max_checkpoints,\n        extension=\".safetensors\" if config.lora_checkpoint_format == \"kohya\" else None,\n    )\n\n    # Train!\n    total_batch_size = config.train_batch_size * accelerator.num_processes * config.gradient_accumulation_steps\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num batches = {len(data_loader)}\")\n    logger.info(f\"  Instantaneous batch size per device = {config.train_batch_size}\")\n    logger.info(f\"  Gradient accumulation steps = {config.gradient_accumulation_steps}\")\n    logger.info(f\"  Parallel processes = {accelerator.num_processes}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Total optimization steps = {num_train_steps}\")\n    logger.info(f\"  Total epochs = {num_train_epochs}\")\n\n    global_step = 0\n    first_epoch = 0\n    completed_epochs = 0\n\n    progress_bar = tqdm(\n        range(global_step, num_train_steps),\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n    progress_bar.set_description(\"Steps\")\n\n    def save_checkpoint(num_completed_epochs: int, num_completed_steps: int):\n        accelerator.wait_for_everyone()\n        if accelerator.is_main_process:\n            _save_sd_lora_checkpoint(\n                epoch=num_completed_epochs,\n                step=num_completed_steps,\n                unet=accelerator.unwrap_model(unet) if config.train_unet else None,\n                text_encoder=accelerator.unwrap_model(text_encoder) if config.train_text_encoder else None,\n                logger=logger,\n                checkpoint_tracker=checkpoint_tracker,\n                lora_checkpoint_format=config.lora_checkpoint_format,\n                callbacks=callbacks,\n            )\n        accelerator.wait_for_everyone()\n\n    def validate(num_completed_epochs: int, num_completed_steps: int):\n        accelerator.wait_for_everyone()\n        if accelerator.is_main_process:\n            generate_validation_images_sd(\n                epoch=num_completed_epochs,\n                step=num_completed_steps,\n                out_dir=out_dir,\n                accelerator=accelerator,\n                vae=vae,\n                text_encoder=text_encoder,\n                tokenizer=tokenizer,\n                noise_scheduler=noise_scheduler,\n                unet=unet,\n                config=config,\n                logger=logger,\n                callbacks=callbacks,\n            )\n        accelerator.wait_for_everyone()\n\n    for epoch in range(first_epoch, num_train_epochs):\n        train_loss = 0.0\n        for data_batch_idx, data_batch in enumerate(data_loader):\n            with accelerator.accumulate(unet, text_encoder):\n                loss = train_forward(\n                    config=config,\n                    data_batch=data_batch,\n                    vae=vae,\n                    noise_scheduler=noise_scheduler,\n                    tokenizer=tokenizer,\n                    text_encoder=text_encoder,\n                    unet=unet,\n                    weight_dtype=weight_dtype,\n                    use_masks=config.use_masks,\n                    min_snr_gamma=config.min_snr_gamma,\n                )\n\n                # Gather the losses across all processes for logging (if we use distributed training).\n                # TODO(ryand): Test that this works properly with distributed training.\n                avg_loss = accelerator.gather(loss.repeat(config.train_batch_size)).mean()\n                train_loss += avg_loss.item() / config.gradient_accumulation_steps\n\n                # Backpropagate.\n                accelerator.backward(loss)\n                if accelerator.sync_gradients and config.max_grad_norm is not None:\n                    params_to_clip = itertools.chain.from_iterable([m.parameters() for m in all_trainable_models])\n                    accelerator.clip_grad_norm_(params_to_clip, config.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad(set_to_none=True)\n\n            # Checks if the accelerator has performed an optimization step behind the scenes.\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n                completed_epochs = epoch if (data_batch_idx + 1) < len(data_loader) else epoch + 1\n                log = {\"train_loss\": train_loss}\n\n                lrs = lr_scheduler.get_last_lr()\n                if config.train_unet:\n                    # When training the UNet, it will always be the first parameter group.\n                    log[\"lr/unet\"] = float(lrs[0])\n                    if config.optimizer.optimizer_type == \"Prodigy\":\n                        log[\"lr/d*lr/unet\"] = optimizer.param_groups[0][\"d\"] * optimizer.param_groups[0][\"lr\"]\n                if config.train_text_encoder:\n                    # When training the text encoder, it will always be the last parameter group.\n                    log[\"lr/text_encoder\"] = float(lrs[-1])\n                    if config.optimizer.optimizer_type == \"Prodigy\":\n                        log[\"lr/d*lr/text_encoder\"] = optimizer.param_groups[-1][\"d\"] * optimizer.param_groups[-1][\"lr\"]\n\n                accelerator.log(log, step=global_step)\n                train_loss = 0.0\n\n                # global_step represents the *number of completed steps* at this point.\n                if config.save_every_n_steps is not None and global_step % config.save_every_n_steps == 0:\n                    save_checkpoint(num_completed_epochs=completed_epochs, num_completed_steps=global_step)\n\n                if (\n                    config.validate_every_n_steps is not None\n                    and global_step % config.validate_every_n_steps == 0\n                    and len(config.validation_prompts) > 0\n                ):\n                    validate(num_completed_epochs=completed_epochs, num_completed_steps=global_step)\n\n            logs = {\n                \"step_loss\": loss.detach().item(),\n                \"lr\": lr_scheduler.get_last_lr()[0],\n            }\n            progress_bar.set_postfix(**logs)\n\n            if global_step >= num_train_steps:\n                break\n\n        # Save a checkpoint every n epochs.\n        if config.save_every_n_epochs is not None and completed_epochs % config.save_every_n_epochs == 0:\n            save_checkpoint(num_completed_epochs=completed_epochs, num_completed_steps=global_step)\n\n        # Generate validation images every n epochs.\n        if (\n            config.validate_every_n_epochs is not None\n            and completed_epochs % config.validate_every_n_epochs == 0\n            and len(config.validation_prompts) > 0\n        ):\n            validate(num_completed_epochs=completed_epochs, num_completed_steps=global_step)\n\n    accelerator.end_training()\n"
  },
  {
    "path": "src/invoke_training/pipelines/stable_diffusion/textual_inversion/__init__.py",
    "content": ""
  },
  {
    "path": "src/invoke_training/pipelines/stable_diffusion/textual_inversion/config.py",
    "content": "from typing import Literal\n\nfrom pydantic import model_validator\n\nfrom invoke_training.config.base_pipeline_config import BasePipelineConfig\nfrom invoke_training.config.data.data_loader_config import TextualInversionSDDataLoaderConfig\nfrom invoke_training.config.optimizer.optimizer_config import AdamOptimizerConfig, ProdigyOptimizerConfig\n\n\nclass SdTextualInversionConfig(BasePipelineConfig):\n    type: Literal[\"SD_TEXTUAL_INVERSION\"] = \"SD_TEXTUAL_INVERSION\"\n    \"\"\"Must be `SD_TEXTUAL_INVERSION`. This is what differentiates training pipeline types.\n    \"\"\"\n\n    model: str\n    \"\"\"Name or path of the base model to train. Can be in diffusers format, or a single stable diffusion checkpoint\n    file. (E.g. `\"runwayml/stable-diffusion-v1-5\"`, `\"stabilityai/stable-diffusion-xl-base-1.0\"`,\n    `\"/path/to/local/model.safetensors\"`, etc.)\n\n    The model architecture must match the training pipeline being run. For example, if running a\n    Textual Inversion SDXL pipeline, then `model` must refer to an SDXL model.\n    \"\"\"\n\n    hf_variant: str | None = \"fp16\"\n    \"\"\"The Hugging Face Hub model variant to use. Only applies if `model` is a Hugging Face Hub model name.\n    \"\"\"\n\n    # Helpful discussion for understanding how this works at inference time:\n    # https://github.com/huggingface/diffusers/pull/3144#discussion_r1172413509\n    num_vectors: int = 1\n    \"\"\"Note: `num_vectors` can be overridden by `initial_phrase`.\n\n    The number of textual inversion embedding vectors that will be used to learn the concept.\n\n    Increasing the `num_vectors` enables the model to learn more complex concepts, but has the following drawbacks:\n\n    - greater risk of overfitting\n    - increased size of the resulting output file\n    - consumes more of the prompt capacity at inference time\n\n    Typical values for `num_vectors` are in the range [1, 16].\n\n    As a rule of thumb, `num_vectors` can be increased as the size of the dataset increases (without overfitting).\n    \"\"\"\n\n    placeholder_token: str\n    \"\"\"The special word to associate the learned embeddings with. Choose a unique token that is unlikely to already\n    exist in the tokenizer's vocabulary.\n    \"\"\"\n\n    initializer_token: str | None = None\n    \"\"\"Note: Exactly one of `initializer_token`, `initial_embedding_file`, or `initial_phrase` should be set.\n\n    A vocabulary token to use as an initializer for the placeholder token. It should be a single word that roughly\n    describes the object or style that you're trying to train on. Must map to a single tokenizer token.\n\n    For example, if you are training on a dataset of images of your pet dog, a good choice would be `dog`.\n    \"\"\"\n\n    initial_embedding_file: str | None = None\n    \"\"\"Note: Exactly one of `initializer_token`, `initial_embedding_file`, or `initial_phrase` should be set.\n\n    Path to an existing TI embedding that will be used to initialize the embedding being trained. The placeholder\n    token in the file must match the `placeholder_token` field.\n\n    Either `initializer_token` or `initial_embedding_file` should be set.\n    \"\"\"\n\n    initial_phrase: str | None = None\n    \"\"\"Note: Exactly one of `initializer_token`, `initial_embedding_file`, or `initial_phrase` should be set.\n\n    A phrase that will be used to initialize the placeholder token embedding. The phrase will be tokenized, and the\n    corresponding embeddings will be used to initialize the placeholder tokens. The number of embedding vectors will be\n    inferred from the length of the tokenized phrase, so keep the phrase short. The consequences of training a large\n    number of embedding vectors are discussed in the `num_vectors` field documentation.\n\n    For example, if you are training on a dataset of images of pokemon, you might use `pokemon sketch white background`.\n    \"\"\"\n\n    optimizer: AdamOptimizerConfig | ProdigyOptimizerConfig = AdamOptimizerConfig()\n\n    lr_scheduler: Literal[\n        \"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\", \"constant\", \"constant_with_warmup\"\n    ] = \"constant\"\n\n    lr_warmup_steps: int = 0\n    \"\"\"The number of warmup steps in the learning rate scheduler. Only applied to schedulers that support warmup.\n    See lr_scheduler.\n    \"\"\"\n\n    min_snr_gamma: float | None = 5.0\n    \"\"\"Min-SNR weighting for diffusion training was introduced in https://arxiv.org/abs/2303.09556. This strategy\n    improves the speed of training convergence by adjusting the weight of each sample.\n\n    `min_snr_gamma` acts like an an upper bound on the weight of samples with low noise levels.\n\n    If `None`, then Min-SNR weighting will not be applied. If enabled, the recommended value is `min_snr_gamma = 5.0`.\n    \"\"\"\n\n    cache_vae_outputs: bool = False\n    \"\"\"If True, the VAE will be applied to all of the images in the dataset before starting training and the results\n    will be cached to disk. This reduces the VRAM requirements during training (don't have to keep the VAE in VRAM), and\n    speeds up training (don't have to run the VAE encoding step).\n\n    This option can only be enabled if all non-deterministic image augmentations are disabled (i.e. `center_crop=True`,\n    `random_flip=False`, etc.).\n    \"\"\"\n\n    enable_cpu_offload_during_validation: bool = False\n    \"\"\"If True, models will be kept in CPU memory and loaded into GPU memory one-by-one while generating validation\n    images. This reduces VRAM requirements at the cost of slower generation of validation images.\n    \"\"\"\n\n    gradient_accumulation_steps: int = 1\n    \"\"\"The number of gradient steps to accumulate before each weight update. This is an alternative to increasing the\n    `train_batch_size` when training with limited VRAM.\n    \"\"\"\n\n    weight_dtype: Literal[\"float32\", \"float16\", \"bfloat16\"] = \"bfloat16\"\n    \"\"\"All weights (trainable and fixed) will be cast to this precision. Lower precision dtypes require less VRAM, and\n    result in faster training, but are more prone to issues with numerical stability.\n\n    Recommendations:\n\n    - `\"float32\"`: Use this mode if you have plenty of VRAM available.\n    - `\"bfloat16\"`: Use this mode if you have limited VRAM and a GPU that supports bfloat16.\n    - `\"float16\"`: Use this mode if you have limited VRAM and a GPU that does not support bfloat16.\n\n    See also [`mixed_precision`][invoke_training.pipelines.stable_diffusion.textual_inversion.config.SdTextualInversionConfig.mixed_precision].\n    \"\"\"  # noqa: E501\n\n    mixed_precision: Literal[\"no\", \"fp16\", \"bf16\", \"fp8\"] = \"no\"\n    \"\"\"The mixed precision mode to use.\n\n    If mixed precision is enabled, then all non-trainable parameters will be cast to the specified `weight_dtype`, and\n    trainable parameters are kept in float32 precision to avoid issues with numerical stability.\n\n    This value is passed to Hugging Face Accelerate. See\n    [`accelerate.Accelerator.mixed_precision`](https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.mixed_precision)\n    for more details.\n    \"\"\"  # noqa: E501\n\n    xformers: bool = False\n    \"\"\"If `True`, use xformers for more efficient attention blocks.\n    \"\"\"\n\n    gradient_checkpointing: bool = False\n    \"\"\"Whether or not to use gradient checkpointing to save memory at the expense of a slower backward pass. Enabling\n    gradient checkpointing slows down training by ~20%.\n    \"\"\"\n\n    max_checkpoints: int | None = None\n    \"\"\"The maximum number of checkpoints to keep. New checkpoints will replace earlier checkpoints to stay under this\n    limit. Note that this limit is applied to 'step' and 'epoch' checkpoints separately.\n    \"\"\"\n\n    prediction_type: Literal[\"epsilon\", \"v_prediction\"] | None = None\n    \"\"\"The prediction type that will be used for training. If `None`, the prediction type will be inferred from the\n    scheduler.\n    \"\"\"\n\n    max_grad_norm: float | None = None\n    \"\"\"Maximum gradient norm for gradient clipping. Set to `null` or 0 for no clipping.\n    \"\"\"\n\n    validation_prompts: list[str] = []\n    \"\"\"A list of prompts that will be used to generate images throughout training for the purpose of tracking progress.\n    \"\"\"\n\n    negative_validation_prompts: list[str] | None = None\n    \"\"\"A list of negative prompts that will be applied when generating validation images. If set, this list should have\n    the same length as 'validation_prompts'.\n    \"\"\"\n\n    num_validation_images_per_prompt: int = 4\n    \"\"\"The number of validation images to generate for each prompt in `validation_prompts`. Careful, validation can\n    become very slow if this number is too large.\n    \"\"\"\n\n    train_batch_size: int = 4\n    \"\"\"The training batch size.\n    \"\"\"\n\n    use_masks: bool = False\n    \"\"\"If True, image masks will be applied to weight the loss during training. The dataset must contain masks for this\n    feature to be used.\n    \"\"\"\n\n    data_loader: TextualInversionSDDataLoaderConfig\n    \"\"\"The data configuration.\n\n    See\n    [`TextualInversionSDDataLoaderConfig`][invoke_training.config.data.data_loader_config.TextualInversionSDDataLoaderConfig]\n    for details.\n    \"\"\"\n\n    @model_validator(mode=\"after\")\n    def check_validation_prompts(self):\n        if self.negative_validation_prompts is not None and len(self.negative_validation_prompts) != len(\n            self.validation_prompts\n        ):\n            raise ValueError(\n                f\"The number of validation_prompts ({len(self.validation_prompts)}) must match the number of \"\n                f\"negative_validation_prompts ({len(self.negative_validation_prompts)}).\"\n            )\n        return self\n"
  },
  {
    "path": "src/invoke_training/pipelines/stable_diffusion/textual_inversion/train.py",
    "content": "import json\nimport logging\nimport math\nimport os\nimport tempfile\nimport time\n\nimport torch\nfrom accelerate import Accelerator\nfrom accelerate.utils import set_seed\nfrom diffusers.optimization import get_scheduler\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPTextModel, CLIPTokenizer, PreTrainedTokenizer\n\nfrom invoke_training._shared.accelerator.accelerator_utils import (\n    get_dtype_from_str,\n    initialize_accelerator,\n    initialize_logging,\n)\nfrom invoke_training._shared.checkpoints.checkpoint_tracker import CheckpointTracker\nfrom invoke_training._shared.checkpoints.serialization import save_state_dict\nfrom invoke_training._shared.data.data_loaders.textual_inversion_sd_dataloader import (\n    build_textual_inversion_sd_dataloader,\n)\nfrom invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import log_aspect_ratio_buckets\nfrom invoke_training._shared.optimizer.optimizer_utils import initialize_optimizer\nfrom invoke_training._shared.stable_diffusion.model_loading_utils import load_models_sd\nfrom invoke_training._shared.stable_diffusion.textual_inversion import (\n    initialize_placeholder_tokens_from_initial_embedding,\n    initialize_placeholder_tokens_from_initial_phrase,\n    initialize_placeholder_tokens_from_initializer_token,\n    restore_original_embeddings,\n)\nfrom invoke_training._shared.stable_diffusion.validation import generate_validation_images_sd\nfrom invoke_training._shared.utils.import_xformers import import_xformers\nfrom invoke_training.pipelines.callbacks import ModelCheckpoint, ModelType, PipelineCallbacks, TrainingCheckpoint\nfrom invoke_training.pipelines.stable_diffusion.lora.train import cache_vae_outputs, train_forward\nfrom invoke_training.pipelines.stable_diffusion.textual_inversion.config import SdTextualInversionConfig\n\n\ndef _save_ti_embeddings(\n    epoch: int,\n    step: int,\n    text_encoder: CLIPTextModel,\n    placeholder_token_ids: list[int],\n    accelerator: Accelerator,\n    logger: logging.Logger,\n    checkpoint_tracker: CheckpointTracker,\n    callbacks: list[PipelineCallbacks] | None,\n):\n    \"\"\"Save a Textual Inversion checkpoint. Old checkpoints are deleted if necessary to respect the checkpoint_tracker\n    limits.\n    \"\"\"\n    # Prune checkpoints and get new checkpoint path.\n    num_pruned = checkpoint_tracker.prune(1)\n    if num_pruned > 0:\n        logger.info(f\"Pruned {num_pruned} checkpoint(s).\")\n    save_path = checkpoint_tracker.get_path(epoch=epoch, step=step)\n\n    learned_embeds = (\n        accelerator.unwrap_model(text_encoder)\n        .get_input_embeddings()\n        .weight[min(placeholder_token_ids) : max(placeholder_token_ids) + 1]\n    )\n    learned_embeds_dict = {\"emb_params\": learned_embeds.detach().cpu().to(torch.float32)}\n\n    save_state_dict(learned_embeds_dict, save_path)\n\n    if callbacks is not None:\n        for cb in callbacks:\n            cb.on_save_checkpoint(\n                TrainingCheckpoint(\n                    models=[ModelCheckpoint(file_path=save_path, model_type=ModelType.SD1_TEXTUAL_INVERSION)],\n                    epoch=epoch,\n                    step=step,\n                )\n            )\n\n\ndef _initialize_placeholder_tokens(\n    config: SdTextualInversionConfig,\n    tokenizer: CLIPTokenizer,\n    text_encoder: PreTrainedTokenizer,\n    logger: logging.Logger,\n) -> tuple[list[str], list[int]]:\n    \"\"\"Prepare the tokenizer and text_encoder for TI training.\n\n    - Add the placeholder tokens to the tokenizer.\n    - Add new token embeddings to the text_encoder for each of the placeholder tokens.\n    - Initialize the new token embeddings from either an existing token, or an initial TI embedding file.\n    \"\"\"\n    if (\n        sum(\n            [\n                config.initializer_token is not None,\n                config.initial_embedding_file is not None,\n                config.initial_phrase is not None,\n            ]\n        )\n        != 1\n    ):\n        raise ValueError(\n            \"Exactly one of 'initializer_token', 'initial_embedding_file', or 'initial_phrase' should be set.\"\n        )\n\n    if config.initializer_token is not None:\n        placeholder_tokens, placeholder_token_ids = initialize_placeholder_tokens_from_initializer_token(\n            tokenizer=tokenizer,\n            text_encoder=text_encoder,\n            initializer_token=config.initializer_token,\n            placeholder_token=config.placeholder_token,\n            num_vectors=config.num_vectors,\n            logger=logger,\n        )\n    elif config.initial_embedding_file is not None:\n        placeholder_tokens, placeholder_token_ids = initialize_placeholder_tokens_from_initial_embedding(\n            tokenizer=tokenizer,\n            text_encoder=text_encoder,\n            initial_embedding_file=config.initial_embedding_file,\n            placeholder_token=config.placeholder_token,\n            num_vectors=config.num_vectors,\n        )\n    elif config.initial_phrase is not None:\n        placeholder_tokens, placeholder_token_ids = initialize_placeholder_tokens_from_initial_phrase(\n            tokenizer=tokenizer,\n            text_encoder=text_encoder,\n            initial_phrase=config.initial_phrase,\n            placeholder_token=config.placeholder_token,\n        )\n    else:\n        raise ValueError(\n            \"Exactly one of 'initializer_token', 'initial_embedding_file', or 'initial_phrase' should be set.\"\n        )\n\n    return placeholder_tokens, placeholder_token_ids\n\n\ndef train(config: SdTextualInversionConfig, callbacks: list[PipelineCallbacks] | None = None):  # noqa: C901\n    # Create a timestamped directory for all outputs.\n    out_dir = os.path.join(config.base_output_dir, f\"{time.time()}\")\n    ckpt_dir = os.path.join(out_dir, \"checkpoints\")\n    os.makedirs(ckpt_dir)\n\n    accelerator = initialize_accelerator(\n        out_dir, config.gradient_accumulation_steps, config.mixed_precision, config.report_to\n    )\n    logger = initialize_logging(os.path.basename(__file__), accelerator)\n\n    # Set the accelerate seed.\n    if config.seed is not None:\n        set_seed(config.seed)\n\n    # Log the accelerator configuration from every process to help with debugging.\n    logger.info(accelerator.state, main_process_only=False)\n\n    logger.info(\"Starting Textual Inversion Training.\")\n    logger.info(f\"Configuration:\\n{json.dumps(config.dict(), indent=2, default=str)}\")\n    logger.info(f\"Output dir: '{out_dir}'\")\n\n    # Write the configuration to disk.\n    with open(os.path.join(out_dir, \"config.json\"), \"w\") as f:\n        json.dump(config.dict(), f, indent=2, default=str)\n\n    weight_dtype = get_dtype_from_str(config.weight_dtype)\n\n    logger.info(\"Loading models.\")\n    tokenizer, noise_scheduler, text_encoder, vae, unet = load_models_sd(\n        logger=logger, model_name_or_path=config.model, hf_variant=config.hf_variant, dtype=weight_dtype\n    )\n\n    placeholder_tokens, placeholder_token_ids = _initialize_placeholder_tokens(\n        config=config, tokenizer=tokenizer, text_encoder=text_encoder, logger=logger\n    )\n    logger.info(f\"Initialized {len(placeholder_tokens)} placeholder tokens: {placeholder_tokens}.\")\n\n    # All parameters of the VAE, UNet, and text encoder are currently frozen. Just unfreeze the token embeddings in the\n    # text encoder.\n    text_encoder.text_model.embeddings.token_embedding.requires_grad_(True)\n\n    if config.gradient_checkpointing:\n        # We want to enable gradient checkpointing in the UNet regardless of whether it is being trained.\n        unet.enable_gradient_checkpointing()\n        # unet must be in train() mode for gradient checkpointing to take effect.\n        # At the time of writing, the unet dropout probabilities default to 0, so putting the unet in train mode does\n        # not change its forward behavior.\n        unet.train()\n\n        # The text_encoder will be put in .train() mode later, so we don't need to worry about that here.\n        # Note: There are some weird interactions gradient checkpointing and requires_grad_() when training a\n        # text_encoder LoRA. If this code ever gets copied elsewhere, make sure to take a look at how this is handled in\n        # other training pipelines.\n        text_encoder.gradient_checkpointing_enable()\n\n    if config.xformers:\n        import_xformers()\n\n        unet.enable_xformers_memory_efficient_attention()\n        vae.enable_xformers_memory_efficient_attention()\n\n    # Prepare VAE output cache.\n    vae_output_cache_dir_name = None\n    if config.cache_vae_outputs:\n        if config.data_loader.random_flip:\n            raise ValueError(\"'cache_vae_outputs' cannot be True if 'random_flip' is True.\")\n        if not config.data_loader.center_crop:\n            raise ValueError(\"'cache_vae_outputs' cannot be True if 'center_crop' is False.\")\n\n        # We use a temporary directory for the cache. The directory will automatically be cleaned up when\n        # tmp_vae_output_cache_dir is destroyed.\n        tmp_vae_output_cache_dir = tempfile.TemporaryDirectory()\n        vae_output_cache_dir_name = tmp_vae_output_cache_dir.name\n        if accelerator.is_local_main_process:\n            # Only the main process should populate the cache.\n            logger.info(f\"Generating VAE output cache ('{vae_output_cache_dir_name}').\")\n            vae.to(accelerator.device, dtype=weight_dtype)\n            data_loader = build_textual_inversion_sd_dataloader(\n                config=config.data_loader,\n                placeholder_token=config.placeholder_token,\n                batch_size=config.train_batch_size,\n                use_masks=config.use_masks,\n                shuffle=False,\n            )\n            cache_vae_outputs(vae_output_cache_dir_name, data_loader, vae)\n        # Move the VAE back to the CPU, because it is not needed for training.\n        vae.to(\"cpu\")\n        accelerator.wait_for_everyone()\n    else:\n        vae.to(accelerator.device, dtype=weight_dtype)\n\n    unet.to(accelerator.device, dtype=weight_dtype)\n    text_encoder.to(accelerator.device, dtype=weight_dtype)\n\n    # Initialize the optimizer to only optimize the token embeddings.\n    optimizer = initialize_optimizer(config.optimizer, text_encoder.get_input_embeddings().parameters())\n\n    data_loader = build_textual_inversion_sd_dataloader(\n        config=config.data_loader,\n        placeholder_token=config.placeholder_token,\n        batch_size=config.train_batch_size,\n        use_masks=config.use_masks,\n        vae_output_cache_dir=vae_output_cache_dir_name,\n    )\n\n    log_aspect_ratio_buckets(logger=logger, batch_sampler=data_loader.batch_sampler)\n\n    assert sum([config.max_train_steps is not None, config.max_train_epochs is not None]) == 1\n    assert sum([config.save_every_n_steps is not None, config.save_every_n_epochs is not None]) == 1\n    assert sum([config.validate_every_n_steps is not None, config.validate_every_n_epochs is not None]) == 1\n\n    # A \"step\" represents a single weight update operation (i.e. takes into account gradient accumulation steps).\n    # math.ceil(...) is used in calculating the num_steps_per_epoch, because by default an optimizer step is taken when\n    # the end of the dataloader is reached, even if gradient_accumulation_steps hasn't been reached.\n    num_steps_per_epoch = math.ceil(len(data_loader) / config.gradient_accumulation_steps)\n    num_train_steps = config.max_train_steps or config.max_train_epochs * num_steps_per_epoch\n    num_train_epochs = math.ceil(num_train_steps / num_steps_per_epoch)\n\n    # TODO(ryand): Test in a distributed training environment and more clearly document the rationale for scaling steps\n    # by the number of processes. This scaling logic was copied from the diffusers example training code, but it appears\n    # in many places so I don't know where it originated. Internally, accelerate makes one LR scheduler step per process\n    # (https://github.com/huggingface/accelerate/blame/49cb83a423f2946059117d8bb39b7c8747d29d80/src/accelerate/scheduler.py#L72-L82),\n    # so the scaling here simply reverses that behaviour.\n    lr_scheduler: torch.optim.lr_scheduler.LRScheduler = get_scheduler(\n        config.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=config.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=num_train_steps * accelerator.num_processes,\n    )\n\n    # Prepare everything with our `accelerator`.\n    text_encoder, optimizer, data_loader, lr_scheduler = accelerator.prepare(\n        text_encoder, optimizer, data_loader, lr_scheduler\n    )\n\n    prepared_result: tuple[\n        CLIPTextModel, torch.optim.Optimizer, torch.utils.data.DataLoader, torch.optim.lr_scheduler.LRScheduler\n    ] = accelerator.prepare(text_encoder, optimizer, data_loader, lr_scheduler)\n    text_encoder, optimizer, data_loader, lr_scheduler = prepared_result\n\n    if accelerator.is_main_process:\n        accelerator.init_trackers(\"textual_inversion_training\")\n        # Tensorboard uses markdown formatting, so we wrap the config json in a code block.\n        accelerator.log({\"configuration\": f\"```json\\n{json.dumps(config.dict(), indent=2, default=str)}\\n```\\n\"})\n\n    checkpoint_tracker = CheckpointTracker(\n        base_dir=ckpt_dir,\n        prefix=\"checkpoint\",\n        extension=\".safetensors\",\n        max_checkpoints=config.max_checkpoints,\n    )\n\n    # Train!\n    total_batch_size = config.train_batch_size * accelerator.num_processes * config.gradient_accumulation_steps\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num batches = {len(data_loader)}\")\n    logger.info(f\"  Instantaneous batch size per device = {config.train_batch_size}\")\n    logger.info(f\"  Gradient accumulation steps = {config.gradient_accumulation_steps}\")\n    logger.info(f\"  Parallel processes = {accelerator.num_processes}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Total optimization steps = {num_train_steps}\")\n    logger.info(f\"  Total epochs = {num_train_epochs}\")\n\n    global_step = 0\n    first_epoch = 0\n    completed_epochs = 0\n\n    progress_bar = tqdm(\n        range(global_step, num_train_steps),\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n    progress_bar.set_description(\"Steps\")\n\n    # Keep original embeddings as reference.\n    orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone()\n\n    def save_checkpoint(num_completed_epochs: int, num_completed_steps: int):\n        accelerator.wait_for_everyone()\n        if accelerator.is_main_process:\n            _save_ti_embeddings(\n                epoch=num_completed_epochs,\n                step=num_completed_steps,\n                text_encoder=text_encoder,\n                placeholder_token_ids=placeholder_token_ids,\n                accelerator=accelerator,\n                logger=logger,\n                checkpoint_tracker=checkpoint_tracker,\n                callbacks=callbacks,\n            )\n        accelerator.wait_for_everyone()\n\n    def validate(num_completed_epochs: int, num_completed_steps: int):\n        accelerator.wait_for_everyone()\n        if accelerator.is_main_process:\n            generate_validation_images_sd(\n                epoch=num_completed_epochs,\n                step=num_completed_steps,\n                out_dir=out_dir,\n                accelerator=accelerator,\n                vae=vae,\n                text_encoder=text_encoder,\n                tokenizer=tokenizer,\n                noise_scheduler=noise_scheduler,\n                unet=unet,\n                config=config,\n                logger=logger,\n                callbacks=callbacks,\n            )\n        accelerator.wait_for_everyone()\n\n    for epoch in range(first_epoch, num_train_epochs):\n        text_encoder.train()\n\n        train_loss = 0.0\n        for data_batch_idx, data_batch in enumerate(data_loader):\n            with accelerator.accumulate(text_encoder):\n                loss = train_forward(\n                    config=config,\n                    data_batch=data_batch,\n                    vae=vae,\n                    noise_scheduler=noise_scheduler,\n                    tokenizer=tokenizer,\n                    text_encoder=text_encoder,\n                    unet=unet,\n                    weight_dtype=weight_dtype,\n                    use_masks=config.use_masks,\n                    min_snr_gamma=config.min_snr_gamma,\n                )\n\n                # Gather the losses across all processes for logging (if we use distributed training).\n                # TODO(ryand): Test that this works properly with distributed training.\n                avg_loss = accelerator.gather(loss.repeat(config.train_batch_size)).mean()\n                train_loss += avg_loss.item() / config.gradient_accumulation_steps\n\n                accelerator.backward(loss)\n                if accelerator.sync_gradients and config.max_grad_norm is not None:\n                    # TODO(ryand): I copied this from another pipeline. Should probably just clip the trainable params.\n                    params_to_clip = text_encoder.parameters()\n                    accelerator.clip_grad_norm_(params_to_clip, config.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad(set_to_none=True)\n\n                # Make sure we don't update any embedding weights besides the newly-added token(s).\n                # TODO(ryand): Should we only do this if accelerator.sync_gradients?\n                restore_original_embeddings(\n                    tokenizer=tokenizer,\n                    placeholder_token_ids=placeholder_token_ids,\n                    accelerator=accelerator,\n                    text_encoder=text_encoder,\n                    orig_embeds_params=orig_embeds_params,\n                )\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n                completed_epochs = epoch if (data_batch_idx + 1) < len(data_loader) else epoch + 1\n                log = {\"train_loss\": train_loss, \"lr\": lr_scheduler.get_last_lr()[0]}\n\n                if config.optimizer.optimizer_type == \"Prodigy\":\n                    # TODO(ryand): Test Prodigy logging.\n                    log[\"lr/d*lr\"] = optimizer.param_groups[0][\"d\"] * optimizer.param_groups[0][\"lr\"]\n\n                accelerator.log(log, step=global_step)\n                train_loss = 0.0\n\n                # global_step represents the *number of completed steps* at this point.\n                if config.save_every_n_steps is not None and global_step % config.save_every_n_steps == 0:\n                    save_checkpoint(num_completed_epochs=completed_epochs, num_completed_steps=global_step)\n\n                if (\n                    config.validate_every_n_steps is not None\n                    and global_step % config.validate_every_n_steps == 0\n                    and len(config.validation_prompts) > 0\n                ):\n                    validate(num_completed_epochs=completed_epochs, num_completed_steps=global_step)\n\n            logs = {\"step_loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n\n            if global_step >= num_train_steps:\n                break\n\n        # Save a checkpoint every n epochs.\n        if config.save_every_n_epochs is not None and completed_epochs % config.save_every_n_epochs == 0:\n            save_checkpoint(num_completed_epochs=completed_epochs, num_completed_steps=global_step)\n\n        # Generate validation images every n epochs.\n        if (\n            config.validate_every_n_epochs is not None\n            and completed_epochs % config.validate_every_n_epochs == 0\n            and len(config.validation_prompts) > 0\n        ):\n            validate(num_completed_epochs=completed_epochs, num_completed_steps=global_step)\n\n    accelerator.end_training()\n"
  },
  {
    "path": "src/invoke_training/pipelines/stable_diffusion_xl/__init__.py",
    "content": ""
  },
  {
    "path": "src/invoke_training/pipelines/stable_diffusion_xl/finetune/__init__.py",
    "content": ""
  },
  {
    "path": "src/invoke_training/pipelines/stable_diffusion_xl/finetune/config.py",
    "content": "from typing import Annotated, Literal, Union\n\nfrom pydantic import Field, model_validator\n\nfrom invoke_training.config.base_pipeline_config import BasePipelineConfig\nfrom invoke_training.config.data.data_loader_config import DreamboothSDDataLoaderConfig, ImageCaptionSDDataLoaderConfig\nfrom invoke_training.config.optimizer.optimizer_config import AdamOptimizerConfig, ProdigyOptimizerConfig\n\n\nclass SdxlFinetuneConfig(BasePipelineConfig):\n    type: Literal[\"SDXL_FINETUNE\"] = \"SDXL_FINETUNE\"\n\n    model: str = \"stabilityai/stable-diffusion-xl-base-1.0\"\n    \"\"\"Name or path of the base model to train. Can be in diffusers format, or a single stable diffusion checkpoint\n    file. (E.g. 'stabilityai/stable-diffusion-xl-base-1.0', '/path/to/JuggernautXL.safetensors', etc. )\n    \"\"\"\n\n    hf_variant: str | None = \"fp16\"\n    \"\"\"The Hugging Face Hub model variant to use. Only applies if `model` is a Hugging Face Hub model name.\n    \"\"\"\n\n    save_checkpoint_format: Literal[\"full_diffusers\", \"trained_only_diffusers\"] = \"trained_only_diffusers\"\n    \"\"\"The save format for the checkpoints.\n\n    Options:\n\n    - `full_diffusers`: Save the full model in diffusers format (including models that weren't finetuned). If you want a\n    single output artifact that can be used for generation, then this is the recommended option.\n    - `trained_only_diffusers`: Save only the models that were finetuned in diffusers format. For example, if only the\n    UNet model was trained, then only the UNet model will be saved. This option will significantly reduce the disk space\n    consumed by the saved checkpoints. If you plan to extract a LoRA from the fine-tuned model, then this is the\n    recommended option.\n    \"\"\"\n\n    save_dtype: Literal[\"float32\", \"float16\", \"bfloat16\"] = \"float16\"\n    \"\"\"The dtype to use when saving the model.\n    \"\"\"\n\n    optimizer: AdamOptimizerConfig | ProdigyOptimizerConfig = AdamOptimizerConfig()\n\n    lr_scheduler: Literal[\n        \"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\", \"constant\", \"constant_with_warmup\"\n    ] = \"constant\"\n\n    lr_warmup_steps: int = 0\n    \"\"\"The number of warmup steps in the learning rate scheduler. Only applied to schedulers that support warmup.\n    See lr_scheduler.\n    \"\"\"\n\n    min_snr_gamma: float | None = 5.0\n    \"\"\"Min-SNR weighting for diffusion training was introduced in https://arxiv.org/abs/2303.09556. This strategy\n    improves the speed of training convergence by adjusting the weight of each sample.\n\n    `min_snr_gamma` acts like an an upper bound on the weight of samples with low noise levels.\n\n    If `None`, then Min-SNR weighting will not be applied. If enabled, the recommended value is `min_snr_gamma = 5.0`.\n    \"\"\"\n\n    cache_text_encoder_outputs: bool = False\n    \"\"\"If True, the text encoder(s) will be applied to all of the captions in the dataset before starting training and\n    the results will be cached to disk. This reduces the VRAM requirements during training (don't have to keep the\n    text encoders in VRAM), and speeds up training  (don't have to run the text encoders for each training example).\n    This option can only be enabled if `train_text_encoder == False` and there are no caption augmentations being\n    applied.\n    \"\"\"\n\n    cache_vae_outputs: bool = False\n    \"\"\"If True, the VAE will be applied to all of the images in the dataset before starting training and the results\n    will be cached to disk. This reduces the VRAM requirements during training (don't have to keep the VAE in VRAM), and\n    speeds up training (don't have to run the VAE encoding step). This option can only be enabled if all\n    non-deterministic image augmentations are disabled (i.e. center_crop=True, random_flip=False).\n    \"\"\"\n\n    enable_cpu_offload_during_validation: bool = False\n    \"\"\"If True, models will be kept in CPU memory and loaded into GPU memory one-by-one while generating validation\n    images. This reduces VRAM requirements at the cost of slower generation of validation images.\n    \"\"\"\n\n    gradient_accumulation_steps: int = 1\n    \"\"\"The number of gradient steps to accumulate before each weight update. This value is passed to Hugging Face\n    Accelerate. This is an alternative to increasing the batch size when training with limited VRAM.\n    \"\"\"\n\n    weight_dtype: Literal[\"float32\", \"float16\", \"bfloat16\"] = \"bfloat16\"\n    \"\"\"All weights (trainable and fixed) will be cast to this precision. Lower precision dtypes require less VRAM, and\n    result in faster training, but are more prone to issues with numerical stability.\n\n    Recommendations:\n\n    - `\"float32\"`: Use this mode if you have plenty of VRAM available.\n    - `\"bfloat16\"`: Use this mode if you have limited VRAM and a GPU that supports bfloat16.\n    - `\"float16\"`: Use this mode if you have limited VRAM and a GPU that does not support bfloat16.\n\n    See also [`mixed_precision`][invoke_training.pipelines.stable_diffusion_xl.lora.config.SdxlLoraConfig.mixed_precision].\n    \"\"\"  # noqa: E501\n\n    mixed_precision: Literal[\"no\", \"fp16\", \"bf16\", \"fp8\"] = \"no\"\n    \"\"\"The mixed precision mode to use.\n\n    If mixed precision is enabled, then all non-trainable parameters will be cast to the specified `weight_dtype`, and\n    trainable parameters are kept in float32 precision to avoid issues with numerical stability.\n\n    This value is passed to Hugging Face Accelerate. See\n    [`accelerate.Accelerator.mixed_precision`](https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.mixed_precision)\n    for more details.\n    \"\"\"  # noqa: E501\n\n    xformers: bool = False\n    \"\"\"If true, use xformers for more efficient attention blocks.\n    \"\"\"\n\n    gradient_checkpointing: bool = False\n    \"\"\"Whether or not to use gradient checkpointing to save memory at the expense of a slower backward pass. Enabling\n    gradient checkpointing slows down training by ~20%.\n    \"\"\"\n\n    max_checkpoints: int | None = None\n    \"\"\"The maximum number of checkpoints to keep. New checkpoints will replace earlier checkpoints to stay under this\n    limit. Note that this limit is applied to 'step' and 'epoch' checkpoints separately.\n    \"\"\"\n\n    prediction_type: Literal[\"epsilon\", \"v_prediction\"] | None = None\n    \"\"\"The prediction_type that will be used for training. Choose between 'epsilon' or 'v_prediction' or leave 'None'.\n    If 'None', the prediction type of the scheduler: `noise_scheduler.config.prediction_type` is used.\n    \"\"\"\n\n    max_grad_norm: float | None = None\n    \"\"\"Max gradient norm for clipping. Set to null or 0 for no clipping.\n    \"\"\"\n\n    validation_prompts: list[str] = []\n    \"\"\"A list of prompts that will be used to generate images throughout training for the purpose of tracking progress.\n    See also 'validate_every_n_epochs'.\n    \"\"\"\n\n    negative_validation_prompts: list[str] | None = None\n    \"\"\"A list of negative prompts that will be applied when generating validation images. If set, this list should have\n    the same length as 'validation_prompts'.\n    \"\"\"\n\n    num_validation_images_per_prompt: int = 4\n    \"\"\"The number of validation images to generate for each prompt in 'validation_prompts'. Careful, validation can\n    become quite slow if this number is too large.\n    \"\"\"\n\n    train_batch_size: int = 4\n    \"\"\"The training batch size.\n    \"\"\"\n\n    use_masks: bool = False\n    \"\"\"If True, image masks will be applied to weight the loss during training. The dataset must contain masks for this\n    feature to be used.\n    \"\"\"\n\n    data_loader: Annotated[\n        Union[ImageCaptionSDDataLoaderConfig, DreamboothSDDataLoaderConfig], Field(discriminator=\"type\")\n    ]\n\n    vae_model: str | None = None\n    \"\"\"The name of the Hugging Face Hub VAE model to train against. This will override the VAE bundled with the base\n    model (specified by the `model` parameter). This config option is provided for SDXL models, because SDXL shipped\n    with a VAE that produces NaNs in fp16 mode, so it is common to replace this VAE with a fixed version.\n    \"\"\"\n\n    @model_validator(mode=\"after\")\n    def check_validation_prompts(self):\n        if self.negative_validation_prompts is not None and len(self.negative_validation_prompts) != len(\n            self.validation_prompts\n        ):\n            raise ValueError(\n                f\"The number of validation_prompts ({len(self.validation_prompts)}) must match the number of \"\n                f\"negative_validation_prompts ({len(self.negative_validation_prompts)}).\"\n            )\n        return self\n"
  },
  {
    "path": "src/invoke_training/pipelines/stable_diffusion_xl/finetune/train.py",
    "content": "import itertools\nimport json\nimport logging\nimport math\nimport os\nimport tempfile\nimport time\nfrom typing import Literal\n\nimport peft\nimport torch\nimport torch.utils.data\nfrom accelerate.utils import set_seed\nfrom diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel\nfrom diffusers.optimization import get_scheduler\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPTextModel, CLIPTokenizer\n\nfrom invoke_training._shared.accelerator.accelerator_utils import (\n    get_dtype_from_str,\n    initialize_accelerator,\n    initialize_logging,\n)\nfrom invoke_training._shared.checkpoints.checkpoint_tracker import CheckpointTracker\nfrom invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import log_aspect_ratio_buckets\nfrom invoke_training._shared.optimizer.optimizer_utils import initialize_optimizer\nfrom invoke_training._shared.stable_diffusion.checkpoint_utils import (\n    save_sdxl_diffusers_checkpoint,\n    save_sdxl_diffusers_unet_checkpoint,\n)\nfrom invoke_training._shared.stable_diffusion.model_loading_utils import load_models_sdxl\nfrom invoke_training._shared.stable_diffusion.validation import generate_validation_images_sdxl\nfrom invoke_training._shared.utils.import_xformers import import_xformers\nfrom invoke_training.pipelines.callbacks import ModelCheckpoint, ModelType, PipelineCallbacks, TrainingCheckpoint\nfrom invoke_training.pipelines.stable_diffusion.lora.train import cache_vae_outputs\nfrom invoke_training.pipelines.stable_diffusion_xl.finetune.config import SdxlFinetuneConfig\nfrom invoke_training.pipelines.stable_diffusion_xl.lora.train import (\n    _build_data_loader,\n    cache_text_encoder_outputs,\n    train_forward,\n)\n\n\ndef _save_sdxl_checkpoint(\n    epoch: int,\n    step: int,\n    save_checkpoint_format: Literal[\"full_diffusers\", \"trained_only_diffusers\"],\n    vae: AutoencoderKL,\n    text_encoder_1: CLIPTextModel,\n    text_encoder_2: CLIPTextModel,\n    tokenizer_1: CLIPTokenizer,\n    tokenizer_2: CLIPTokenizer,\n    noise_scheduler: DDPMScheduler,\n    unet: UNet2DConditionModel,\n    save_dtype: torch.dtype,\n    logger: logging.Logger,\n    checkpoint_tracker: CheckpointTracker,\n    callbacks: list[PipelineCallbacks] | None,\n):\n    # Prune checkpoints and get new checkpoint path.\n    num_pruned = checkpoint_tracker.prune(1)\n    if num_pruned > 0:\n        logger.info(f\"Pruned {num_pruned} checkpoint(s).\")\n    save_path = checkpoint_tracker.get_path(epoch=epoch, step=step)\n\n    if save_checkpoint_format == \"trained_only_diffusers\":\n        model_type = ModelType.SDXL_UNET_DIFFUSERS\n        save_sdxl_diffusers_unet_checkpoint(checkpoint_path=save_path, unet=unet, save_dtype=save_dtype)\n    elif save_checkpoint_format == \"full_diffusers\":\n        model_type = ModelType.SDXL_FULL_DIFFUSERS\n        save_sdxl_diffusers_checkpoint(\n            checkpoint_path=save_path,\n            vae=vae,\n            text_encoder_1=text_encoder_1,\n            text_encoder_2=text_encoder_2,\n            tokenizer_1=tokenizer_1,\n            tokenizer_2=tokenizer_2,\n            noise_scheduler=noise_scheduler,\n            unet=unet,\n            save_dtype=save_dtype,\n        )\n    else:\n        raise ValueError(f\"Invalid save_checkpoint_format: '{save_checkpoint_format}'.\")\n\n    if callbacks is not None:\n        for cb in callbacks:\n            cb.on_save_checkpoint(\n                TrainingCheckpoint(\n                    models=[ModelCheckpoint(file_path=save_path, model_type=model_type)],\n                    epoch=epoch,\n                    step=step,\n                )\n            )\n\n\ndef train(config: SdxlFinetuneConfig, callbacks: list[PipelineCallbacks] | None = None):  # noqa: C901\n    # Give a clear error message if an unsupported base model was chosen.\n    # TODO(ryan): Update this check to work with single-file SD checkpoints.\n    # check_base_model_version(\n    #     {BaseModelVersionEnum.STABLE_DIFFUSION_SDXL_BASE},\n    #     config.model,\n    #     local_files_only=False,\n    # )\n\n    # Create a timestamped directory for all outputs.\n    out_dir = os.path.join(config.base_output_dir, f\"{time.time()}\")\n    ckpt_dir = os.path.join(out_dir, \"checkpoints\")\n    os.makedirs(ckpt_dir)\n\n    accelerator = initialize_accelerator(\n        out_dir, config.gradient_accumulation_steps, config.mixed_precision, config.report_to\n    )\n    logger = initialize_logging(os.path.basename(__file__), accelerator)\n\n    # Set the accelerate seed.\n    if config.seed is not None:\n        set_seed(config.seed)\n\n    # Log the accelerator configuration from every process to help with debugging.\n    logger.info(accelerator.state, main_process_only=False)\n\n    logger.info(\"Starting Training.\")\n    logger.info(f\"Configuration:\\n{json.dumps(config.dict(), indent=2, default=str)}\")\n    logger.info(f\"Output dir: '{out_dir}'\")\n\n    # Write the configuration to disk.\n    with open(os.path.join(out_dir, \"config.json\"), \"w\") as f:\n        json.dump(config.dict(), f, indent=2, default=str)\n\n    weight_dtype = get_dtype_from_str(config.weight_dtype)\n\n    logger.info(\"Loading models.\")\n    tokenizer_1, tokenizer_2, noise_scheduler, text_encoder_1, text_encoder_2, vae, unet = load_models_sdxl(\n        logger=logger,\n        model_name_or_path=config.model,\n        hf_variant=config.hf_variant,\n        vae_model=config.vae_model,\n        base_embeddings=None,\n        dtype=weight_dtype,\n    )\n\n    if config.xformers:\n        import_xformers()\n\n        # TODO(ryand): There is a known issue if xformers is enabled when training in mixed precision where xformers\n        # will fail because Q, K, V have different dtypes.\n        unet.enable_xformers_memory_efficient_attention()\n        vae.enable_xformers_memory_efficient_attention()\n\n    # Prepare text encoder output cache.\n    text_encoder_output_cache_dir_name = None\n    if config.cache_text_encoder_outputs:\n        # TODO(ryand): Think about how to better check if it is safe to cache the text encoder outputs. Currently, there\n        # are a number of configurations that would cause variation in the text encoder outputs and should not be used\n        # with caching.\n\n        # We use a temporary directory for the cache. The directory will automatically be cleaned up when\n        # tmp_text_encoder_output_cache_dir is destroyed.\n        tmp_text_encoder_output_cache_dir = tempfile.TemporaryDirectory()\n        text_encoder_output_cache_dir_name = tmp_text_encoder_output_cache_dir.name\n        if accelerator.is_local_main_process:\n            # Only the main process should populate the cache.\n            logger.info(f\"Generating text encoder output cache ('{text_encoder_output_cache_dir_name}').\")\n            text_encoder_1.to(accelerator.device, dtype=weight_dtype)\n            text_encoder_2.to(accelerator.device, dtype=weight_dtype)\n            # TODO(ryan): Move cache_text_encoder_outputs to a shared location so that it is not imported from another\n            # pipeline.\n            cache_text_encoder_outputs(\n                text_encoder_output_cache_dir_name, config, tokenizer_1, tokenizer_2, text_encoder_1, text_encoder_2\n            )\n        # Move the text_encoders back to the CPU, because they are not needed for training.\n        text_encoder_1.to(\"cpu\")\n        text_encoder_2.to(\"cpu\")\n        accelerator.wait_for_everyone()\n    else:\n        text_encoder_1.to(accelerator.device, dtype=weight_dtype)\n        text_encoder_2.to(accelerator.device, dtype=weight_dtype)\n\n    # Prepare VAE output cache.\n    vae_output_cache_dir_name = None\n    if config.cache_vae_outputs:\n        if config.data_loader.random_flip:\n            raise ValueError(\"'cache_vae_outputs' cannot be True if 'random_flip' is True.\")\n        if not config.data_loader.center_crop:\n            raise ValueError(\"'cache_vae_outputs' cannot be True if 'center_crop' is False.\")\n\n        # We use a temporary directory for the cache. The directory will automatically be cleaned up when\n        # tmp_vae_output_cache_dir is destroyed.\n        tmp_vae_output_cache_dir = tempfile.TemporaryDirectory()\n        vae_output_cache_dir_name = tmp_vae_output_cache_dir.name\n        if accelerator.is_local_main_process:\n            # Only the main process should to populate the cache.\n            logger.info(f\"Generating VAE output cache ('{vae_output_cache_dir_name}').\")\n            vae.to(accelerator.device, dtype=weight_dtype)\n            # TODO(ryan): Move cache_text_encoder_outputs to a shared location so that it is not imported from another\n            # pipeline.\n            data_loader = _build_data_loader(\n                data_loader_config=config.data_loader,\n                batch_size=config.train_batch_size,\n                use_masks=config.use_masks,\n                shuffle=False,\n                sequential_batching=True,\n            )\n            cache_vae_outputs(vae_output_cache_dir_name, data_loader, vae)\n        # Move the VAE back to the CPU, because it is not needed for training.\n        vae.to(\"cpu\")\n        accelerator.wait_for_everyone()\n    else:\n        vae.to(accelerator.device, dtype=weight_dtype)\n\n    unet.to(accelerator.device, dtype=weight_dtype)\n\n    # Make UNet trainable.\n    unet.requires_grad_(True)\n    unet.train()\n    all_trainable_models = [unet]\n\n    # If mixed_precision is enabled, cast all trainable params to float32.\n    if config.mixed_precision != \"no\":\n        for trainable_model in all_trainable_models:\n            for param in trainable_model.parameters():\n                if param.requires_grad:\n                    param.data = param.to(torch.float32)\n\n    if config.gradient_checkpointing:\n        # We want to enable gradient checkpointing in the UNet regardless of whether it is being trained.\n        unet.enable_gradient_checkpointing()\n        # unet must be in train() mode for gradient checkpointing to take effect.\n        # At the time of writing, the unet dropout probabilities default to 0, so putting the unet in train mode does\n        # not change its forward behavior.\n        unet.train()\n\n    optimizer = initialize_optimizer(config.optimizer, unet.parameters())\n\n    data_loader = _build_data_loader(\n        data_loader_config=config.data_loader,\n        batch_size=config.train_batch_size,\n        use_masks=config.use_masks,\n        text_encoder_output_cache_dir=text_encoder_output_cache_dir_name,\n        vae_output_cache_dir=vae_output_cache_dir_name,\n    )\n\n    log_aspect_ratio_buckets(logger=logger, batch_sampler=data_loader.batch_sampler)\n\n    assert sum([config.max_train_steps is not None, config.max_train_epochs is not None]) == 1\n    assert sum([config.save_every_n_steps is not None, config.save_every_n_epochs is not None]) == 1\n    assert sum([config.validate_every_n_steps is not None, config.validate_every_n_epochs is not None]) == 1\n\n    # A \"step\" represents a single weight update operation (i.e. takes into account gradient accumulation steps).\n    # math.ceil(...) is used in calculating the num_steps_per_epoch, because by default an optimizer step is taken when\n    # the end of the dataloader is reached, even if gradient_accumulation_steps hasn't been reached.\n    num_steps_per_epoch = math.ceil(len(data_loader) / config.gradient_accumulation_steps)\n    num_train_steps = config.max_train_steps or config.max_train_epochs * num_steps_per_epoch\n    num_train_epochs = math.ceil(num_train_steps / num_steps_per_epoch)\n\n    # TODO(ryand): Test in a distributed training environment and more clearly document the rationale for scaling steps\n    # by the number of processes. This scaling logic was copied from the diffusers example training code, but it appears\n    # in many places so I don't know where it originated. Internally, accelerate makes one LR scheduler step per process\n    # (https://github.com/huggingface/accelerate/blame/49cb83a423f2946059117d8bb39b7c8747d29d80/src/accelerate/scheduler.py#L72-L82),\n    # so the scaling here simply reverses that behaviour.\n    lr_scheduler: torch.optim.lr_scheduler.LRScheduler = get_scheduler(\n        config.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=config.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=num_train_steps * accelerator.num_processes,\n    )\n\n    prepared_result: tuple[\n        UNet2DConditionModel,\n        peft.PeftModel | CLIPTextModel,\n        peft.PeftModel | CLIPTextModel,\n        torch.optim.Optimizer,\n        torch.utils.data.DataLoader,\n        torch.optim.lr_scheduler.LRScheduler,\n    ] = accelerator.prepare(\n        unet,\n        text_encoder_1,\n        text_encoder_2,\n        optimizer,\n        data_loader,\n        lr_scheduler,\n        # Disable automatic device placement for text_encoder if the text encoder outputs were cached.\n        device_placement=[\n            True,\n            not config.cache_text_encoder_outputs,\n            not config.cache_text_encoder_outputs,\n            True,\n            True,\n            True,\n        ],\n    )\n    unet, text_encoder_1, text_encoder_2, optimizer, data_loader, lr_scheduler = prepared_result\n\n    if accelerator.is_main_process:\n        accelerator.init_trackers(\"finetune\")\n        # Tensorboard uses markdown formatting, so we wrap the config json in a code block.\n        accelerator.log({\"configuration\": f\"```json\\n{json.dumps(config.dict(), indent=2, default=str)}\\n```\\n\"})\n\n    checkpoint_tracker = CheckpointTracker(\n        base_dir=ckpt_dir, prefix=\"checkpoint\", max_checkpoints=config.max_checkpoints\n    )\n\n    # Train!\n    total_batch_size = config.train_batch_size * accelerator.num_processes * config.gradient_accumulation_steps\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num batches = {len(data_loader)}\")\n    logger.info(f\"  Instantaneous batch size per device = {config.train_batch_size}\")\n    logger.info(f\"  Gradient accumulation steps = {config.gradient_accumulation_steps}\")\n    logger.info(f\"  Parallel processes = {accelerator.num_processes}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Total optimization steps = {num_train_steps}\")\n    logger.info(f\"  Total epochs = {num_train_epochs}\")\n\n    global_step = 0\n    first_epoch = 0\n    completed_epochs = 0\n\n    progress_bar = tqdm(\n        range(global_step, num_train_steps),\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n    progress_bar.set_description(\"Steps\")\n\n    def save_checkpoint(num_completed_epochs: int, num_completed_steps: int):\n        accelerator.wait_for_everyone()\n        if accelerator.is_main_process:\n            _save_sdxl_checkpoint(\n                epoch=num_completed_epochs,\n                step=num_completed_steps,\n                save_checkpoint_format=config.save_checkpoint_format,\n                vae=vae,\n                text_encoder_1=text_encoder_1,\n                text_encoder_2=text_encoder_2,\n                tokenizer_1=tokenizer_1,\n                tokenizer_2=tokenizer_2,\n                noise_scheduler=noise_scheduler,\n                unet=unet,\n                save_dtype=get_dtype_from_str(config.save_dtype),\n                logger=logger,\n                checkpoint_tracker=checkpoint_tracker,\n                callbacks=callbacks,\n            )\n        accelerator.wait_for_everyone()\n\n    def validate(num_completed_epochs: int, num_completed_steps: int):\n        accelerator.wait_for_everyone()\n        if accelerator.is_main_process:\n            generate_validation_images_sdxl(\n                epoch=num_completed_epochs,\n                step=num_completed_steps,\n                out_dir=out_dir,\n                accelerator=accelerator,\n                vae=vae,\n                text_encoder_1=text_encoder_1,\n                text_encoder_2=text_encoder_2,\n                tokenizer_1=tokenizer_1,\n                tokenizer_2=tokenizer_2,\n                noise_scheduler=noise_scheduler,\n                unet=unet,\n                config=config,\n                logger=logger,\n                callbacks=callbacks,\n            )\n        accelerator.wait_for_everyone()\n\n    for epoch in range(first_epoch, num_train_epochs):\n        train_loss = 0.0\n        for data_batch_idx, data_batch in enumerate(data_loader):\n            with accelerator.accumulate(unet, text_encoder_1, text_encoder_2):\n                loss = train_forward(\n                    accelerator=accelerator,\n                    data_batch=data_batch,\n                    vae=vae,\n                    noise_scheduler=noise_scheduler,\n                    tokenizer_1=tokenizer_1,\n                    tokenizer_2=tokenizer_2,\n                    text_encoder_1=text_encoder_1,\n                    text_encoder_2=text_encoder_2,\n                    unet=unet,\n                    weight_dtype=weight_dtype,\n                    resolution=config.data_loader.resolution,\n                    use_masks=config.use_masks,\n                    prediction_type=config.prediction_type,\n                    min_snr_gamma=config.min_snr_gamma,\n                )\n\n                # Gather the losses across all processes for logging (if we use distributed training).\n                # TODO(ryand): Test that this works properly with distributed training.\n                avg_loss = accelerator.gather(loss.repeat(config.train_batch_size)).mean()\n                train_loss += avg_loss.item() / config.gradient_accumulation_steps\n\n                # Backpropagate.\n                accelerator.backward(loss)\n                if accelerator.sync_gradients and config.max_grad_norm is not None:\n                    params_to_clip = itertools.chain.from_iterable([m.parameters() for m in all_trainable_models])\n                    accelerator.clip_grad_norm_(params_to_clip, config.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad(set_to_none=True)\n\n            # Checks if the accelerator has performed an optimization step behind the scenes.\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n                completed_epochs = epoch if (data_batch_idx + 1) < len(data_loader) else epoch + 1\n                log = {\"train_loss\": train_loss}\n\n                lrs = lr_scheduler.get_last_lr()\n                # When training the UNet, it will always be the first parameter group.\n                log[\"lr/unet\"] = float(lrs[0])\n                if config.optimizer.optimizer_type == \"Prodigy\":\n                    log[\"lr/d*lr/unet\"] = optimizer.param_groups[0][\"d\"] * optimizer.param_groups[0][\"lr\"]\n\n                accelerator.log(log, step=global_step)\n                train_loss = 0.0\n\n                # global_step represents the *number of completed steps* at this point.\n                if config.save_every_n_steps is not None and global_step % config.save_every_n_steps == 0:\n                    save_checkpoint(num_completed_epochs=completed_epochs, num_completed_steps=global_step)\n\n                if (\n                    config.validate_every_n_steps is not None\n                    and global_step % config.validate_every_n_steps == 0\n                    and len(config.validation_prompts) > 0\n                ):\n                    validate(num_completed_epochs=completed_epochs, num_completed_steps=global_step)\n\n            logs = {\n                \"step_loss\": loss.detach().item(),\n                \"lr\": lr_scheduler.get_last_lr()[0],\n            }\n            progress_bar.set_postfix(**logs)\n\n            if global_step >= num_train_steps:\n                break\n\n        # Save a checkpoint every n epochs.\n        if config.save_every_n_epochs is not None and completed_epochs % config.save_every_n_epochs == 0:\n            save_checkpoint(num_completed_epochs=completed_epochs, num_completed_steps=global_step)\n\n        # Generate validation images every n epochs.\n        if (\n            config.validate_every_n_epochs is not None\n            and completed_epochs % config.validate_every_n_epochs == 0\n            and len(config.validation_prompts) > 0\n        ):\n            validate(num_completed_epochs=completed_epochs, num_completed_steps=global_step)\n\n    accelerator.end_training()\n"
  },
  {
    "path": "src/invoke_training/pipelines/stable_diffusion_xl/lora/__init__.py",
    "content": ""
  },
  {
    "path": "src/invoke_training/pipelines/stable_diffusion_xl/lora/config.py",
    "content": "from typing import Annotated, Literal, Union\n\nfrom pydantic import Field, model_validator\n\nfrom invoke_training._shared.stable_diffusion.lora_checkpoint_utils import (\n    TEXT_ENCODER_TARGET_MODULES,\n    UNET_TARGET_MODULES,\n)\nfrom invoke_training.config.base_pipeline_config import BasePipelineConfig\nfrom invoke_training.config.data.data_loader_config import DreamboothSDDataLoaderConfig, ImageCaptionSDDataLoaderConfig\nfrom invoke_training.config.optimizer.optimizer_config import AdamOptimizerConfig, ProdigyOptimizerConfig\n\n\nclass SdxlLoraConfig(BasePipelineConfig):\n    type: Literal[\"SDXL_LORA\"] = \"SDXL_LORA\"\n\n    model: str = \"stabilityai/stable-diffusion-xl-base-1.0\"\n    \"\"\"Name or path of the base model to train. Can be in diffusers format, or a single stable diffusion checkpoint\n    file. (E.g. 'stabilityai/stable-diffusion-xl-base-1.0', '/path/to/JuggernautXL.safetensors', etc. )\n    \"\"\"\n\n    hf_variant: str | None = \"fp16\"\n    \"\"\"The Hugging Face Hub model variant to use. Only applies if `model` is a Hugging Face Hub model name.\n    \"\"\"\n\n    # Note: Pydantic handles mutable default values well:\n    # https://docs.pydantic.dev/latest/concepts/models/#fields-with-non-hashable-default-values\n    base_embeddings: dict[str, str] = {}\n    \"\"\"A mapping of embedding tokens to trained embedding file paths. These embeddings will be applied to the base model\n    before training.\n\n    Example:\n    ```\n    base_embeddings = {\n        \"bruce_the_gnome\": \"/path/to/bruce_the_gnome.safetensors\",\n    }\n    ```\n\n    Consider also adding the embedding tokens to the `data_loader.caption_prefix` if they are not already present in the\n    dataset captions.\n\n    Note that the embeddings themselves are not fine-tuned further, but they will impact the LoRA model training if they\n    are referenced in the dataset captions. The list of embeddings provided here should be the same list used at\n    generation time with the resultant LoRA model.\n    \"\"\"\n\n    lora_checkpoint_format: Literal[\"invoke_peft\", \"kohya\"] = \"kohya\"\n    \"\"\"The format of the LoRA checkpoint to save. Choose between `invoke_peft` or `kohya`.\"\"\"\n\n    train_unet: bool = True\n    \"\"\"Whether to add LoRA layers to the UNet model and train it.\n    \"\"\"\n\n    train_text_encoder: bool = True\n    \"\"\"Whether to add LoRA layers to the text encoder and train it.\n    \"\"\"\n\n    optimizer: AdamOptimizerConfig | ProdigyOptimizerConfig = AdamOptimizerConfig()\n\n    text_encoder_learning_rate: float | None = None\n    \"\"\"The learning rate to use for the text encoder model. If set, this overrides the optimizer's default learning\n    rate. Set to null or 0 to use the optimizer's default learning rate.\n    \"\"\"\n\n    unet_learning_rate: float | None = None\n    \"\"\"The learning rate to use for the UNet model. If set, this overrides the optimizer's default learning rate.\n    Set to null or 0 to use the optimizer's default learning rate.\n    \"\"\"\n\n    lr_scheduler: Literal[\n        \"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\", \"constant\", \"constant_with_warmup\"\n    ] = \"constant\"\n\n    lr_warmup_steps: int = 0\n    \"\"\"The number of warmup steps in the learning rate scheduler. Only applied to schedulers that support warmup.\n    See lr_scheduler.\n    \"\"\"\n\n    min_snr_gamma: float | None = 5.0\n    \"\"\"Min-SNR weighting for diffusion training was introduced in https://arxiv.org/abs/2303.09556. This strategy\n    improves the speed of training convergence by adjusting the weight of each sample.\n\n    `min_snr_gamma` acts like an an upper bound on the weight of samples with low noise levels.\n\n    If `None`, then Min-SNR weighting will not be applied. If enabled, the recommended value is `min_snr_gamma = 5.0`.\n    \"\"\"\n\n    lora_rank_dim: int = 4\n    \"\"\"The rank dimension to use for the LoRA layers. Increasing the rank dimension increases the model's expressivity,\n    but also increases the size of the generated LoRA model.\n    \"\"\"\n\n    # The default list of target modules is based on\n    # https://github.com/huggingface/peft/blob/8665e2b5719faa4e4b91749ddec09442927b53e0/examples/stable_diffusion/train_dreambooth.py#L49C1-L65C87\n    unet_lora_target_modules: list[str] = UNET_TARGET_MODULES\n    \"\"\"The list of target modules to apply LoRA layers to in the UNet model. The default list will produce a highly\n    expressive LoRA model.\n\n    For a smaller and less expressive LoRA model, the following list is recommended:\n    ```python\n    unet_lora_target_modules = [\"to_k\", \"to_q\", \"to_v\", \"to_out.0\"]\n    ```\n\n    The list of target modules is passed to Hugging Face's PEFT library. See\n    [the docs](https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.target_modules) for\n    details.\n    \"\"\"\n\n    text_encoder_lora_target_modules: list[str] = TEXT_ENCODER_TARGET_MODULES\n    \"\"\"The list of target modules to apply LoRA layers to in the text encoder models. The default list will produce a\n    highly expressive LoRA model.\n\n    For a smaller and less expressive LoRA model, the following list is recommended:\n    ```python\n    text_encoder_lora_target_modules = [\"fc1\", \"fc2\", \"q_proj\", \"k_proj\", \"v_proj\", \"out_proj\"]\n    ```\n\n    The list of target modules is passed to Hugging Face's PEFT library. See\n    [the docs](https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.target_modules) for\n    details.\n    \"\"\"\n\n    cache_text_encoder_outputs: bool = False\n    \"\"\"If True, the text encoder(s) will be applied to all of the captions in the dataset before starting training and\n    the results will be cached to disk. This reduces the VRAM requirements during training (don't have to keep the\n    text encoders in VRAM), and speeds up training  (don't have to run the text encoders for each training example).\n    This option can only be enabled if `train_text_encoder == False` and there are no caption augmentations being\n    applied.\n    \"\"\"\n\n    cache_vae_outputs: bool = False\n    \"\"\"If True, the VAE will be applied to all of the images in the dataset before starting training and the results\n    will be cached to disk. This reduces the VRAM requirements during training (don't have to keep the VAE in VRAM), and\n    speeds up training (don't have to run the VAE encoding step). This option can only be enabled if all\n    non-deterministic image augmentations are disabled (i.e. center_crop=True, random_flip=False).\n    \"\"\"\n\n    enable_cpu_offload_during_validation: bool = False\n    \"\"\"If True, models will be kept in CPU memory and loaded into GPU memory one-by-one while generating validation\n    images. This reduces VRAM requirements at the cost of slower generation of validation images.\n    \"\"\"\n\n    gradient_accumulation_steps: int = 1\n    \"\"\"The number of gradient steps to accumulate before each weight update. This value is passed to Hugging Face\n    Accelerate. This is an alternative to increasing the batch size when training with limited VRAM.\n    \"\"\"\n\n    weight_dtype: Literal[\"float32\", \"float16\", \"bfloat16\"] = \"bfloat16\"\n    \"\"\"All weights (trainable and fixed) will be cast to this precision. Lower precision dtypes require less VRAM, and\n    result in faster training, but are more prone to issues with numerical stability.\n\n    Recommendations:\n\n    - `\"float32\"`: Use this mode if you have plenty of VRAM available.\n    - `\"bfloat16\"`: Use this mode if you have limited VRAM and a GPU that supports bfloat16.\n    - `\"float16\"`: Use this mode if you have limited VRAM and a GPU that does not support bfloat16.\n\n    See also [`mixed_precision`][invoke_training.pipelines.stable_diffusion_xl.lora.config.SdxlLoraConfig.mixed_precision].\n    \"\"\"  # noqa: E501\n\n    mixed_precision: Literal[\"no\", \"fp16\", \"bf16\", \"fp8\"] = \"no\"\n    \"\"\"The mixed precision mode to use.\n\n    If mixed precision is enabled, then all non-trainable parameters will be cast to the specified `weight_dtype`, and\n    trainable parameters are kept in float32 precision to avoid issues with numerical stability.\n\n    This value is passed to Hugging Face Accelerate. See\n    [`accelerate.Accelerator.mixed_precision`](https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.mixed_precision)\n    for more details.\n    \"\"\"  # noqa: E501\n\n    xformers: bool = False\n    \"\"\"If true, use xformers for more efficient attention blocks.\n    \"\"\"\n\n    gradient_checkpointing: bool = False\n    \"\"\"Whether or not to use gradient checkpointing to save memory at the expense of a slower backward pass. Enabling\n    gradient checkpointing slows down training by ~20%.\n    \"\"\"\n\n    max_checkpoints: int | None = None\n    \"\"\"The maximum number of checkpoints to keep. New checkpoints will replace earlier checkpoints to stay under this\n    limit. Note that this limit is applied to 'step' and 'epoch' checkpoints separately.\n    \"\"\"\n\n    prediction_type: Literal[\"epsilon\", \"v_prediction\"] | None = None\n    \"\"\"The prediction_type that will be used for training. Choose between 'epsilon' or 'v_prediction' or leave 'None'.\n    If 'None', the prediction type of the scheduler: `noise_scheduler.config.prediction_type` is used.\n    \"\"\"\n\n    max_grad_norm: float | None = None\n    \"\"\"Max gradient norm for clipping. Set to null or 0 for no clipping.\n    \"\"\"\n\n    validation_prompts: list[str] = []\n    \"\"\"A list of prompts that will be used to generate images throughout training for the purpose of tracking progress.\n    See also 'validate_every_n_epochs'.\n    \"\"\"\n\n    negative_validation_prompts: list[str] | None = None\n    \"\"\"A list of negative prompts that will be applied when generating validation images. If set, this list should have\n    the same length as 'validation_prompts'.\n    \"\"\"\n\n    num_validation_images_per_prompt: int = 4\n    \"\"\"The number of validation images to generate for each prompt in 'validation_prompts'. Careful, validation can\n    become quite slow if this number is too large.\n    \"\"\"\n\n    train_batch_size: int = 4\n    \"\"\"The training batch size.\n    \"\"\"\n\n    use_masks: bool = False\n    \"\"\"If True, image masks will be applied to weight the loss during training. The dataset must contain masks for this\n    feature to be used.\n    \"\"\"\n\n    data_loader: Annotated[\n        Union[ImageCaptionSDDataLoaderConfig, DreamboothSDDataLoaderConfig], Field(discriminator=\"type\")\n    ]\n\n    vae_model: str | None = None\n    \"\"\"The name of the Hugging Face Hub VAE model to train against. This will override the VAE bundled with the base\n    model (specified by the `model` parameter). This config option is provided for SDXL models, because SDXL shipped\n    with a VAE that produces NaNs in fp16 mode, so it is common to replace this VAE with a fixed version.\n    \"\"\"\n\n    @model_validator(mode=\"after\")\n    def check_validation_prompts(self):\n        if self.negative_validation_prompts is not None and len(self.negative_validation_prompts) != len(\n            self.validation_prompts\n        ):\n            raise ValueError(\n                f\"The number of validation_prompts ({len(self.validation_prompts)}) must match the number of \"\n                f\"negative_validation_prompts ({len(self.negative_validation_prompts)}).\"\n            )\n        return self\n"
  },
  {
    "path": "src/invoke_training/pipelines/stable_diffusion_xl/lora/train.py",
    "content": "import itertools\nimport json\nimport logging\nimport math\nimport os\nimport tempfile\nimport time\nfrom pathlib import Path\nfrom typing import Literal, Optional, Union\n\nimport peft\nimport torch\nimport torch.utils.data\nfrom accelerate import Accelerator\nfrom accelerate.utils import set_seed\nfrom diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel\nfrom diffusers.optimization import get_scheduler\nfrom torch.utils.data import DataLoader\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPPreTrainedModel, CLIPTextModel, PreTrainedTokenizer\n\nfrom invoke_training._shared.accelerator.accelerator_utils import (\n    get_dtype_from_str,\n    initialize_accelerator,\n    initialize_logging,\n)\nfrom invoke_training._shared.checkpoints.checkpoint_tracker import CheckpointTracker\nfrom invoke_training._shared.data.data_loaders.dreambooth_sd_dataloader import build_dreambooth_sd_dataloader\nfrom invoke_training._shared.data.data_loaders.image_caption_sd_dataloader import build_image_caption_sd_dataloader\nfrom invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import log_aspect_ratio_buckets\nfrom invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache\nfrom invoke_training._shared.data.utils.resolution import Resolution\nfrom invoke_training._shared.optimizer.optimizer_utils import initialize_optimizer\nfrom invoke_training._shared.stable_diffusion.lora_checkpoint_utils import (\n    save_sdxl_kohya_checkpoint,\n    save_sdxl_peft_checkpoint,\n)\nfrom invoke_training._shared.stable_diffusion.min_snr_weighting import compute_snr\nfrom invoke_training._shared.stable_diffusion.model_loading_utils import load_models_sdxl\nfrom invoke_training._shared.stable_diffusion.tokenize_captions import tokenize_captions\nfrom invoke_training._shared.stable_diffusion.validation import generate_validation_images_sdxl\nfrom invoke_training._shared.utils.import_xformers import import_xformers\nfrom invoke_training.config.data.data_loader_config import DreamboothSDDataLoaderConfig, ImageCaptionSDDataLoaderConfig\nfrom invoke_training.pipelines.callbacks import ModelCheckpoint, ModelType, PipelineCallbacks, TrainingCheckpoint\nfrom invoke_training.pipelines.stable_diffusion.lora.train import cache_vae_outputs\nfrom invoke_training.pipelines.stable_diffusion_xl.lora.config import SdxlLoraConfig\n\n\ndef _save_sdxl_lora_checkpoint(\n    epoch: int,\n    step: int,\n    unet: peft.PeftModel | None,\n    text_encoder_1: peft.PeftModel | None,\n    text_encoder_2: peft.PeftModel | None,\n    logger: logging.Logger,\n    checkpoint_tracker: CheckpointTracker,\n    lora_checkpoint_format: Literal[\"invoke_peft\", \"kohya\"],\n    callbacks: list[PipelineCallbacks] | None,\n):\n    # Prune checkpoints and get new checkpoint path.\n    num_pruned = checkpoint_tracker.prune(1)\n    if num_pruned > 0:\n        logger.info(f\"Pruned {num_pruned} checkpoint(s).\")\n    save_path = checkpoint_tracker.get_path(epoch=epoch, step=step)\n\n    if lora_checkpoint_format == \"invoke_peft\":\n        model_type = ModelType.SD1_LORA_PEFT\n        save_sdxl_peft_checkpoint(\n            Path(save_path), unet=unet, text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2\n        )\n    elif lora_checkpoint_format == \"kohya\":\n        model_type = ModelType.SD1_LORA_KOHYA\n        save_sdxl_kohya_checkpoint(\n            Path(save_path), unet=unet, text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2\n        )\n    else:\n        raise ValueError(f\"Unsupported lora_checkpoint_format: '{lora_checkpoint_format}'.\")\n\n    if callbacks is not None:\n        for cb in callbacks:\n            cb.on_save_checkpoint(\n                TrainingCheckpoint(\n                    models=[ModelCheckpoint(file_path=save_path, model_type=model_type)], epoch=epoch, step=step\n                )\n            )\n\n\ndef _build_data_loader(\n    data_loader_config: Union[ImageCaptionSDDataLoaderConfig, DreamboothSDDataLoaderConfig],\n    batch_size: int,\n    use_masks: bool = False,\n    text_encoder_output_cache_dir: Optional[str] = None,\n    vae_output_cache_dir: Optional[str] = None,\n    shuffle: bool = True,\n    sequential_batching: bool = False,\n) -> DataLoader:\n    if data_loader_config.type == \"IMAGE_CAPTION_SD_DATA_LOADER\":\n        return build_image_caption_sd_dataloader(\n            config=data_loader_config,\n            batch_size=batch_size,\n            use_masks=use_masks,\n            text_encoder_output_cache_dir=text_encoder_output_cache_dir,\n            text_encoder_cache_field_to_output_field={\n                \"prompt_embeds\": \"prompt_embeds\",\n                \"pooled_prompt_embeds\": \"pooled_prompt_embeds\",\n            },\n            vae_output_cache_dir=vae_output_cache_dir,\n            shuffle=shuffle,\n        )\n    elif data_loader_config.type == \"DREAMBOOTH_SD_DATA_LOADER\":\n        if use_masks:\n            raise ValueError(\"Masks are not yet supported for DreamBooth data loaders.\")\n        return build_dreambooth_sd_dataloader(\n            config=data_loader_config,\n            batch_size=batch_size,\n            text_encoder_output_cache_dir=text_encoder_output_cache_dir,\n            text_encoder_cache_field_to_output_field={\n                \"prompt_embeds\": \"prompt_embeds\",\n                \"pooled_prompt_embeds\": \"pooled_prompt_embeds\",\n            },\n            vae_output_cache_dir=vae_output_cache_dir,\n            shuffle=shuffle,\n            sequential_batching=sequential_batching,\n        )\n    else:\n        raise ValueError(f\"Unsupported data loader config type: '{data_loader_config.type}'.\")\n\n\n# encode_prompt was adapted from:\n# https://github.com/huggingface/diffusers/blob/7b07f9812a58bfa96c06ed8ffe9e6b584286e2fd/examples/text_to_image/train_text_to_image_lora_sdxl.py#L470-L496\ndef _encode_prompt(text_encoders: list[CLIPPreTrainedModel], prompt_token_ids_list: list[torch.Tensor]):\n    prompt_embeds_list = []\n\n    for i, text_encoder in enumerate(text_encoders):\n        text_input_ids = prompt_token_ids_list[i]\n\n        prompt_embeds = text_encoder(\n            text_input_ids.to(text_encoder.device),\n            output_hidden_states=True,\n        )\n\n        # We are only ALWAYS interested in the pooled output of the final text encoder.\n        # TODO(ryand): Document this logic more clearly.\n        pooled_prompt_embeds = prompt_embeds[0]\n        prompt_embeds = prompt_embeds.hidden_states[-2]\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)\n        prompt_embeds_list.append(prompt_embeds)\n\n    prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)\n    pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)\n    return prompt_embeds, pooled_prompt_embeds\n\n\n# TODO(ryand): Cache VAE outputs and text encoder outputs at the same time in a single pass over the dataset.\n\n\ndef cache_text_encoder_outputs(\n    cache_dir: str,\n    config: SdxlLoraConfig,\n    tokenizer_1: PreTrainedTokenizer,\n    tokenizer_2: PreTrainedTokenizer,\n    text_encoder_1: CLIPPreTrainedModel,\n    text_encoder_2: CLIPPreTrainedModel,\n):\n    \"\"\"Run the text encoder on all captions in the dataset and cache the results to disk.\n    Args:\n        cache_dir (str): The directory where the results will be cached.\n        config (FinetuneLoRAConfig): Training config.\n        tokenizer_1 (PreTrainedTokenizer):\n        tokenizer_2 (PreTrainedTokenizer):\n        text_encoder_1 (CLIPPreTrainedModel):\n        text_encoder_2 (CLIPPreTrainedModel):\n    \"\"\"\n    data_loader = _build_data_loader(\n        data_loader_config=config.data_loader,\n        batch_size=config.train_batch_size,\n        shuffle=False,\n        sequential_batching=True,\n    )\n\n    cache = TensorDiskCache(cache_dir)\n\n    for data_batch in tqdm(data_loader):\n        caption_token_ids_1 = tokenize_captions(tokenizer_1, data_batch[\"caption\"])\n        caption_token_ids_2 = tokenize_captions(tokenizer_2, data_batch[\"caption\"])\n        prompt_embeds, pooled_prompt_embeds = _encode_prompt(\n            [text_encoder_1, text_encoder_2], [caption_token_ids_1, caption_token_ids_2]\n        )\n\n        # Split batch before caching.\n        for i in range(len(data_batch[\"id\"])):\n            embeds = {\n                \"prompt_embeds\": prompt_embeds[i],\n                \"pooled_prompt_embeds\": pooled_prompt_embeds[i],\n            }\n            cache.save(data_batch[\"id\"][i], embeds)\n\n\ndef train_forward(  # noqa: C901\n    accelerator: Accelerator,\n    data_batch: dict,\n    vae: AutoencoderKL,\n    noise_scheduler: DDPMScheduler,\n    tokenizer_1: PreTrainedTokenizer,\n    tokenizer_2: PreTrainedTokenizer,\n    text_encoder_1: CLIPPreTrainedModel,\n    text_encoder_2: CLIPPreTrainedModel,\n    unet: UNet2DConditionModel,\n    weight_dtype: torch.dtype,\n    resolution: int | tuple[int, int],\n    use_masks: bool = False,\n    prediction_type=None,\n    min_snr_gamma: float | None = None,\n):\n    \"\"\"Run the forward training pass for a single data_batch.\n\n    Returns:\n        torch.Tensor: Loss\n    \"\"\"\n    # Convert images to latent space.\n    # The VAE output may have been cached and included in the data_batch. If not, we calculate it here.\n    latents = data_batch.get(\"vae_output\", None)\n    if latents is None:\n        latents = vae.encode(data_batch[\"image\"].to(dtype=weight_dtype)).latent_dist.sample()\n        latents = latents * vae.config.scaling_factor\n\n    # Sample noise that we'll add to the latents.\n    noise = torch.randn_like(latents)\n\n    batch_size = latents.shape[0]\n    # Sample a random timestep for each image.\n    timesteps = torch.randint(\n        0,\n        noise_scheduler.config.num_train_timesteps,\n        (batch_size,),\n        device=latents.device,\n    )\n    timesteps = timesteps.long()\n\n    # Add noise to the latents according to the noise magnitude at each timestep (this is the forward diffusion\n    # process).\n    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n\n    # compute_time_ids was copied from:\n    # https://github.com/huggingface/diffusers/blob/7b07f9812a58bfa96c06ed8ffe9e6b584286e2fd/examples/text_to_image/train_text_to_image_lora_sdxl.py#L1033-L1039\n    # \"time_ids\" may seem like a weird naming choice. The name comes from the diffusers SDXL implementation. Presumably,\n    # it is a result of the fact that the original size and crop values get concatenated with the time embeddings.\n    def compute_time_ids(original_size, crops_coords_top_left):\n        # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids\n        target_size = Resolution.parse(resolution).to_tuple()\n        add_time_ids = list(original_size + crops_coords_top_left + target_size)\n        add_time_ids = torch.tensor([add_time_ids])\n        add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)\n        return add_time_ids\n\n    add_time_ids = torch.cat(\n        [compute_time_ids(s, c) for s, c in zip(data_batch[\"original_size_hw\"], data_batch[\"crop_top_left_yx\"])]\n    )\n    unet_conditions = {\"time_ids\": add_time_ids}\n\n    # Get the text embedding for conditioning.\n    # The text encoder output may have been cached and included in the data_batch. If not, we calculate it here.\n    if \"prompt_embeds\" in data_batch:\n        prompt_embeds = data_batch[\"prompt_embeds\"]\n        pooled_prompt_embeds = data_batch[\"pooled_prompt_embeds\"]\n    else:\n        caption_token_ids_1 = tokenize_captions(tokenizer_1, data_batch[\"caption\"])\n        caption_token_ids_2 = tokenize_captions(tokenizer_2, data_batch[\"caption\"])\n        prompt_embeds, pooled_prompt_embeds = _encode_prompt(\n            [text_encoder_1, text_encoder_2], [caption_token_ids_1, caption_token_ids_2]\n        )\n        prompt_embeds = prompt_embeds.to(dtype=weight_dtype)\n        pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=weight_dtype)\n\n    unet_conditions[\"text_embeds\"] = pooled_prompt_embeds\n\n    # Get the target for loss depending on the prediction type.\n    if prediction_type is not None:\n        # Set the prediction_type of scheduler if it's defined in config.\n        noise_scheduler.register_to_config(prediction_type=prediction_type)\n    if noise_scheduler.config.prediction_type == \"epsilon\":\n        target = noise\n    elif noise_scheduler.config.prediction_type == \"v_prediction\":\n        target = noise_scheduler.get_velocity(latents, noise, timesteps)\n    else:\n        raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n    # Predict the noise residual.\n    model_pred = unet(noisy_latents, timesteps, prompt_embeds, added_cond_kwargs=unet_conditions).sample\n\n    min_snr_weights = None\n    if min_snr_gamma is not None:\n        # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.\n        # Since we predict the noise instead of x_0, the original formulation is slightly changed.\n        # This is discussed in Section 4.2 of the same paper.\n\n        snr = compute_snr(noise_scheduler, timesteps)\n\n        # Note: We divide by snr here per Section 4.2 of the paper, since we are predicting the noise instead of x_0.\n        # w_t = min(1, SNR(t)) / SNR(t)\n        min_snr_weights = torch.clamp(snr, max=min_snr_gamma) / snr\n\n        if noise_scheduler.config.prediction_type == \"epsilon\":\n            pass\n        elif noise_scheduler.config.prediction_type == \"v_prediction\":\n            # Velocity objective needs to be floored to an SNR weight of one.\n            min_snr_weights = min_snr_weights + 1\n        else:\n            raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n\n    loss = torch.nn.functional.mse_loss(model_pred.float(), target.float(), reduction=\"none\")\n\n    if use_masks:\n        # TODO(ryand): As a future performance optimization, we may want to do this resizing in the dataloader.\n        mask = data_batch[\"mask\"].to(dtype=loss.dtype, device=loss.device)\n        _, _, latent_h, latent_w = loss.shape\n        mask = torch.nn.functional.interpolate(mask, size=(latent_h, latent_w), mode=\"nearest\")\n        loss = loss * mask\n\n    # Mean-reduce the loss along all dimensions except for the batch dimension.\n    loss = loss.mean(dim=list(range(1, len(loss.shape))))\n\n    # Apply min_snr_weights.\n    if min_snr_weights is not None:\n        loss = loss * min_snr_weights\n\n    # Apply per-example loss weights.\n    if \"loss_weight\" in data_batch:\n        loss = loss * data_batch[\"loss_weight\"]\n\n    return loss.mean()\n\n\ndef train(config: SdxlLoraConfig, callbacks: list[PipelineCallbacks] | None = None):  # noqa: C901\n    # Give a clear error message if an unsupported base model was chosen.\n    # TODO(ryan): Update this check to work with single-file SD checkpoints.\n    # check_base_model_version(\n    #     {BaseModelVersionEnum.STABLE_DIFFUSION_SDXL_BASE},\n    #     config.model,\n    #     local_files_only=False,\n    # )\n\n    # Create a timestamped directory for all outputs.\n    out_dir = os.path.join(config.base_output_dir, f\"{time.time()}\")\n    ckpt_dir = os.path.join(out_dir, \"checkpoints\")\n    os.makedirs(ckpt_dir)\n\n    accelerator = initialize_accelerator(\n        out_dir, config.gradient_accumulation_steps, config.mixed_precision, config.report_to\n    )\n    logger = initialize_logging(os.path.basename(__file__), accelerator)\n\n    # Set the accelerate seed.\n    if config.seed is not None:\n        set_seed(config.seed)\n\n    # Log the accelerator configuration from every process to help with debugging.\n    logger.info(accelerator.state, main_process_only=False)\n\n    logger.info(\"Starting Training.\")\n    logger.info(f\"Configuration:\\n{json.dumps(config.dict(), indent=2, default=str)}\")\n    logger.info(f\"Output dir: '{out_dir}'\")\n\n    # Write the configuration to disk.\n    with open(os.path.join(out_dir, \"config.json\"), \"w\") as f:\n        json.dump(config.dict(), f, indent=2, default=str)\n\n    weight_dtype = get_dtype_from_str(config.weight_dtype)\n\n    logger.info(\"Loading models.\")\n    tokenizer_1, tokenizer_2, noise_scheduler, text_encoder_1, text_encoder_2, vae, unet = load_models_sdxl(\n        logger=logger,\n        model_name_or_path=config.model,\n        hf_variant=config.hf_variant,\n        vae_model=config.vae_model,\n        base_embeddings=config.base_embeddings,\n        dtype=weight_dtype,\n    )\n\n    if config.xformers:\n        import_xformers()\n\n        # TODO(ryand): There is a known issue if xformers is enabled when training in mixed precision where xformers\n        # will fail because Q, K, V have different dtypes.\n        unet.enable_xformers_memory_efficient_attention()\n        vae.enable_xformers_memory_efficient_attention()\n\n    # Prepare text encoder output cache.\n    text_encoder_output_cache_dir_name = None\n    if config.cache_text_encoder_outputs:\n        # TODO(ryand): Think about how to better check if it is safe to cache the text encoder outputs. Currently, there\n        # are a number of configurations that would cause variation in the text encoder outputs and should not be used\n        # with caching.\n\n        if config.train_text_encoder:\n            raise ValueError(\"'cache_text_encoder_outputs' and 'train_text_encoder' cannot both be True.\")\n\n        # We use a temporary directory for the cache. The directory will automatically be cleaned up when\n        # tmp_text_encoder_output_cache_dir is destroyed.\n        tmp_text_encoder_output_cache_dir = tempfile.TemporaryDirectory()\n        text_encoder_output_cache_dir_name = tmp_text_encoder_output_cache_dir.name\n        if accelerator.is_local_main_process:\n            # Only the main process should populate the cache.\n            logger.info(f\"Generating text encoder output cache ('{text_encoder_output_cache_dir_name}').\")\n            text_encoder_1.to(accelerator.device, dtype=weight_dtype)\n            text_encoder_2.to(accelerator.device, dtype=weight_dtype)\n            cache_text_encoder_outputs(\n                text_encoder_output_cache_dir_name, config, tokenizer_1, tokenizer_2, text_encoder_1, text_encoder_2\n            )\n        # Move the text_encoders back to the CPU, because they are not needed for training.\n        text_encoder_1.to(\"cpu\")\n        text_encoder_2.to(\"cpu\")\n        accelerator.wait_for_everyone()\n    else:\n        text_encoder_1.to(accelerator.device, dtype=weight_dtype)\n        text_encoder_2.to(accelerator.device, dtype=weight_dtype)\n\n    # Prepare VAE output cache.\n    vae_output_cache_dir_name = None\n    if config.cache_vae_outputs:\n        if config.data_loader.random_flip:\n            raise ValueError(\"'cache_vae_outputs' cannot be True if 'random_flip' is True.\")\n        if not config.data_loader.center_crop:\n            raise ValueError(\"'cache_vae_outputs' cannot be True if 'center_crop' is False.\")\n\n        # We use a temporary directory for the cache. The directory will automatically be cleaned up when\n        # tmp_vae_output_cache_dir is destroyed.\n        tmp_vae_output_cache_dir = tempfile.TemporaryDirectory()\n        vae_output_cache_dir_name = tmp_vae_output_cache_dir.name\n        if accelerator.is_local_main_process:\n            # Only the main process should to populate the cache.\n            logger.info(f\"Generating VAE output cache ('{vae_output_cache_dir_name}').\")\n            vae.to(accelerator.device, dtype=weight_dtype)\n            data_loader = _build_data_loader(\n                data_loader_config=config.data_loader,\n                batch_size=config.train_batch_size,\n                use_masks=config.use_masks,\n                shuffle=False,\n                sequential_batching=True,\n            )\n            cache_vae_outputs(vae_output_cache_dir_name, data_loader, vae)\n        # Move the VAE back to the CPU, because it is not needed for training.\n        vae.to(\"cpu\")\n        accelerator.wait_for_everyone()\n    else:\n        vae.to(accelerator.device, dtype=weight_dtype)\n\n    unet.to(accelerator.device, dtype=weight_dtype)\n\n    # Add LoRA layers to the models being trained.\n    trainable_param_groups = []\n    all_trainable_models: list[peft.PeftModel] = []\n\n    def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = None) -> peft.PeftModel:\n        peft_model = peft.get_peft_model(model, lora_config)\n        peft_model.print_trainable_parameters()\n\n        # Populate `trainable_param_groups`, to be passed to the optimizer.\n        param_group = {\"params\": list(filter(lambda p: p.requires_grad, peft_model.parameters()))}\n        if lr is not None:\n            param_group[\"lr\"] = lr\n        trainable_param_groups.append(param_group)\n\n        # Populate all_trainable_models.\n        all_trainable_models.append(peft_model)\n\n        peft_model.train()\n\n        return peft_model\n\n    if config.train_unet:\n        unet_lora_config = peft.LoraConfig(\n            r=config.lora_rank_dim,\n            # TODO(ryand): Diffusers uses lora_alpha=config.lora_rank_dim. Is that preferred?\n            lora_alpha=1.0,\n            target_modules=config.unet_lora_target_modules,\n        )\n        unet = inject_lora_layers(unet, unet_lora_config, lr=config.unet_learning_rate)\n\n    if config.train_text_encoder:\n        text_encoder_lora_config = peft.LoraConfig(\n            r=config.lora_rank_dim,\n            lora_alpha=1.0,\n            # init_lora_weights=\"gaussian\",\n            target_modules=config.text_encoder_lora_target_modules,\n        )\n        text_encoder_1 = inject_lora_layers(\n            text_encoder_1, text_encoder_lora_config, lr=config.text_encoder_learning_rate\n        )\n        text_encoder_2 = inject_lora_layers(\n            text_encoder_2, text_encoder_lora_config, lr=config.text_encoder_learning_rate\n        )\n\n    # If mixed_precision is enabled, cast all trainable params to float32.\n    if config.mixed_precision != \"no\":\n        for trainable_model in all_trainable_models:\n            for param in trainable_model.parameters():\n                if param.requires_grad:\n                    param.data = param.to(torch.float32)\n\n    if config.gradient_checkpointing:\n        # We want to enable gradient checkpointing in the UNet regardless of whether it is being trained.\n        unet.enable_gradient_checkpointing()\n        # unet must be in train() mode for gradient checkpointing to take effect.\n        # At the time of writing, the unet dropout probabilities default to 0, so putting the unet in train mode does\n        # not change its forward behavior.\n        unet.train()\n        if config.train_text_encoder:\n            for te in [text_encoder_1, text_encoder_2]:\n                te.gradient_checkpointing_enable()\n\n                # The text encoders must be in train() mode for gradient checkpointing to take effect. This should\n                # already be the case, since we are training the text_encoders, be we do it explicitly to make it clear\n                # that this is required.\n                # At the time of writing, the text encoder dropout probabilities default to 0, so putting the text\n                # encoders in train mode does not change their forward behavior.\n                te.train()\n\n                # Set requires_grad = True on the first parameters of the text encoders. Without this, the text encoder\n                # LoRA weights would have 0 gradients, and so would not get trained. Note that the set of\n                # trainable_param_groups has already been populated - the embeddings will not be trained.\n                te.text_model.embeddings.requires_grad_(True)\n\n    optimizer = initialize_optimizer(config.optimizer, trainable_param_groups)\n\n    data_loader = _build_data_loader(\n        data_loader_config=config.data_loader,\n        batch_size=config.train_batch_size,\n        use_masks=config.use_masks,\n        text_encoder_output_cache_dir=text_encoder_output_cache_dir_name,\n        vae_output_cache_dir=vae_output_cache_dir_name,\n    )\n\n    log_aspect_ratio_buckets(logger=logger, batch_sampler=data_loader.batch_sampler)\n\n    assert sum([config.max_train_steps is not None, config.max_train_epochs is not None]) == 1\n    assert sum([config.save_every_n_steps is not None, config.save_every_n_epochs is not None]) == 1\n    assert sum([config.validate_every_n_steps is not None, config.validate_every_n_epochs is not None]) == 1\n\n    # A \"step\" represents a single weight update operation (i.e. takes into account gradient accumulation steps).\n    # math.ceil(...) is used in calculating the num_steps_per_epoch, because by default an optimizer step is taken when\n    # the end of the dataloader is reached, even if gradient_accumulation_steps hasn't been reached.\n    num_steps_per_epoch = math.ceil(len(data_loader) / config.gradient_accumulation_steps)\n    num_train_steps = config.max_train_steps or config.max_train_epochs * num_steps_per_epoch\n    num_train_epochs = math.ceil(num_train_steps / num_steps_per_epoch)\n\n    # TODO(ryand): Test in a distributed training environment and more clearly document the rationale for scaling steps\n    # by the number of processes. This scaling logic was copied from the diffusers example training code, but it appears\n    # in many places so I don't know where it originated. Internally, accelerate makes one LR scheduler step per process\n    # (https://github.com/huggingface/accelerate/blame/49cb83a423f2946059117d8bb39b7c8747d29d80/src/accelerate/scheduler.py#L72-L82),\n    # so the scaling here simply reverses that behaviour.\n    lr_scheduler: torch.optim.lr_scheduler.LRScheduler = get_scheduler(\n        config.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=config.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=num_train_steps * accelerator.num_processes,\n    )\n\n    prepared_result: tuple[\n        UNet2DConditionModel,\n        peft.PeftModel | CLIPTextModel,\n        peft.PeftModel | CLIPTextModel,\n        torch.optim.Optimizer,\n        torch.utils.data.DataLoader,\n        torch.optim.lr_scheduler.LRScheduler,\n    ] = accelerator.prepare(\n        unet,\n        text_encoder_1,\n        text_encoder_2,\n        optimizer,\n        data_loader,\n        lr_scheduler,\n        # Disable automatic device placement for text_encoder if the text encoder outputs were cached.\n        device_placement=[\n            True,\n            not config.cache_text_encoder_outputs,\n            not config.cache_text_encoder_outputs,\n            True,\n            True,\n            True,\n        ],\n    )\n    unet, text_encoder_1, text_encoder_2, optimizer, data_loader, lr_scheduler = prepared_result\n\n    if accelerator.is_main_process:\n        accelerator.init_trackers(\"lora_training\")\n        # Tensorboard uses markdown formatting, so we wrap the config json in a code block.\n        accelerator.log({\"configuration\": f\"```json\\n{json.dumps(config.dict(), indent=2, default=str)}\\n```\\n\"})\n\n    checkpoint_tracker = CheckpointTracker(\n        base_dir=ckpt_dir,\n        prefix=\"checkpoint\",\n        max_checkpoints=config.max_checkpoints,\n        extension=\".safetensors\" if config.lora_checkpoint_format == \"kohya\" else None,\n    )\n\n    # Train!\n    total_batch_size = config.train_batch_size * accelerator.num_processes * config.gradient_accumulation_steps\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num batches = {len(data_loader)}\")\n    logger.info(f\"  Instantaneous batch size per device = {config.train_batch_size}\")\n    logger.info(f\"  Gradient accumulation steps = {config.gradient_accumulation_steps}\")\n    logger.info(f\"  Parallel processes = {accelerator.num_processes}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Total optimization steps = {num_train_steps}\")\n    logger.info(f\"  Total epochs = {num_train_epochs}\")\n\n    global_step = 0\n    first_epoch = 0\n    completed_epochs = 0\n\n    progress_bar = tqdm(\n        range(global_step, num_train_steps),\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n    progress_bar.set_description(\"Steps\")\n\n    def save_checkpoint(num_completed_epochs: int, num_completed_steps: int):\n        accelerator.wait_for_everyone()\n        if accelerator.is_main_process:\n            _save_sdxl_lora_checkpoint(\n                epoch=num_completed_epochs,\n                step=num_completed_steps,\n                unet=unet if config.train_unet else None,\n                text_encoder_1=text_encoder_1 if config.train_text_encoder else None,\n                text_encoder_2=text_encoder_2 if config.train_text_encoder else None,\n                logger=logger,\n                checkpoint_tracker=checkpoint_tracker,\n                lora_checkpoint_format=config.lora_checkpoint_format,\n                callbacks=callbacks,\n            )\n        accelerator.wait_for_everyone()\n\n    def validate(num_completed_epochs: int, num_completed_steps: int):\n        accelerator.wait_for_everyone()\n        if accelerator.is_main_process:\n            generate_validation_images_sdxl(\n                epoch=num_completed_epochs,\n                step=num_completed_steps,\n                out_dir=out_dir,\n                accelerator=accelerator,\n                vae=vae,\n                text_encoder_1=text_encoder_1,\n                text_encoder_2=text_encoder_2,\n                tokenizer_1=tokenizer_1,\n                tokenizer_2=tokenizer_2,\n                noise_scheduler=noise_scheduler,\n                unet=unet,\n                config=config,\n                logger=logger,\n                callbacks=callbacks,\n            )\n        accelerator.wait_for_everyone()\n\n    for epoch in range(first_epoch, num_train_epochs):\n        train_loss = 0.0\n        for data_batch_idx, data_batch in enumerate(data_loader):\n            with accelerator.accumulate(unet, text_encoder_1, text_encoder_2):\n                loss = train_forward(\n                    accelerator=accelerator,\n                    data_batch=data_batch,\n                    vae=vae,\n                    noise_scheduler=noise_scheduler,\n                    tokenizer_1=tokenizer_1,\n                    tokenizer_2=tokenizer_2,\n                    text_encoder_1=text_encoder_1,\n                    text_encoder_2=text_encoder_2,\n                    unet=unet,\n                    weight_dtype=weight_dtype,\n                    resolution=config.data_loader.resolution,\n                    use_masks=config.use_masks,\n                    prediction_type=config.prediction_type,\n                    min_snr_gamma=config.min_snr_gamma,\n                )\n\n                # Gather the losses across all processes for logging (if we use distributed training).\n                # TODO(ryand): Test that this works properly with distributed training.\n                avg_loss = accelerator.gather(loss.repeat(config.train_batch_size)).mean()\n                train_loss += avg_loss.item() / config.gradient_accumulation_steps\n\n                # Backpropagate.\n                accelerator.backward(loss)\n                if accelerator.sync_gradients and config.max_grad_norm is not None:\n                    params_to_clip = itertools.chain.from_iterable([m.parameters() for m in all_trainable_models])\n                    accelerator.clip_grad_norm_(params_to_clip, config.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad(set_to_none=True)\n\n            # Checks if the accelerator has performed an optimization step behind the scenes.\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n                completed_epochs = epoch if (data_batch_idx + 1) < len(data_loader) else epoch + 1\n                log = {\"train_loss\": train_loss}\n\n                lrs = lr_scheduler.get_last_lr()\n                if config.train_unet:\n                    # When training the UNet, it will always be the first parameter group.\n                    log[\"lr/unet\"] = float(lrs[0])\n                    if config.optimizer.optimizer_type == \"Prodigy\":\n                        log[\"lr/d*lr/unet\"] = optimizer.param_groups[0][\"d\"] * optimizer.param_groups[0][\"lr\"]\n                if config.train_text_encoder:\n                    # When training the text encoder, it will always be the last parameter group.\n                    log[\"lr/text_encoder\"] = float(lrs[-1])\n                    if config.optimizer.optimizer_type == \"Prodigy\":\n                        log[\"lr/d*lr/text_encoder\"] = optimizer.param_groups[-1][\"d\"] * optimizer.param_groups[-1][\"lr\"]\n\n                accelerator.log(log, step=global_step)\n                train_loss = 0.0\n\n                # global_step represents the *number of completed steps* at this point.\n                if config.save_every_n_steps is not None and global_step % config.save_every_n_steps == 0:\n                    save_checkpoint(num_completed_epochs=completed_epochs, num_completed_steps=global_step)\n\n                if (\n                    config.validate_every_n_steps is not None\n                    and global_step % config.validate_every_n_steps == 0\n                    and len(config.validation_prompts) > 0\n                ):\n                    validate(num_completed_epochs=completed_epochs, num_completed_steps=global_step)\n\n            logs = {\n                \"step_loss\": loss.detach().item(),\n                \"lr\": lr_scheduler.get_last_lr()[0],\n            }\n            progress_bar.set_postfix(**logs)\n\n            if global_step >= num_train_steps:\n                break\n\n        # Save a checkpoint every n epochs.\n        if config.save_every_n_epochs is not None and completed_epochs % config.save_every_n_epochs == 0:\n            save_checkpoint(num_completed_epochs=completed_epochs, num_completed_steps=global_step)\n\n        # Generate validation images every n epochs.\n        if (\n            config.validate_every_n_epochs is not None\n            and completed_epochs % config.validate_every_n_epochs == 0\n            and len(config.validation_prompts) > 0\n        ):\n            validate(num_completed_epochs=completed_epochs, num_completed_steps=global_step)\n\n    accelerator.end_training()\n"
  },
  {
    "path": "src/invoke_training/pipelines/stable_diffusion_xl/lora_and_textual_inversion/__init__.py",
    "content": ""
  },
  {
    "path": "src/invoke_training/pipelines/stable_diffusion_xl/lora_and_textual_inversion/config.py",
    "content": "from typing import Literal\n\nfrom pydantic import model_validator\n\nfrom invoke_training.config.base_pipeline_config import BasePipelineConfig\nfrom invoke_training.config.data.data_loader_config import TextualInversionSDDataLoaderConfig\nfrom invoke_training.config.optimizer.optimizer_config import AdamOptimizerConfig, ProdigyOptimizerConfig\n\n\nclass SdxlLoraAndTextualInversionConfig(BasePipelineConfig):\n    type: Literal[\"SDXL_LORA_AND_TEXTUAL_INVERSION\"] = \"SDXL_LORA_AND_TEXTUAL_INVERSION\"\n\n    model: str = \"stabilityai/stable-diffusion-xl-base-1.0\"\n    \"\"\"Name or path of the base model to train. Can be in diffusers format, or a single stable diffusion checkpoint\n    file. (E.g. 'stabilityai/stable-diffusion-xl-base-1.0', '/path/to/JuggernautXL.safetensors', etc. )\n    \"\"\"\n\n    hf_variant: str | None = \"fp16\"\n    \"\"\"The Hugging Face Hub model variant to use. Only applies if `model` is a Hugging Face Hub model name.\n    \"\"\"\n\n    lora_checkpoint_format: Literal[\"invoke_peft\", \"kohya\"] = \"kohya\"\n    \"\"\"The format of the LoRA checkpoint to save. Choose between `invoke_peft` or `kohya`.\"\"\"\n\n    # Helpful discussion for understanding how this works at inference time:\n    # https://github.com/huggingface/diffusers/pull/3144#discussion_r1172413509\n    num_vectors: int = 1\n    \"\"\"Note: `num_vectors` can be overridden by `initial_phrase`.\n\n    The number of textual inversion embedding vectors that will be used to learn the concept.\n\n    Increasing the `num_vectors` enables the model to learn more complex concepts, but has the following drawbacks:\n\n    - greater risk of overfitting\n    - increased size of the resulting output file\n    - consumes more of the prompt capacity at inference time\n\n    Typical values for `num_vectors` are in the range [1, 16].\n\n    As a rule of thumb, `num_vectors` can be increased as the size of the dataset increases (without overfitting).\n    \"\"\"\n\n    placeholder_token: str\n    \"\"\"The special word to associate the learned embeddings with. Choose a unique token that is unlikely to already\n    exist in the tokenizer's vocabulary.\n    \"\"\"\n\n    initializer_token: str | None = None\n    \"\"\"A vocabulary token to use as an initializer for the placeholder token. It should be a single word that roughly\n    describes the object or style that you're trying to train on. Must map to a single tokenizer token.\n\n    For example, if you are training on a dataset of images of your pet dog, a good choice would be `dog`.\n    \"\"\"\n\n    initial_phrase: str | None = None\n    \"\"\"Note: Exactly one of `initializer_token` or `initial_phrase` should be set.\n\n    A phrase that will be used to initialize the placeholder token embedding. The phrase will be tokenized, and the\n    corresponding embeddings will be used to initialize the placeholder tokens. The number of embedding vectors will be\n    inferred from the length of the tokenized phrase, so keep the phrase short. The consequences of training a large\n    number of embedding vectors are discussed in the `num_vectors` field documentation.\n\n    For example, if you are training on a dataset of images of pokemon, you might use `pokemon sketch white background`.\n    \"\"\"\n\n    train_unet: bool = True\n    \"\"\"Whether to add LoRA layers to the UNet model and train it.\n    \"\"\"\n\n    train_text_encoder: bool = True\n    \"\"\"Whether to add LoRA layers to the text encoder and train it.\n    \"\"\"\n\n    train_ti: bool = True\n    \"\"\"Whether to train the textual inversion embeddings.\"\"\"\n\n    ti_train_steps_ratio: float | None = None\n    \"\"\"The fraction of the total training steps for which the TI embeddings will be trained. For example, if we are\n    training for a total of 5000 steps and `ti_train_steps_ratio=0.5`, then the TI embeddings will be trained for 2500\n    steps and the will be frozen for the remaining steps.\n\n    If `None`, then the TI embeddings will be trained for the entire duration of training.\n    \"\"\"\n\n    optimizer: AdamOptimizerConfig | ProdigyOptimizerConfig = AdamOptimizerConfig()\n\n    text_encoder_learning_rate: float | None = 1e-5\n    \"\"\"The learning rate to use for the text encoder model. Set to null or 0 to use the optimizer's default learning\n    rate.\n    \"\"\"\n\n    unet_learning_rate: float | None = 1e-4\n    \"\"\"The learning rate to use for the UNet model. Set to null or 0 to use the optimizer's default learning rate.\n    \"\"\"\n\n    textual_inversion_learning_rate: float | None = 1e-3\n    \"\"\"The learning rate to use for textual inversion training of the embeddings. Set to null or 0 to use the\n    optimizer's default learning rate.\n    \"\"\"\n\n    lr_scheduler: Literal[\n        \"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\", \"constant\", \"constant_with_warmup\"\n    ] = \"constant\"\n\n    lr_warmup_steps: int = 0\n    \"\"\"The number of warmup steps in the learning rate scheduler. Only applied to schedulers that support warmup.\n    See lr_scheduler.\n    \"\"\"\n\n    min_snr_gamma: float | None = 5.0\n    \"\"\"Min-SNR weighting for diffusion training was introduced in https://arxiv.org/abs/2303.09556. This strategy\n    improves the speed of training convergence by adjusting the weight of each sample.\n\n    `min_snr_gamma` acts like an an upper bound on the weight of samples with low noise levels.\n\n    If `None`, then Min-SNR weighting will not be applied. If enabled, the recommended value is `min_snr_gamma = 5.0`.\n    \"\"\"\n\n    lora_rank_dim: int = 4\n    \"\"\"The rank dimension to use for the LoRA layers. Increasing the rank dimension increases the model's expressivity,\n    but also increases the size of the generated LoRA model.\n    \"\"\"\n\n    cache_text_encoder_outputs: bool = False\n    \"\"\"If True, the text encoder(s) will be applied to all of the captions in the dataset before starting training and\n    the results will be cached to disk. This reduces the VRAM requirements during training (don't have to keep the\n    text encoders in VRAM), and speeds up training  (don't have to run the text encoders for each training example).\n    This option can only be enabled if `train_text_encoder == False` and there are no caption augmentations being\n    applied.\n    \"\"\"\n\n    cache_vae_outputs: bool = False\n    \"\"\"If True, the VAE will be applied to all of the images in the dataset before starting training and the results\n    will be cached to disk. This reduces the VRAM requirements during training (don't have to keep the VAE in VRAM), and\n    speeds up training (don't have to run the VAE encoding step). This option can only be enabled if all\n    non-deterministic image augmentations are disabled (i.e. center_crop=True, random_flip=False).\n    \"\"\"\n\n    enable_cpu_offload_during_validation: bool = False\n    \"\"\"If True, models will be kept in CPU memory and loaded into GPU memory one-by-one while generating validation\n    images. This reduces VRAM requirements at the cost of slower generation of validation images.\n    \"\"\"\n\n    gradient_accumulation_steps: int = 1\n    \"\"\"The number of gradient steps to accumulate before each weight update. This value is passed to Hugging Face\n    Accelerate. This is an alternative to increasing the batch size when training with limited VRAM.\n    \"\"\"\n    weight_dtype: Literal[\"float32\", \"float16\", \"bfloat16\"] = \"bfloat16\"\n    \"\"\"All weights (trainable and fixed) will be cast to this precision. Lower precision dtypes require less VRAM, and\n    result in faster training, but are more prone to issues with numerical stability.\n\n    Recommendations:\n\n    - `\"float32\"`: Use this mode if you have plenty of VRAM available.\n    - `\"bfloat16\"`: Use this mode if you have limited VRAM and a GPU that supports bfloat16.\n    - `\"float16\"`: Use this mode if you have limited VRAM and a GPU that does not support bfloat16.\n\n    See also [`mixed_precision`][invoke_training.pipelines.stable_diffusion_xl.lora_and_textual_inversion.config.SdxlLoraAndTextualInversionConfig.mixed_precision].\n    \"\"\"  # noqa: E501\n\n    mixed_precision: Literal[\"no\", \"fp16\", \"bf16\", \"fp8\"] = \"no\"\n    \"\"\"The mixed precision mode to use.\n\n    If mixed precision is enabled, then all non-trainable parameters will be cast to the specified `weight_dtype`, and\n    trainable parameters are kept in float32 precision to avoid issues with numerical stability.\n\n    This value is passed to Hugging Face Accelerate. See\n    [`accelerate.Accelerator.mixed_precision`](https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.mixed_precision)\n    for more details.\n    \"\"\"  # noqa: E501\n\n    xformers: bool = False\n    \"\"\"If true, use xformers for more efficient attention blocks.\n    \"\"\"\n\n    gradient_checkpointing: bool = False\n    \"\"\"Whether or not to use gradient checkpointing to save memory at the expense of a slower backward pass. Enabling\n    gradient checkpointing slows down training by ~20%.\n    \"\"\"\n\n    max_checkpoints: int | None = None\n    \"\"\"The maximum number of checkpoints to keep. New checkpoints will replace earlier checkpoints to stay under this\n    limit. Note that this limit is applied to 'step' and 'epoch' checkpoints separately.\n    \"\"\"\n\n    prediction_type: Literal[\"epsilon\", \"v_prediction\"] | None = None\n    \"\"\"The prediction_type that will be used for training. Choose between 'epsilon' or 'v_prediction' or leave 'None'.\n    If 'None', the prediction type of the scheduler: `noise_scheduler.config.prediction_type` is used.\n    \"\"\"\n\n    max_grad_norm: float | None = None\n    \"\"\"Max gradient norm for clipping. Set to null or 0 for no clipping.\n    \"\"\"\n\n    validation_prompts: list[str] = []\n    \"\"\"A list of prompts that will be used to generate images throughout training for the purpose of tracking progress.\n    \"\"\"\n\n    negative_validation_prompts: list[str] | None = None\n    \"\"\"A list of negative prompts that will be applied when generating validation images. If set, this list should have\n    the same length as 'validation_prompts'.\n    \"\"\"\n\n    num_validation_images_per_prompt: int = 4\n    \"\"\"The number of validation images to generate for each prompt in 'validation_prompts'. Careful, validation can\n    become quite slow if this number is too large.\n    \"\"\"\n\n    train_batch_size: int = 4\n    \"\"\"The training batch size.\n    \"\"\"\n\n    use_masks: bool = False\n    \"\"\"If True, image masks will be applied to weight the loss during training. The dataset must contain masks for this\n    feature to be used.\n    \"\"\"\n\n    data_loader: TextualInversionSDDataLoaderConfig\n    \"\"\"The data configuration.\n\n    See\n    [`TextualInversionSDDataLoaderConfig`][invoke_training.config.data.data_loader_config.TextualInversionSDDataLoaderConfig]\n    for details.\n    \"\"\"\n\n    vae_model: str | None = None\n    \"\"\"The name of the Hugging Face Hub VAE model to train against. This will override the VAE bundled with the base\n    model (specified by the `model` parameter). This config option is provided for SDXL models, because SDXL shipped\n    with a VAE that produces NaNs in fp16 mode, so it is common to replace this VAE with a fixed version.\n    \"\"\"\n\n    @model_validator(mode=\"after\")\n    def check_validation_prompts(self):\n        if self.negative_validation_prompts is not None and len(self.negative_validation_prompts) != len(\n            self.validation_prompts\n        ):\n            raise ValueError(\n                f\"The number of validation_prompts ({len(self.validation_prompts)}) must match the number of \"\n                f\"negative_validation_prompts ({len(self.negative_validation_prompts)}).\"\n            )\n        return self\n"
  },
  {
    "path": "src/invoke_training/pipelines/stable_diffusion_xl/lora_and_textual_inversion/train.py",
    "content": "import itertools\nimport json\nimport logging\nimport math\nimport os\nimport time\nfrom pathlib import Path\nfrom typing import Literal\n\nimport peft\nimport torch\nimport torch.utils.data\nfrom accelerate import Accelerator\nfrom accelerate.utils import set_seed\nfrom diffusers import UNet2DConditionModel\nfrom diffusers.optimization import get_scheduler\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPTextModel\n\nfrom invoke_training._shared.accelerator.accelerator_utils import (\n    get_dtype_from_str,\n    initialize_accelerator,\n    initialize_logging,\n)\nfrom invoke_training._shared.checkpoints.checkpoint_tracker import CheckpointTracker\nfrom invoke_training._shared.checkpoints.serialization import save_state_dict\nfrom invoke_training._shared.data.data_loaders.textual_inversion_sd_dataloader import (\n    build_textual_inversion_sd_dataloader,\n)\nfrom invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import log_aspect_ratio_buckets\nfrom invoke_training._shared.optimizer.optimizer_utils import initialize_optimizer\nfrom invoke_training._shared.stable_diffusion.lora_checkpoint_utils import (\n    TEXT_ENCODER_TARGET_MODULES,\n    UNET_TARGET_MODULES,\n    save_sdxl_kohya_checkpoint,\n    save_sdxl_peft_checkpoint,\n)\nfrom invoke_training._shared.stable_diffusion.model_loading_utils import load_models_sdxl\nfrom invoke_training._shared.stable_diffusion.textual_inversion import restore_original_embeddings\nfrom invoke_training._shared.stable_diffusion.validation import generate_validation_images_sdxl\nfrom invoke_training._shared.utils.import_xformers import import_xformers\nfrom invoke_training.pipelines.callbacks import ModelCheckpoint, ModelType, PipelineCallbacks, TrainingCheckpoint\nfrom invoke_training.pipelines.stable_diffusion_xl.lora.train import train_forward\nfrom invoke_training.pipelines.stable_diffusion_xl.lora_and_textual_inversion.config import (\n    SdxlLoraAndTextualInversionConfig,\n)\nfrom invoke_training.pipelines.stable_diffusion_xl.textual_inversion.train import _initialize_placeholder_tokens\n\n\ndef _save_sdxl_lora_and_ti_checkpoint(\n    config: SdxlLoraAndTextualInversionConfig,\n    epoch: int,\n    step: int,\n    unet: peft.PeftModel | None,\n    text_encoder_1: peft.PeftModel | None,\n    text_encoder_2: peft.PeftModel | None,\n    placeholder_token_ids_1: list[int],\n    placeholder_token_ids_2: list[int],\n    accelerator: Accelerator,\n    logger: logging.Logger,\n    checkpoint_tracker: CheckpointTracker,\n    lora_checkpoint_format: Literal[\"invoke_peft\", \"kohya\"],\n    callbacks: list[PipelineCallbacks] | None,\n):\n    # Prune checkpoints and get new checkpoint path.\n    num_pruned = checkpoint_tracker.prune(1)\n    if num_pruned > 0:\n        logger.info(f\"Pruned {num_pruned} checkpoint(s).\")\n    save_path = checkpoint_tracker.get_path(epoch=epoch, step=step)\n\n    training_checkpoint = TrainingCheckpoint(models=[], epoch=epoch, step=step)\n\n    if lora_checkpoint_format == \"invoke_peft\":\n        save_sdxl_peft_checkpoint(\n            Path(save_path),\n            unet=unet if config.train_unet else None,\n            text_encoder_1=text_encoder_1 if config.train_text_encoder else None,\n            text_encoder_2=text_encoder_2 if config.train_text_encoder else None,\n        )\n        training_checkpoint.models.append(ModelCheckpoint(file_path=save_path, model_type=ModelType.SDXL_LORA_PEFT))\n    elif lora_checkpoint_format == \"kohya\":\n        save_sdxl_kohya_checkpoint(\n            Path(save_path) / \"lora.safetensors\",\n            unet=unet if config.train_unet else None,\n            text_encoder_1=text_encoder_1 if config.train_text_encoder else None,\n            text_encoder_2=text_encoder_2 if config.train_text_encoder else None,\n        )\n        training_checkpoint.models.append(ModelCheckpoint(file_path=save_path, model_type=ModelType.SDXL_LORA_KOHYA))\n    else:\n        raise ValueError(f\"Unsupported lora_checkpoint_format: '{lora_checkpoint_format}'.\")\n\n    if config.train_ti:\n        ti_checkpoint_path = Path(save_path) / \"embeddings.safetensors\"\n        learned_embeds_1 = (\n            accelerator.unwrap_model(text_encoder_1)\n            .get_input_embeddings()\n            .weight[min(placeholder_token_ids_1) : max(placeholder_token_ids_1) + 1]\n        )\n        learned_embeds_2 = (\n            accelerator.unwrap_model(text_encoder_2)\n            .get_input_embeddings()\n            .weight[min(placeholder_token_ids_2) : max(placeholder_token_ids_2) + 1]\n        )\n        learned_embeds_dict = {\n            \"clip_l\": learned_embeds_1.detach().cpu().to(dtype=torch.float32),\n            \"clip_g\": learned_embeds_2.detach().cpu().to(dtype=torch.float32),\n        }\n        save_state_dict(learned_embeds_dict, ti_checkpoint_path)\n        training_checkpoint.models.append(\n            ModelCheckpoint(file_path=ti_checkpoint_path, model_type=ModelType.SDXL_TEXTUAL_INVERSION)\n        )\n\n    if callbacks is not None:\n        for cb in callbacks:\n            cb.on_save_checkpoint(training_checkpoint)\n\n\ndef train(config: SdxlLoraAndTextualInversionConfig, callbacks: list[PipelineCallbacks] | None = None):  # noqa: C901\n    # Give a clear error message if an unsupported base model was chosen.\n    # TODO(ryan): Update this check to work with single-file SD checkpoints.\n    # check_base_model_version(\n    #     {BaseModelVersionEnum.STABLE_DIFFUSION_SDXL_BASE},\n    #     config.model,\n    #     local_files_only=False,\n    # )\n\n    # Create a timestamped directory for all outputs.\n    out_dir = os.path.join(config.base_output_dir, f\"{time.time()}\")\n    ckpt_dir = os.path.join(out_dir, \"checkpoints\")\n    os.makedirs(ckpt_dir)\n\n    accelerator = initialize_accelerator(\n        out_dir, config.gradient_accumulation_steps, config.mixed_precision, config.report_to\n    )\n    logger = initialize_logging(os.path.basename(__file__), accelerator)\n\n    # Set the accelerate seed.\n    if config.seed is not None:\n        set_seed(config.seed)\n\n    # Log the accelerator configuration from every process to help with debugging.\n    logger.info(accelerator.state, main_process_only=False)\n\n    logger.info(\"Starting Training.\")\n    logger.info(f\"Configuration:\\n{json.dumps(config.dict(), indent=2, default=str)}\")\n    logger.info(f\"Output dir: '{out_dir}'\")\n\n    # Write the configuration to disk.\n    with open(os.path.join(out_dir, \"config.json\"), \"w\") as f:\n        json.dump(config.dict(), f, indent=2, default=str)\n\n    weight_dtype = get_dtype_from_str(config.weight_dtype)\n\n    logger.info(\"Loading models.\")\n    tokenizer_1, tokenizer_2, noise_scheduler, text_encoder_1, text_encoder_2, vae, unet = load_models_sdxl(\n        logger=logger,\n        model_name_or_path=config.model,\n        hf_variant=config.hf_variant,\n        vae_model=config.vae_model,\n        dtype=weight_dtype,\n    )\n\n    if config.xformers:\n        import_xformers()\n\n        # TODO(ryand): There is a known issue if xformers is enabled when training in mixed precision where xformers\n        # will fail because Q, K, V have different dtypes.\n        unet.enable_xformers_memory_efficient_attention()\n        vae.enable_xformers_memory_efficient_attention()\n\n    # Prepare text encoder output cache.\n    # text_encoder_output_cache_dir_name = None\n    if config.cache_text_encoder_outputs:\n        raise NotImplementedError(\"Caching text encoder outputs is not yet supported.\")\n    else:\n        text_encoder_1.to(accelerator.device, dtype=weight_dtype)\n        text_encoder_2.to(accelerator.device, dtype=weight_dtype)\n\n    # Prepare VAE output cache.\n    vae_output_cache_dir_name = None\n    if config.cache_vae_outputs:\n        raise NotImplementedError(\"Caching VAE outputs is not yet supported.\")\n    else:\n        vae.to(accelerator.device, dtype=weight_dtype)\n\n    unet.to(accelerator.device, dtype=weight_dtype)\n\n    # Add LoRA layers to the models being trained.\n    trainable_param_groups = []\n    all_trainable_models: set[torch.nn.Module] = set()\n\n    def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float) -> peft.PeftModel:\n        peft_model = peft.get_peft_model(model, lora_config)\n        peft_model.print_trainable_parameters()\n\n        # Populate `trainable_param_groups`, to be passed to the optimizer.\n        param_group = {\"params\": list(filter(lambda p: p.requires_grad, peft_model.parameters())), \"lr\": lr}\n        trainable_param_groups.append(param_group)\n\n        # Populate all_trainable_models.\n        all_trainable_models.add(peft_model)\n        peft_model.train()\n        return peft_model\n\n    if config.train_unet:\n        unet_lora_config = peft.LoraConfig(\n            r=config.lora_rank_dim,\n            # TODO(ryand): Diffusers uses lora_alpha=config.lora_rank_dim. Is that preferred?\n            lora_alpha=1.0,\n            target_modules=UNET_TARGET_MODULES,\n        )\n        unet = inject_lora_layers(unet, unet_lora_config, lr=config.unet_learning_rate)\n\n    if config.train_text_encoder:\n        text_encoder_lora_config = peft.LoraConfig(\n            r=config.lora_rank_dim,\n            lora_alpha=1.0,\n            # init_lora_weights=\"gaussian\",\n            target_modules=TEXT_ENCODER_TARGET_MODULES,\n        )\n        text_encoder_1 = inject_lora_layers(\n            text_encoder_1, text_encoder_lora_config, lr=config.text_encoder_learning_rate\n        )\n        text_encoder_2 = inject_lora_layers(\n            text_encoder_2, text_encoder_lora_config, lr=config.text_encoder_learning_rate\n        )\n\n    if config.train_ti:\n        # TODO(ryand): Move this private function to a shared location.\n        placeholder_tokens, placeholder_token_ids_1, placeholder_token_ids_2 = _initialize_placeholder_tokens(\n            config=config,\n            tokenizer_1=tokenizer_1,\n            tokenizer_2=tokenizer_2,\n            text_encoder_1=text_encoder_1,\n            text_encoder_2=text_encoder_2,\n            logger=logger,\n        )\n        logger.info(f\"Initialized {len(placeholder_tokens)} placeholder tokens: {placeholder_tokens}.\")\n\n        # Unfreeze the token embeddings in the text encoders.\n        text_encoder_1.text_model.embeddings.token_embedding.requires_grad_(True)\n        text_encoder_2.text_model.embeddings.token_embedding.requires_grad_(True)\n\n        all_trainable_models.add(text_encoder_1)\n        all_trainable_models.add(text_encoder_2)\n\n        for te in [text_encoder_1, text_encoder_2]:\n            param_group = {\n                \"params\": te.get_input_embeddings().parameters(),\n                \"lr\": config.textual_inversion_learning_rate,\n            }\n            trainable_param_groups.append(param_group)\n\n    # If mixed_precision is enabled, cast all trainable params to float32.\n    if config.mixed_precision != \"no\":\n        for trainable_model in all_trainable_models:\n            for param in trainable_model.parameters():\n                if param.requires_grad:\n                    param.data = param.to(torch.float32)\n\n    if config.gradient_checkpointing:\n        # We want to enable gradient checkpointing in the UNet regardless of whether it is being trained.\n        unet.enable_gradient_checkpointing()\n        # unet must be in train() mode for gradient checkpointing to take effect.\n        # At the time of writing, the unet dropout probabilities default to 0, so putting the unet in train mode does\n        # not change its forward behavior.\n        unet.train()\n        if config.train_text_encoder:\n            for te in [text_encoder_1, text_encoder_2]:\n                te.gradient_checkpointing_enable()\n\n                # The text encoders must be in train() mode for gradient checkpointing to take effect. This should\n                # already be the case, since we are training the text_encoders, be we do it explicitly to make it clear\n                # that this is required.\n                # At the time of writing, the text encoder dropout probabilities default to 0, so putting the text\n                # encoders in train mode does not change their forward behavior.\n                te.train()\n\n                # Set requires_grad = True on the first parameters of the text encoders. Without this, the text encoder\n                # LoRA weights would have 0 gradients, and so would not get trained. Note that the set of\n                # trainable_param_groups has already been populated - this won't change what gets trained.\n                te.text_model.embeddings.requires_grad_(True)\n\n    optimizer = initialize_optimizer(config.optimizer, trainable_param_groups)\n\n    data_loader = build_textual_inversion_sd_dataloader(\n        config=config.data_loader,\n        placeholder_token=config.placeholder_token,\n        batch_size=config.train_batch_size,\n        use_masks=config.use_masks,\n        vae_output_cache_dir=vae_output_cache_dir_name,\n    )\n\n    log_aspect_ratio_buckets(logger=logger, batch_sampler=data_loader.batch_sampler)\n\n    assert sum([config.max_train_steps is not None, config.max_train_epochs is not None]) == 1\n    assert sum([config.save_every_n_steps is not None, config.save_every_n_epochs is not None]) == 1\n    assert sum([config.validate_every_n_steps is not None, config.validate_every_n_epochs is not None]) == 1\n\n    # A \"step\" represents a single weight update operation (i.e. takes into account gradient accumulation steps).\n    # math.ceil(...) is used in calculating the num_steps_per_epoch, because by default an optimizer step is taken when\n    # the end of the dataloader is reached, even if gradient_accumulation_steps hasn't been reached.\n    num_steps_per_epoch = math.ceil(len(data_loader) / config.gradient_accumulation_steps)\n    num_train_steps = config.max_train_steps or config.max_train_epochs * num_steps_per_epoch\n    num_train_epochs = math.ceil(num_train_steps / num_steps_per_epoch)\n\n    # TODO(ryand): Test in a distributed training environment and more clearly document the rationale for scaling steps\n    # by the number of processes. This scaling logic was copied from the diffusers example training code, but it appears\n    # in many places so I don't know where it originated. Internally, accelerate makes one LR scheduler step per process\n    # (https://github.com/huggingface/accelerate/blame/49cb83a423f2946059117d8bb39b7c8747d29d80/src/accelerate/scheduler.py#L72-L82),\n    # so the scaling here simply reverses that behaviour.\n    lr_scheduler: torch.optim.lr_scheduler.LRScheduler = get_scheduler(\n        config.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=config.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=num_train_steps * accelerator.num_processes,\n    )\n\n    prepared_result: tuple[\n        UNet2DConditionModel,\n        peft.PeftModel | CLIPTextModel,\n        peft.PeftModel | CLIPTextModel,\n        torch.optim.Optimizer,\n        torch.utils.data.DataLoader,\n        torch.optim.lr_scheduler.LRScheduler,\n    ] = accelerator.prepare(\n        unet,\n        text_encoder_1,\n        text_encoder_2,\n        optimizer,\n        data_loader,\n        lr_scheduler,\n        # Disable automatic device placement for text_encoder if the text encoder outputs were cached.\n        device_placement=[\n            True,\n            not config.cache_text_encoder_outputs,\n            not config.cache_text_encoder_outputs,\n            True,\n            True,\n            True,\n        ],\n    )\n    unet, text_encoder_1, text_encoder_2, optimizer, data_loader, lr_scheduler = prepared_result\n\n    if accelerator.is_main_process:\n        accelerator.init_trackers(\"lora_and_ti_training\")\n        # Tensorboard uses markdown formatting, so we wrap the config json in a code block.\n        accelerator.log({\"configuration\": f\"```json\\n{json.dumps(config.dict(), indent=2, default=str)}\\n```\\n\"})\n\n    checkpoint_tracker = CheckpointTracker(\n        base_dir=ckpt_dir,\n        prefix=\"checkpoint\",\n        max_checkpoints=config.max_checkpoints,\n    )\n\n    # Train!\n    total_batch_size = config.train_batch_size * accelerator.num_processes * config.gradient_accumulation_steps\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num batches = {len(data_loader)}\")\n    logger.info(f\"  Instantaneous batch size per device = {config.train_batch_size}\")\n    logger.info(f\"  Gradient accumulation steps = {config.gradient_accumulation_steps}\")\n    logger.info(f\"  Parallel processes = {accelerator.num_processes}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Total optimization steps = {num_train_steps}\")\n    logger.info(f\"  Total epochs = {num_train_epochs}\")\n\n    global_step = 0\n    first_epoch = 0\n    completed_epochs = first_epoch\n\n    progress_bar = tqdm(\n        range(global_step, num_train_steps),\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n    progress_bar.set_description(\"Steps\")\n\n    ti_train_steps = num_train_steps\n    if config.ti_train_steps_ratio is not None:\n        ti_train_steps = math.ceil(num_train_steps * config.ti_train_steps_ratio)\n        logger.info(f\"The TI training pivot point is set at {ti_train_steps} steps.\")\n\n    # Keep original embeddings as reference.\n    with torch.no_grad():\n        orig_embeds_params_1 = accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight.data.clone()\n        orig_embeds_params_2 = accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight.data.clone()\n\n    def save_checkpoint(num_completed_epochs: int, num_completed_steps: int):\n        accelerator.wait_for_everyone()\n        if accelerator.is_main_process:\n            _save_sdxl_lora_and_ti_checkpoint(\n                config=config,\n                epoch=num_completed_epochs,\n                step=num_completed_steps,\n                unet=unet,\n                text_encoder_1=text_encoder_1,\n                text_encoder_2=text_encoder_2,\n                placeholder_token_ids_1=placeholder_token_ids_1,\n                placeholder_token_ids_2=placeholder_token_ids_2,\n                accelerator=accelerator,\n                logger=logger,\n                checkpoint_tracker=checkpoint_tracker,\n                lora_checkpoint_format=config.lora_checkpoint_format,\n                callbacks=callbacks,\n            )\n        accelerator.wait_for_everyone()\n\n    def validate(num_completed_epochs: int, num_completed_steps: int):\n        accelerator.wait_for_everyone()\n        if accelerator.is_main_process:\n            generate_validation_images_sdxl(\n                epoch=num_completed_epochs,\n                step=num_completed_steps,\n                out_dir=out_dir,\n                accelerator=accelerator,\n                vae=vae,\n                text_encoder_1=text_encoder_1,\n                text_encoder_2=text_encoder_2,\n                tokenizer_1=tokenizer_1,\n                tokenizer_2=tokenizer_2,\n                noise_scheduler=noise_scheduler,\n                unet=unet,\n                config=config,\n                logger=logger,\n                callbacks=callbacks,\n            )\n        accelerator.wait_for_everyone()\n\n    for epoch in range(first_epoch, num_train_epochs):\n        # TODO(ryand): Is this necessary?\n        text_encoder_1.train()\n        text_encoder_2.train()\n\n        train_loss = 0.0\n        for data_batch_idx, data_batch in enumerate(data_loader):\n            if global_step == ti_train_steps and config.train_ti:\n                logger.info(\"Reached TI training pivot point. Setting TI learning rate to 0.0.\")\n                # TODO(ryand): The TI embeddings continue to be updated slightly by the normalization step in\n                # restore_original_embeddings(...). The updates should be very small and converge quickly, so this\n                # should be fine. But, at some point we should tidy this up.\n                for ti_param_group in optimizer.param_groups[-2:]:\n                    # The TI param groups should be the last two param groups. But, this is pretty brittle, so this\n                    # assertion adds a bit of safety.\n                    assert len(ti_param_group[\"params\"]) == 1\n                    ti_param_group[\"lr\"] = 0.0\n\n            with accelerator.accumulate(unet, text_encoder_1, text_encoder_2):\n                loss = train_forward(\n                    accelerator=accelerator,\n                    data_batch=data_batch,\n                    vae=vae,\n                    noise_scheduler=noise_scheduler,\n                    tokenizer_1=tokenizer_1,\n                    tokenizer_2=tokenizer_2,\n                    text_encoder_1=text_encoder_1,\n                    text_encoder_2=text_encoder_2,\n                    unet=unet,\n                    weight_dtype=weight_dtype,\n                    resolution=config.data_loader.resolution,\n                    use_masks=config.use_masks,\n                    prediction_type=config.prediction_type,\n                    min_snr_gamma=config.min_snr_gamma,\n                )\n\n                # Gather the losses across all processes for logging (if we use distributed training).\n                # TODO(ryand): Test that this works properly with distributed training.\n                avg_loss = accelerator.gather(loss.repeat(config.train_batch_size)).mean()\n                train_loss += avg_loss.item() / config.gradient_accumulation_steps\n\n                # Backpropagate.\n                accelerator.backward(loss)\n                if accelerator.sync_gradients and config.max_grad_norm is not None:\n                    params_to_clip = itertools.chain.from_iterable([m.parameters() for m in all_trainable_models])\n                    accelerator.clip_grad_norm_(params_to_clip, config.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad(set_to_none=True)\n\n                # Make sure we don't update any embedding weights besides the newly-added token(s).\n                # TODO(ryand): Should we only do this if accelerator.sync_gradients?\n                restore_original_embeddings(\n                    tokenizer=tokenizer_1,\n                    placeholder_token_ids=placeholder_token_ids_1,\n                    accelerator=accelerator,\n                    text_encoder=text_encoder_1,\n                    orig_embeds_params=orig_embeds_params_1,\n                )\n                restore_original_embeddings(\n                    tokenizer=tokenizer_2,\n                    placeholder_token_ids=placeholder_token_ids_2,\n                    accelerator=accelerator,\n                    text_encoder=text_encoder_2,\n                    orig_embeds_params=orig_embeds_params_2,\n                )\n\n            # Checks if the accelerator has performed an optimization step behind the scenes.\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n                completed_epochs = epoch if (data_batch_idx + 1) < len(data_loader) else epoch + 1\n                log = {\"train_loss\": train_loss}\n\n                lrs = lr_scheduler.get_last_lr()\n\n                # Prepare LR names in the same order that their respective param groups were added to the optimizer.\n                # TODO: Do this at the time that we prepare the param groups?\n                lr_names = []\n                if config.train_unet:\n                    lr_names.append(\"unet\")\n                if config.train_text_encoder:\n                    lr_names.append(\"text_encoder_1\")\n                    lr_names.append(\"text_encoder_2\")\n                if config.train_ti:\n                    lr_names.append(\"ti_embeddings_1\")\n                    lr_names.append(\"ti_embeddings_2\")\n\n                for lr_idx, lr_name in enumerate(lr_names):\n                    log[f\"lr/{lr_name}\"] = float(lrs[lr_idx])\n                    if config.optimizer.optimizer_type == \"Prodigy\":\n                        log[f\"lr/d*lr/{lr_name}\"] = (\n                            optimizer.param_groups[lr_idx][\"d\"] * optimizer.param_groups[lr_idx][\"lr\"]\n                        )\n\n                accelerator.log(log, step=global_step)\n                train_loss = 0.0\n\n                # global_step represents the *number of completed steps* at this point.\n                if config.save_every_n_steps is not None and global_step % config.save_every_n_steps == 0:\n                    save_checkpoint(num_completed_epochs=completed_epochs, num_completed_steps=global_step)\n\n                if (\n                    config.validate_every_n_steps is not None\n                    and global_step % config.validate_every_n_steps == 0\n                    and len(config.validation_prompts) > 0\n                ):\n                    validate(num_completed_epochs=completed_epochs, num_completed_steps=global_step)\n\n            logs = {\n                \"step_loss\": loss.detach().item(),\n                \"lr\": lr_scheduler.get_last_lr()[0],\n            }\n            progress_bar.set_postfix(**logs)\n\n            if global_step >= num_train_steps:\n                break\n\n        # Save a checkpoint every n epochs.\n        if config.save_every_n_epochs is not None and completed_epochs % config.save_every_n_epochs == 0:\n            save_checkpoint(num_completed_epochs=completed_epochs, num_completed_steps=global_step)\n\n        # Generate validation images every n epochs.\n        if (\n            config.validate_every_n_epochs is not None\n            and completed_epochs % config.validate_every_n_epochs == 0\n            and len(config.validation_prompts) > 0\n        ):\n            validate(num_completed_epochs=completed_epochs, num_completed_steps=global_step)\n\n    accelerator.end_training()\n"
  },
  {
    "path": "src/invoke_training/pipelines/stable_diffusion_xl/textual_inversion/__init__.py",
    "content": ""
  },
  {
    "path": "src/invoke_training/pipelines/stable_diffusion_xl/textual_inversion/config.py",
    "content": "from typing import Literal\n\nfrom pydantic import model_validator\n\nfrom invoke_training.config.base_pipeline_config import BasePipelineConfig\nfrom invoke_training.config.data.data_loader_config import TextualInversionSDDataLoaderConfig\nfrom invoke_training.config.optimizer.optimizer_config import AdamOptimizerConfig, ProdigyOptimizerConfig\n\n\nclass SdxlTextualInversionConfig(BasePipelineConfig):\n    type: Literal[\"SDXL_TEXTUAL_INVERSION\"] = \"SDXL_TEXTUAL_INVERSION\"\n    \"\"\"Must be `SDXL_TEXTUAL_INVERSION`. This is what differentiates training pipeline types.\n    \"\"\"\n\n    model: str = \"stabilityai/stable-diffusion-xl-base-1.0\"\n    \"\"\"Name or path of the base model to train. Can be in diffusers format, or a single stable diffusion checkpoint\n    file. (E.g. 'stabilityai/stable-diffusion-xl-base-1.0', '/path/to/JuggernautXL.safetensors', etc. )\n    \"\"\"\n\n    hf_variant: str | None = \"fp16\"\n    \"\"\"The Hugging Face Hub model variant to use. Only applies if `model` is a Hugging Face Hub model name.\n    \"\"\"\n\n    # Helpful discussion for understanding how this works at inference time:\n    # https://github.com/huggingface/diffusers/pull/3144#discussion_r1172413509\n    num_vectors: int = 1\n    \"\"\"Note: `num_vectors` can be overridden by `initial_phrase`.\n\n    The number of textual inversion embedding vectors that will be used to learn the concept.\n\n    Increasing the `num_vectors` enables the model to learn more complex concepts, but has the following drawbacks:\n\n    - greater risk of overfitting\n    - increased size of the resulting output file\n    - consumes more of the prompt capacity at inference time\n\n    Typical values for `num_vectors` are in the range [1, 16].\n\n    As a rule of thumb, `num_vectors` can be increased as the size of the dataset increases (without overfitting).\n    \"\"\"\n\n    placeholder_token: str\n    \"\"\"The special word to associate the learned embeddings with. Choose a unique token that is unlikely to already\n    exist in the tokenizer's vocabulary.\n    \"\"\"\n\n    initializer_token: str | None = None\n    \"\"\"Note: Exactly one of `initializer_token`, `initial_embedding_file`, or `initial_phrase` should be set.\n\n    A vocabulary token to use as an initializer for the placeholder token. It should be a single word that roughly\n    describes the object or style that you're trying to train on. Must map to a single tokenizer token.\n\n    For example, if you are training on a dataset of images of your pet dog, a good choice would be `dog`.\n    \"\"\"\n\n    initial_embedding_file: str | None = None\n    \"\"\"Note: Exactly one of `initializer_token`, `initial_embedding_file`, or `initial_phrase` should be set.\n\n    Path to an existing TI embedding that will be used to initialize the embedding being trained. The placeholder\n    token in the file must match the `placeholder_token` field.\n\n    Either `initializer_token` or `initial_embedding_file` should be set.\n    \"\"\"\n\n    initial_phrase: str | None = None\n    \"\"\"Note: Exactly one of `initializer_token`, `initial_embedding_file`, or `initial_phrase` should be set.\n\n    A phrase that will be used to initialize the placeholder token embedding. The phrase will be tokenized, and the\n    corresponding embeddings will be used to initialize the placeholder tokens. The number of embedding vectors will be\n    inferred from the length of the tokenized phrase, so keep the phrase short. The consequences of training a large\n    number of embedding vectors are discussed in the `num_vectors` field documentation.\n\n    For example, if you are training on a dataset of images of pokemon, you might use `pokemon sketch white background`.\n    \"\"\"\n\n    optimizer: AdamOptimizerConfig | ProdigyOptimizerConfig = AdamOptimizerConfig()\n\n    lr_scheduler: Literal[\n        \"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\", \"constant\", \"constant_with_warmup\"\n    ] = \"constant\"\n\n    lr_warmup_steps: int = 0\n    \"\"\"The number of warmup steps in the learning rate scheduler. Only applied to schedulers that support warmup.\n    See lr_scheduler.\n    \"\"\"\n\n    min_snr_gamma: float | None = 5.0\n    \"\"\"Min-SNR weighting for diffusion training was introduced in https://arxiv.org/abs/2303.09556. This strategy\n    improves the speed of training convergence by adjusting the weight of each sample.\n\n    `min_snr_gamma` acts like an an upper bound on the weight of samples with low noise levels.\n\n    If `None`, then Min-SNR weighting will not be applied. If enabled, the recommended value is `min_snr_gamma = 5.0`.\n    \"\"\"\n\n    cache_vae_outputs: bool = False\n    \"\"\"If True, the VAE will be applied to all of the images in the dataset before starting training and the results\n    will be cached to disk. This reduces the VRAM requirements during training (don't have to keep the VAE in VRAM), and\n    speeds up training (don't have to run the VAE encoding step).\n\n    This option can only be enabled if all non-deterministic image augmentations are disabled (i.e. `center_crop=True`,\n    `random_flip=False`, etc.).\n    \"\"\"\n\n    enable_cpu_offload_during_validation: bool = False\n    \"\"\"If True, models will be kept in CPU memory and loaded into GPU memory one-by-one while generating validation\n    images. This reduces VRAM requirements at the cost of slower generation of validation images.\n    \"\"\"\n\n    gradient_accumulation_steps: int = 1\n    \"\"\"The number of gradient steps to accumulate before each weight update. This is an alternative to increasing the\n    `train_batch_size` when training with limited VRAM.\n    \"\"\"\n\n    weight_dtype: Literal[\"float32\", \"float16\", \"bfloat16\"] = \"bfloat16\"\n    \"\"\"All weights (trainable and fixed) will be cast to this precision. Lower precision dtypes require less VRAM, and\n    result in faster training, but are more prone to issues with numerical stability.\n\n    Recommendations:\n\n    - `\"float32\"`: Use this mode if you have plenty of VRAM available.\n    - `\"bfloat16\"`: Use this mode if you have limited VRAM and a GPU that supports bfloat16.\n    - `\"float16\"`: Use this mode if you have limited VRAM and a GPU that does not support bfloat16.\n\n    See also [`mixed_precision`][invoke_training.pipelines.stable_diffusion_xl.textual_inversion.config.SdxlTextualInversionConfig.mixed_precision].\n    \"\"\"  # noqa: E501\n\n    mixed_precision: Literal[\"no\", \"fp16\", \"bf16\", \"fp8\"] = \"no\"\n    \"\"\"The mixed precision mode to use.\n\n    If mixed precision is enabled, then all non-trainable parameters will be cast to the specified `weight_dtype`, and\n    trainable parameters are kept in float32 precision to avoid issues with numerical stability.\n\n    This value is passed to Hugging Face Accelerate. See\n    [`accelerate.Accelerator.mixed_precision`](https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.mixed_precision)\n    for more details.\n    \"\"\"  # noqa: E501\n\n    xformers: bool = False\n    \"\"\"If `True`, use xformers for more efficient attention blocks.\n    \"\"\"\n\n    gradient_checkpointing: bool = False\n    \"\"\"Whether or not to use gradient checkpointing to save memory at the expense of a slower backward pass. Enabling\n    gradient checkpointing slows down training by ~20%.\n    \"\"\"\n\n    max_checkpoints: int | None = None\n    \"\"\"The maximum number of checkpoints to keep. New checkpoints will replace earlier checkpoints to stay under this\n    limit. Note that this limit is applied to 'step' and 'epoch' checkpoints separately.\n    \"\"\"\n\n    prediction_type: Literal[\"epsilon\", \"v_prediction\"] | None = None\n    \"\"\"The prediction type that will be used for training. If `None`, the prediction type will be inferred from the\n    scheduler.\n    \"\"\"\n\n    max_grad_norm: float | None = None\n    \"\"\"Maximum gradient norm for gradient clipping. Set to `None` for no clipping.\n    \"\"\"\n\n    validation_prompts: list[str] = []\n    \"\"\"A list of prompts that will be used to generate images throughout training for the purpose of tracking progress.\n    \"\"\"\n\n    negative_validation_prompts: list[str] | None = None\n    \"\"\"A list of negative prompts that will be applied when generating validation images. If set, this list should have\n    the same length as 'validation_prompts'.\n    \"\"\"\n\n    num_validation_images_per_prompt: int = 4\n    \"\"\"The number of validation images to generate for each prompt in `validation_prompts`. Careful, validation can\n    become very slow if this number is too large.\n    \"\"\"\n\n    train_batch_size: int = 4\n    \"\"\"The training batch size.\n    \"\"\"\n\n    use_masks: bool = False\n    \"\"\"If True, image masks will be applied to weight the loss during training. The dataset must contain masks for this\n    feature to be used.\n    \"\"\"\n\n    data_loader: TextualInversionSDDataLoaderConfig\n    \"\"\"The data configuration.\n\n    See\n    [`TextualInversionSDDataLoaderConfig`][invoke_training.config.data.data_loader_config.TextualInversionSDDataLoaderConfig]\n    for details.\n    \"\"\"\n\n    vae_model: str | None = None\n    \"\"\"The name of the Hugging Face Hub VAE model to train against. If set, this will override the VAE bundled with the\n    base model (specified by the `model` parameter). This config option is provided for SDXL models, because SDXL 1.0\n    shipped with a VAE that produces NaNs in fp16 mode, so it is common to replace this VAE with a fixed version.\n    \"\"\"\n\n    @model_validator(mode=\"after\")\n    def check_validation_prompts(self):\n        if self.negative_validation_prompts is not None and len(self.negative_validation_prompts) != len(\n            self.validation_prompts\n        ):\n            raise ValueError(\n                f\"The number of validation_prompts ({len(self.validation_prompts)}) must match the number of \"\n                f\"negative_validation_prompts ({len(self.negative_validation_prompts)}).\"\n            )\n        return self\n"
  },
  {
    "path": "src/invoke_training/pipelines/stable_diffusion_xl/textual_inversion/train.py",
    "content": "import json\nimport logging\nimport math\nimport os\nimport tempfile\nimport time\n\nimport torch\nimport torch.utils.data\nfrom accelerate import Accelerator\nfrom accelerate.utils import set_seed\nfrom diffusers.optimization import get_scheduler\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPPreTrainedModel, CLIPTextModel, CLIPTokenizer, PreTrainedTokenizer\n\nfrom invoke_training._shared.accelerator.accelerator_utils import (\n    get_dtype_from_str,\n    initialize_accelerator,\n    initialize_logging,\n)\nfrom invoke_training._shared.checkpoints.checkpoint_tracker import CheckpointTracker\nfrom invoke_training._shared.checkpoints.serialization import save_state_dict\nfrom invoke_training._shared.data.data_loaders.textual_inversion_sd_dataloader import (\n    build_textual_inversion_sd_dataloader,\n)\nfrom invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import log_aspect_ratio_buckets\nfrom invoke_training._shared.optimizer.optimizer_utils import initialize_optimizer\nfrom invoke_training._shared.stable_diffusion.model_loading_utils import load_models_sdxl\nfrom invoke_training._shared.stable_diffusion.textual_inversion import (\n    initialize_placeholder_tokens_from_initial_phrase,\n    initialize_placeholder_tokens_from_initializer_token,\n    restore_original_embeddings,\n)\nfrom invoke_training._shared.stable_diffusion.validation import generate_validation_images_sdxl\nfrom invoke_training._shared.utils.import_xformers import import_xformers\nfrom invoke_training.pipelines.callbacks import ModelCheckpoint, ModelType, PipelineCallbacks, TrainingCheckpoint\nfrom invoke_training.pipelines.stable_diffusion_xl.lora.train import cache_vae_outputs, train_forward\nfrom invoke_training.pipelines.stable_diffusion_xl.lora_and_textual_inversion.config import (\n    SdxlLoraAndTextualInversionConfig,\n)\nfrom invoke_training.pipelines.stable_diffusion_xl.textual_inversion.config import SdxlTextualInversionConfig\n\n\ndef _save_ti_embeddings(\n    epoch: int,\n    step: int,\n    text_encoder_1: CLIPTextModel,\n    text_encoder_2: CLIPTextModel,\n    placeholder_token_ids_1: list[int],\n    placeholder_token_ids_2: list[int],\n    accelerator: Accelerator,\n    logger: logging.Logger,\n    checkpoint_tracker: CheckpointTracker,\n    callbacks: list[PipelineCallbacks] | None,\n):\n    \"\"\"Save a Textual Inversion SDXL checkpoint. Old checkpoints are deleted if necessary to respect the\n    checkpoint_tracker limits.\n    \"\"\"\n    # Prune checkpoints and get new checkpoint path.\n    num_pruned = checkpoint_tracker.prune(1)\n    if num_pruned > 0:\n        logger.info(f\"Pruned {num_pruned} checkpoint(s).\")\n    save_path = checkpoint_tracker.get_path(epoch=epoch, step=step)\n\n    learned_embeds_1 = (\n        accelerator.unwrap_model(text_encoder_1)\n        .get_input_embeddings()\n        .weight[min(placeholder_token_ids_1) : max(placeholder_token_ids_1) + 1]\n    )\n    learned_embeds_2 = (\n        accelerator.unwrap_model(text_encoder_2)\n        .get_input_embeddings()\n        .weight[min(placeholder_token_ids_2) : max(placeholder_token_ids_2) + 1]\n    )\n    learned_embeds_dict = {\n        \"clip_l\": learned_embeds_1.detach().cpu().to(dtype=torch.float32),\n        \"clip_g\": learned_embeds_2.detach().cpu().to(dtype=torch.float32),\n    }\n\n    save_state_dict(learned_embeds_dict, save_path)\n\n    if callbacks is not None:\n        for cb in callbacks:\n            cb.on_save_checkpoint(\n                TrainingCheckpoint(\n                    models=[ModelCheckpoint(file_path=save_path, model_type=ModelType.SDXL_TEXTUAL_INVERSION)],\n                    epoch=epoch,\n                    step=step,\n                )\n            )\n\n\ndef _initialize_placeholder_tokens(\n    config: SdxlTextualInversionConfig | SdxlLoraAndTextualInversionConfig,\n    tokenizer_1: CLIPTokenizer,\n    tokenizer_2: CLIPTokenizer,\n    text_encoder_1: PreTrainedTokenizer,\n    text_encoder_2: PreTrainedTokenizer,\n    logger: logging.Logger,\n) -> tuple[list[str], list[int], list[int]]:\n    \"\"\"Prepare the tokenizers and text_encoders for TI training.\n\n    - Add the placeholder tokens to the tokenizers.\n    - Add new token embeddings to the text_encoders for each of the placeholder tokens.\n    - Initialize the new token embeddings from either an existing token, or an initial TI embedding file.\n    \"\"\"\n\n    if (\n        sum(\n            [\n                getattr(config, \"initializer_token\", None) is not None,\n                getattr(config, \"initial_embedding_file\", None) is not None,\n                getattr(config, \"initial_phrase\", None) is not None,\n            ]\n        )\n        != 1\n    ):\n        raise ValueError(\n            \"Exactly one of 'initializer_token', 'initial_embedding_file', or 'initial_phrase' should be set.\"\n        )\n\n    if getattr(config, \"initializer_token\", None) is not None:\n        placeholder_tokens_1, placeholder_token_ids_1 = initialize_placeholder_tokens_from_initializer_token(\n            tokenizer=tokenizer_1,\n            text_encoder=text_encoder_1,\n            initializer_token=config.initializer_token,\n            placeholder_token=config.placeholder_token,\n            num_vectors=config.num_vectors,\n            logger=logger,\n        )\n        placeholder_tokens_2, placeholder_token_ids_2 = initialize_placeholder_tokens_from_initializer_token(\n            tokenizer=tokenizer_2,\n            text_encoder=text_encoder_2,\n            initializer_token=config.initializer_token,\n            placeholder_token=config.placeholder_token,\n            num_vectors=config.num_vectors,\n            logger=logger,\n        )\n    elif getattr(config, \"initial_embedding_file\", None) is not None:\n        # TODO(ryan)\n        raise NotImplementedError(\"Initializing from an initial embedding is not yet supported for SDXL.\")\n    elif getattr(config, \"initial_phrase\", None) is not None:\n        placeholder_tokens_1, placeholder_token_ids_1 = initialize_placeholder_tokens_from_initial_phrase(\n            tokenizer=tokenizer_1,\n            text_encoder=text_encoder_1,\n            initial_phrase=config.initial_phrase,\n            placeholder_token=config.placeholder_token,\n        )\n        placeholder_tokens_2, placeholder_token_ids_2 = initialize_placeholder_tokens_from_initial_phrase(\n            tokenizer=tokenizer_2,\n            text_encoder=text_encoder_2,\n            initial_phrase=config.initial_phrase,\n            placeholder_token=config.placeholder_token,\n        )\n    else:\n        raise ValueError(\n            \"Exactly one of 'initializer_token', 'initial_embedding_file', or 'initial_phrase' should be set.\"\n        )\n\n    assert placeholder_tokens_1 == placeholder_tokens_2\n    return placeholder_tokens_1, placeholder_token_ids_1, placeholder_token_ids_2\n\n\ndef train(config: SdxlTextualInversionConfig, callbacks: list[PipelineCallbacks] | None = None):  # noqa: C901\n    # Create a timestamped directory for all outputs.\n    out_dir = os.path.join(config.base_output_dir, f\"{time.time()}\")\n    ckpt_dir = os.path.join(out_dir, \"checkpoints\")\n    os.makedirs(ckpt_dir)\n\n    accelerator = initialize_accelerator(\n        out_dir, config.gradient_accumulation_steps, config.mixed_precision, config.report_to\n    )\n    logger = initialize_logging(os.path.basename(__file__), accelerator)\n\n    # Set the accelerate seed.\n    if config.seed is not None:\n        set_seed(config.seed)\n\n    # Log the accelerator configuration from every process to help with debugging.\n    logger.info(accelerator.state, main_process_only=False)\n\n    logger.info(\"Starting Training.\")\n    logger.info(f\"Configuration:\\n{json.dumps(config.dict(), indent=2, default=str)}\")\n    logger.info(f\"Output dir: '{out_dir}'\")\n\n    # Write the configuration to disk.\n    with open(os.path.join(out_dir, \"config.json\"), \"w\") as f:\n        json.dump(config.dict(), f, indent=2, default=str)\n\n    weight_dtype = get_dtype_from_str(config.weight_dtype)\n\n    logger.info(\"Loading models.\")\n    tokenizer_1, tokenizer_2, noise_scheduler, text_encoder_1, text_encoder_2, vae, unet = load_models_sdxl(\n        logger=logger,\n        model_name_or_path=config.model,\n        hf_variant=config.hf_variant,\n        vae_model=config.vae_model,\n        dtype=weight_dtype,\n    )\n\n    placeholder_tokens, placeholder_token_ids_1, placeholder_token_ids_2 = _initialize_placeholder_tokens(\n        config=config,\n        tokenizer_1=tokenizer_1,\n        tokenizer_2=tokenizer_2,\n        text_encoder_1=text_encoder_1,\n        text_encoder_2=text_encoder_2,\n        logger=logger,\n    )\n    logger.info(f\"Initialized {len(placeholder_tokens)} placeholder tokens: {placeholder_tokens}.\")\n\n    # All parameters of the VAE, UNet, and text encoder are currently frozen. Just unfreeze the token embeddings in the\n    # text encoders.\n    text_encoder_1.text_model.embeddings.token_embedding.requires_grad_(True)\n    text_encoder_2.text_model.embeddings.token_embedding.requires_grad_(True)\n\n    if config.gradient_checkpointing:\n        # We want to enable gradient checkpointing in the UNet regardless of whether it is being trained.\n        unet.enable_gradient_checkpointing()\n        # unet must be in train() mode for gradient checkpointing to take effect.\n        # At the time of writing, the unet dropout probabilities default to 0, so putting the unet in train mode does\n        # not change its forward behavior.\n        unet.train()\n        for te in [text_encoder_1, text_encoder_2]:\n            # The text_encoder will be put in .train() mode later, so we don't need to worry about that here.\n            # Note: There are some weird interactions gradient checkpointing and requires_grad_() when training a\n            # text_encoder LoRA. If this code ever gets copied elsewhere, make sure to take a look at how this is\n            # handled in other training pipelines.\n            te.gradient_checkpointing_enable()\n\n    if config.xformers:\n        import_xformers()\n\n        unet.enable_xformers_memory_efficient_attention()\n        vae.enable_xformers_memory_efficient_attention()\n\n    # Prepare VAE output cache.\n    vae_output_cache_dir_name = None\n    if config.cache_vae_outputs:\n        if config.data_loader.random_flip:\n            raise ValueError(\"'cache_vae_outputs' cannot be True if 'random_flip' is True.\")\n        if not config.data_loader.center_crop:\n            raise ValueError(\"'cache_vae_outputs' cannot be True if 'center_crop' is False.\")\n\n        # We use a temporary directory for the cache. The directory will automatically be cleaned up when\n        # tmp_vae_output_cache_dir is destroyed.\n        tmp_vae_output_cache_dir = tempfile.TemporaryDirectory()\n        vae_output_cache_dir_name = tmp_vae_output_cache_dir.name\n        if accelerator.is_local_main_process:\n            # Only the main process should to populate the cache.\n            logger.info(f\"Generating VAE output cache ('{vae_output_cache_dir_name}').\")\n            vae.to(accelerator.device, dtype=weight_dtype)\n            data_loader = build_textual_inversion_sd_dataloader(\n                config=config.data_loader,\n                placeholder_token=config.placeholder_token,\n                batch_size=config.train_batch_size,\n                use_masks=config.use_masks,\n                shuffle=False,\n            )\n            cache_vae_outputs(vae_output_cache_dir_name, data_loader, vae)\n        # Move the VAE back to the CPU, because it is not needed for training.\n        vae.to(\"cpu\")\n        accelerator.wait_for_everyone()\n    else:\n        vae.to(accelerator.device, dtype=weight_dtype)\n\n    unet.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_1.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_2.to(accelerator.device, dtype=weight_dtype)\n\n    # Initialize the optimizer to only optimize the token embeddings.\n    trainable_param_groups = [\n        {\"params\": text_encoder_1.get_input_embeddings().parameters()},\n        {\"params\": text_encoder_2.get_input_embeddings().parameters()},\n    ]\n    optimizer = initialize_optimizer(config.optimizer, trainable_param_groups)\n    trainable_models = torch.nn.ModuleDict({\"text_encoder_1\": text_encoder_1, \"text_encoder_2\": text_encoder_2})\n\n    data_loader = build_textual_inversion_sd_dataloader(\n        config=config.data_loader,\n        placeholder_token=config.placeholder_token,\n        batch_size=config.train_batch_size,\n        use_masks=config.use_masks,\n        vae_output_cache_dir=vae_output_cache_dir_name,\n    )\n\n    log_aspect_ratio_buckets(logger=logger, batch_sampler=data_loader.batch_sampler)\n\n    assert sum([config.max_train_steps is not None, config.max_train_epochs is not None]) == 1\n    assert sum([config.save_every_n_steps is not None, config.save_every_n_epochs is not None]) == 1\n    assert sum([config.validate_every_n_steps is not None, config.validate_every_n_epochs is not None]) == 1\n\n    # A \"step\" represents a single weight update operation (i.e. takes into account gradient accumulation steps).\n    # math.ceil(...) is used in calculating the num_steps_per_epoch, because by default an optimizer step is taken when\n    # the end of the dataloader is reached, even if gradient_accumulation_steps hasn't been reached.\n    num_steps_per_epoch = math.ceil(len(data_loader) / config.gradient_accumulation_steps)\n    num_train_steps = config.max_train_steps or config.max_train_epochs * num_steps_per_epoch\n    num_train_epochs = math.ceil(num_train_steps / num_steps_per_epoch)\n\n    # TODO(ryand): Test in a distributed training environment and more clearly document the rationale for scaling steps\n    # by the number of processes. This scaling logic was copied from the diffusers example training code, but it appears\n    # in many places so I don't know where it originated. Internally, accelerate makes one LR scheduler step per process\n    # (https://github.com/huggingface/accelerate/blame/49cb83a423f2946059117d8bb39b7c8747d29d80/src/accelerate/scheduler.py#L72-L82),\n    # so the scaling here simply reverses that behaviour.\n    lr_scheduler: torch.optim.lr_scheduler.LRScheduler = get_scheduler(\n        config.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=config.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=num_train_steps * accelerator.num_processes,\n    )\n\n    prepared_result: tuple[\n        CLIPPreTrainedModel,\n        CLIPPreTrainedModel,\n        torch.optim.Optimizer,\n        torch.utils.data.DataLoader,\n        torch.optim.lr_scheduler.LRScheduler,\n    ] = accelerator.prepare(text_encoder_1, text_encoder_2, optimizer, data_loader, lr_scheduler)\n    text_encoder_1, text_encoder_2, optimizer, data_loader, lr_scheduler = prepared_result\n\n    if accelerator.is_main_process:\n        accelerator.init_trackers(\"textual_inversion_training\")\n        # Tensorboard uses markdown formatting, so we wrap the config json in a code block.\n        accelerator.log({\"configuration\": f\"```json\\n{json.dumps(config.dict(), indent=2, default=str)}\\n```\\n\"})\n\n    checkpoint_tracker = CheckpointTracker(\n        base_dir=ckpt_dir,\n        prefix=\"checkpoint\",\n        extension=\".safetensors\",\n        max_checkpoints=config.max_checkpoints,\n    )\n\n    # Train!\n    total_batch_size = config.train_batch_size * accelerator.num_processes * config.gradient_accumulation_steps\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num batches = {len(data_loader)}\")\n    logger.info(f\"  Instantaneous batch size per device = {config.train_batch_size}\")\n    logger.info(f\"  Gradient accumulation steps = {config.gradient_accumulation_steps}\")\n    logger.info(f\"  Parallel processes = {accelerator.num_processes}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Total optimization steps = {num_train_steps}\")\n    logger.info(f\"  Total epochs = {num_train_epochs}\")\n\n    global_step = 0\n    first_epoch = 0\n    completed_epochs = 0\n\n    progress_bar = tqdm(\n        range(global_step, num_train_steps),\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n    progress_bar.set_description(\"Steps\")\n\n    # Keep original embeddings as reference.\n    with torch.no_grad():\n        orig_embeds_params_1 = accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight.data.clone()\n        orig_embeds_params_2 = accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight.data.clone()\n\n    def save_checkpoint(num_completed_epochs: int, num_completed_steps: int):\n        accelerator.wait_for_everyone()\n        if accelerator.is_main_process:\n            _save_ti_embeddings(\n                epoch=num_completed_epochs,\n                step=num_completed_steps,\n                text_encoder_1=text_encoder_1,\n                text_encoder_2=text_encoder_2,\n                placeholder_token_ids_1=placeholder_token_ids_1,\n                placeholder_token_ids_2=placeholder_token_ids_2,\n                accelerator=accelerator,\n                logger=logger,\n                checkpoint_tracker=checkpoint_tracker,\n                callbacks=callbacks,\n            )\n        accelerator.wait_for_everyone()\n\n    def validate(num_completed_epochs: int, num_completed_steps: int):\n        accelerator.wait_for_everyone()\n        if accelerator.is_main_process:\n            generate_validation_images_sdxl(\n                epoch=num_completed_epochs,\n                step=num_completed_steps,\n                out_dir=out_dir,\n                accelerator=accelerator,\n                vae=vae,\n                text_encoder_1=text_encoder_1,\n                text_encoder_2=text_encoder_2,\n                tokenizer_1=tokenizer_1,\n                tokenizer_2=tokenizer_2,\n                noise_scheduler=noise_scheduler,\n                unet=unet,\n                config=config,\n                logger=logger,\n                callbacks=callbacks,\n            )\n        accelerator.wait_for_everyone()\n\n    for epoch in range(first_epoch, num_train_epochs):\n        text_encoder_1.train()\n        text_encoder_2.train()\n\n        train_loss = 0.0\n        for data_batch_idx, data_batch in enumerate(data_loader):\n            with accelerator.accumulate(trainable_models):\n                loss = train_forward(\n                    accelerator=accelerator,\n                    data_batch=data_batch,\n                    vae=vae,\n                    noise_scheduler=noise_scheduler,\n                    tokenizer_1=tokenizer_1,\n                    tokenizer_2=tokenizer_2,\n                    text_encoder_1=text_encoder_1,\n                    text_encoder_2=text_encoder_2,\n                    unet=unet,\n                    weight_dtype=weight_dtype,\n                    resolution=config.data_loader.resolution,\n                    use_masks=config.use_masks,\n                    prediction_type=config.prediction_type,\n                    min_snr_gamma=config.min_snr_gamma,\n                )\n\n                # Gather the losses across all processes for logging (if we use distributed training).\n                # TODO(ryand): Test that this works properly with distributed training.\n                avg_loss = accelerator.gather(loss.repeat(config.train_batch_size)).mean()\n                train_loss += avg_loss.item() / config.gradient_accumulation_steps\n\n                # Backpropagate.\n                accelerator.backward(loss)\n                if accelerator.sync_gradients and config.max_grad_norm is not None:\n                    # TODO(ryand): I copied this from another pipeline. Should probably just clip the trainable params.\n                    params_to_clip = trainable_models.parameters()\n                    accelerator.clip_grad_norm_(params_to_clip, config.max_grad_norm)\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad(set_to_none=True)\n\n                # Make sure we don't update any embedding weights besides the newly-added token(s).\n                # TODO(ryand): Should we only do this if accelerator.sync_gradients?\n                restore_original_embeddings(\n                    tokenizer=tokenizer_1,\n                    placeholder_token_ids=placeholder_token_ids_1,\n                    accelerator=accelerator,\n                    text_encoder=text_encoder_1,\n                    orig_embeds_params=orig_embeds_params_1,\n                )\n                restore_original_embeddings(\n                    tokenizer=tokenizer_2,\n                    placeholder_token_ids=placeholder_token_ids_2,\n                    accelerator=accelerator,\n                    text_encoder=text_encoder_2,\n                    orig_embeds_params=orig_embeds_params_2,\n                )\n\n            # Checks if the accelerator has performed an optimization step behind the scenes.\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n                completed_epochs = epoch if (data_batch_idx + 1) < len(data_loader) else epoch + 1\n                log = {\"train_loss\": train_loss, \"lr\": lr_scheduler.get_last_lr()[0]}\n\n                if config.optimizer.optimizer_type == \"Prodigy\":\n                    # TODO(ryand): Test Prodigy logging.\n                    log[\"lr/d*lr\"] = optimizer.param_groups[0][\"d\"] * optimizer.param_groups[0][\"lr\"]\n\n                accelerator.log(log, step=global_step)\n                train_loss = 0.0\n\n                # global_step represents the *number of completed steps* at this point.\n                if config.save_every_n_steps is not None and global_step % config.save_every_n_steps == 0:\n                    save_checkpoint(num_completed_epochs=completed_epochs, num_completed_steps=global_step)\n\n                if (\n                    config.validate_every_n_steps is not None\n                    and global_step % config.validate_every_n_steps == 0\n                    and len(config.validation_prompts) > 0\n                ):\n                    validate(num_completed_epochs=completed_epochs, num_completed_steps=global_step)\n\n            logs = {\n                \"step_loss\": loss.detach().item(),\n                \"lr\": lr_scheduler.get_last_lr()[0],\n            }\n            progress_bar.set_postfix(**logs)\n\n            if global_step >= num_train_steps:\n                break\n\n        # Save a checkpoint every n epochs.\n        if config.save_every_n_epochs is not None and completed_epochs % config.save_every_n_epochs == 0:\n            save_checkpoint(num_completed_epochs=completed_epochs, num_completed_steps=global_step)\n\n        # Generate validation images every n epochs.\n        if (\n            config.validate_every_n_epochs is not None\n            and completed_epochs % config.validate_every_n_epochs == 0\n            and len(config.validation_prompts) > 0\n        ):\n            validate(num_completed_epochs=completed_epochs, num_completed_steps=global_step)\n\n    accelerator.end_training()\n"
  },
  {
    "path": "src/invoke_training/sample_configs/_experimental/sd_dpo_lora_pickapic_1x24gb.yaml",
    "content": "# Training mode: Direct Preference Optimization LoRA Training\n# Dataset: A small subset of the pickapic_v2 dataset.\n# Base model:    SD 1.5\n# GPU:           1 x 24GB\n#\n# Training takes ~2 hours on a single RTX 4090.\n\ntype: SD_DIRECT_PREFERENCE_OPTIMIZATION_LORA\nseed: 1\nbase_output_dir: output/dpo\n\noptimizer:\n  optimizer_type: AdamW\n  learning_rate: 1e-4\n  weight_decay: 1e-2\n\nlr_warmup_steps: 200\nlr_scheduler: cosine\n\ndata_loader:\n  type: IMAGE_PAIR_PREFERENCE_SD_DATA_LOADER\n  dataset:\n    type: HF_HUB_IMAGE_PAIR_PREFERENCE_DATASET\n  resolution: 512\n\n# General\nmodel: runwayml/stable-diffusion-v1-5\ngradient_accumulation_steps: 2\nweight_dtype: float16\nmixed_precision: fp16\ngradient_checkpointing: True\nmax_train_steps: 5000\nsave_every_n_epochs: 1\nsave_every_n_steps: null\nmax_checkpoints: 100\nvalidation_prompts:\n  - A monk in an orange robe by a round window in a spaceship in dramatic lighting\n  - A galaxy-colored figurine is floating over the sea at sunset, photorealistic\n  - Concept art of a mythical sky alligator with wings, nature documentary\nvalidate_every_n_epochs: 1\ntrain_batch_size: 4\nnum_validation_images_per_prompt: 1\n"
  },
  {
    "path": "src/invoke_training/sample_configs/_experimental/sd_dpo_lora_refinement_pokemon_1x24gb.yaml",
    "content": "# Training mode: Direct Preference Optimization LoRA Training\n# Base model:    SD 1.5\n# GPU:           1 x 24GB\n\ntype: SD_DIRECT_PREFERENCE_OPTIMIZATION_LORA\nseed: 1\nbase_output_dir: output/dpo\n\noptimizer:\n  optimizer_type: AdamW\n  learning_rate: 1e-4\n  weight_decay: 1e-2\n\nlr_warmup_steps: 500\nlr_scheduler: cosine\n\ndata_loader:\n  type: IMAGE_PAIR_PREFERENCE_SD_DATA_LOADER\n  dataset:\n    type: IMAGE_PAIR_PREFERENCE_DATASET\n    dataset_dir: output/pokemon_pairs\n  resolution: 512\n  dataloader_num_workers: 4\n\n# General\nmodel: runwayml/stable-diffusion-v1-5\ninitial_lora: output/sd_lora_pokemon/1704824279.2765746/checkpoint_epoch-00000003\ngradient_accumulation_steps: 2\nweight_dtype: float16\nmixed_precision: fp16\ngradient_checkpointing: True\nmax_train_steps: 5000\nsave_every_n_epochs: 10\nsave_every_n_steps: null\nmax_checkpoints: 100\nvalidation_prompts:\n  - A cute yoda pokemon creature.\n  - A cute astronaut pokemon creature.\nvalidate_every_n_epochs: 10\ntrain_batch_size: 4\nnum_validation_images_per_prompt: 2\n"
  },
  {
    "path": "src/invoke_training/sample_configs/flux_lora_1x40gb.yaml",
    "content": "# Training mode: LoRA\n# Base model:    Flux.1-dev\n# Dataset:       Bruce the Gnome\n# GPU:           1 x 40GB\n\ntype: FLUX_LORA\nseed: 1\nbase_output_dir: output/experiments/bruce_the_gnome/flux_lora\n\noptimizer:\n  optimizer_type: AdamW\n  learning_rate: 1e-4\n\nlr_warmup_steps: 1\nlr_scheduler: constant\ntransformer_learning_rate: 4e-4\ntext_encoder_learning_rate: 4e-4\ntrain_text_encoder: False\n\ndata_loader:\n  type: IMAGE_CAPTION_FLUX_DATA_LOADER\n  dataset:\n    type: IMAGE_CAPTION_JSONL_DATASET\n    # Update the jsonl_path field to point to the metadata.jsonl file of the downloaded dataset.\n    jsonl_path: sample_data/bruce_the_gnome/data.jsonl\n  resolution: 768\n  aspect_ratio_buckets:\n    target_resolution: 768\n    start_dim: 384\n    end_dim: 1536\n    divisible_by: 128\n  caption_prefix: \"bruce the gnome\"\n  dataloader_num_workers: 4\n\n# General\nmodel: black-forest-labs/FLUX.1-dev\ngradient_accumulation_steps: 1\nweight_dtype: bfloat16\ngradient_checkpointing: True\n\nmax_train_steps: 350\nsave_every_n_steps: 50\nvalidate_every_n_steps: 50\n\nmax_checkpoints: 10\nvalidation_prompts:\n  - A stuffed gnome at the beach with a pina colada in its hand.\n  - A stuffed gnome reading a book in a cozy library.\n  - A stuffed gnome sitting in a garden surrounded by colorful flowers and butterflies.\ntrain_batch_size: 4\nnum_validation_images_per_prompt: 3"
  },
  {
    "path": "src/invoke_training/sample_configs/sd_lora_baroque_1x8gb.yaml",
    "content": "# Training mode: Finetuning with LoRA\n# Base model:    SD 1.5\n# Dataset:       https://huggingface.co/datasets/InvokeAI/nga-baroque\n# GPU:           1 x 8GB\n\n# Instructions:\n# 1. Download the dataset from https://huggingface.co/datasets/InvokeAI/nga-baroque.\n# 2. Update the `jsonl_path` field in the `data_loader` section to point to the `metadata.jsonl` file of the downloaded\n# dataset.\n\n# Notes:\n# This config file has been optimized for the primary goal of achieving reasonable results *quickly* for demo purposes.\n\ntype: SD_LORA\nseed: 1\nbase_output_dir: output/baroque/sd_lora\n\noptimizer:\n  optimizer_type: Prodigy\n  learning_rate: 1.0\n  weight_decay: 0.01\n  use_bias_correction: True\n  safeguard_warmup: True\n\ndata_loader:\n  type: IMAGE_CAPTION_SD_DATA_LOADER\n  dataset:\n    type: IMAGE_CAPTION_JSONL_DATASET\n    # Update the jsonl_path field to point to the metadata.jsonl file of the downloaded dataset.\n    jsonl_path: data/nga-baroque/metadata.jsonl\n  resolution: 512\n  aspect_ratio_buckets:\n    target_resolution: 512\n    start_dim: 256\n    end_dim: 768\n    divisible_by: 64\n  caption_prefix: \"A baroque painting of\"\n  dataloader_num_workers: 4\n\n# General\nmodel: runwayml/stable-diffusion-v1-5\ngradient_accumulation_steps: 1\nweight_dtype: bfloat16\ngradient_checkpointing: True\n\nmax_train_epochs: 15\nsave_every_n_epochs: 1\nvalidate_every_n_epochs: 1\n\nmax_checkpoints: 5\nvalidation_prompts:\n  - A baroque painting of a woman carrying a basket of fruit.\n  - A baroque painting of a cute Yoda creature.\ntrain_batch_size: 4\nnum_validation_images_per_prompt: 3\n"
  },
  {
    "path": "src/invoke_training/sample_configs/sd_textual_inversion_gnome_1x8gb.yaml",
    "content": "# Training mode: Textual Inversion\n# Base model:    SD v1\n# GPU:           1 x 24GB\n\ntype: SD_TEXTUAL_INVERSION\nseed: 1\nbase_output_dir: output/sd_ti_bruce_the_gnome\n\noptimizer:\n  optimizer_type: AdamW\n  learning_rate: 4e-3\n\nlr_warmup_steps: 200\nlr_scheduler: cosine\n\ndata_loader:\n  type: TEXTUAL_INVERSION_SD_DATA_LOADER\n  dataset:\n    type: IMAGE_DIR_DATASET\n    dataset_dir: \"sample_data/bruce_the_gnome\"\n    keep_in_memory: True\n  caption_preset: object\n  resolution: 512\n  center_crop: True\n  random_flip: False\n  shuffle_caption_delimiter: null\n  aspect_ratio_buckets:\n    target_resolution: 512\n    start_dim: 256\n    end_dim: 768\n    divisible_by: 64\n  dataloader_num_workers: 4\n\n# General\nmodel: runwayml/stable-diffusion-v1-5\nnum_vectors: 4\nplaceholder_token: \"bruce_the_gnome\"\ninitializer_token: \"gnome\"\ncache_vae_outputs: False\ngradient_accumulation_steps: 1\nweight_dtype: bfloat16\ngradient_checkpointing: True\n\nmax_train_steps: 2000\nsave_every_n_steps: 200\nvalidate_every_n_steps: 200\n\nmax_checkpoints: 20\nvalidation_prompts:\n  - A photo of bruce_the_gnome at the beach\n  - A photo of bruce_the_gnome reading a book\ntrain_batch_size: 1\nnum_validation_images_per_prompt: 3\n"
  },
  {
    "path": "src/invoke_training/sample_configs/sdxl_finetune_baroque_1x24gb.yaml",
    "content": "# Training mode: Full Finetuning\n# Base model:    SDXL\n# Dataset:       https://huggingface.co/datasets/InvokeAI/nga-baroque\n# GPU:           1 x 24GB\n\n# Instructions:\n# 1. Download the dataset from https://huggingface.co/datasets/InvokeAI/nga-baroque.\n# 2. Update the `jsonl_path` field in the `data_loader` section to point to the `metadata.jsonl` file of the downloaded\n# dataset.\n\ntype: SDXL_FINETUNE\nseed: 1\nbase_output_dir: output/baroque/sdxl_finetune\n\noptimizer:\n  optimizer_type: AdamW\n  learning_rate: 5e-5\n  weight_decay: 1e-3\n  use_8bit: True\n\nlr_scheduler: constant_with_warmup\nlr_warmup_steps: 500\n\ndata_loader:\n  type: IMAGE_CAPTION_SD_DATA_LOADER\n  dataset:\n    type: IMAGE_CAPTION_JSONL_DATASET\n    # Update the jsonl_path field to point to the metadata.jsonl file of the downloaded dataset.\n    jsonl_path: data/nga-baroque/metadata.jsonl\n  resolution: 1024\n  aspect_ratio_buckets:\n    target_resolution: 1024\n    start_dim: 512\n    end_dim: 1536\n    divisible_by: 128\n  caption_prefix: \"A baroque style painting,\"\n\n# General\nmodel: stabilityai/stable-diffusion-xl-base-1.0\nsave_checkpoint_format: trained_only_diffusers\n# vae_model: madebyollin/sdxl-vae-fp16-fix\nsave_dtype: float16\ngradient_accumulation_steps: 1\nweight_dtype: bfloat16\ngradient_checkpointing: True\ncache_vae_outputs: True\ncache_text_encoder_outputs: True\n\nmax_train_epochs: 50\nsave_every_n_epochs: 3\nvalidate_every_n_epochs: 3\n# We save a max of 1 checkpoint for demo purposes, because the checkpoints take up a lot of disk space.\nmax_checkpoints: 1\n\nvalidation_prompts:\n  - A baroque style painting of a woman carrying a basket of fruit.\n  - A baroque style painting of a cute Yoda creature.\ntrain_batch_size: 4\nnum_validation_images_per_prompt: 3\n"
  },
  {
    "path": "src/invoke_training/sample_configs/sdxl_finetune_robocats_1x24gb.yaml",
    "content": "# Training mode: Full finetune\n# Base model:    SDXL\n# Dataset:       Robocats\n# GPU:           1 x 24GB\n\ntype: SDXL_FINETUNE\nseed: 1\nbase_output_dir: output/robocats/sdxl_finetune\n\noptimizer:\n  optimizer_type: AdamW\n  learning_rate: 2e-5\n  use_8bit: True\n\nlr_scheduler: constant_with_warmup\nlr_warmup_steps: 200\n\ndata_loader:\n  type: IMAGE_CAPTION_SD_DATA_LOADER\n  dataset:\n    type: IMAGE_CAPTION_JSONL_DATASET\n    # Update the jsonl_path field to point to the metadata.jsonl file of the downloaded dataset.\n    jsonl_path: /home/ryan/data/robocats/data.jsonl\n  resolution: 1024\n  aspect_ratio_buckets:\n    target_resolution: 1024\n    start_dim: 512\n    end_dim: 1536\n    divisible_by: 128\n  caption_prefix: \"In the robocat style,\"\n\n# General\nmodel: stabilityai/stable-diffusion-xl-base-1.0\nsave_checkpoint_format: trained_only_diffusers\n# vae_model: madebyollin/sdxl-vae-fp16-fix\nsave_dtype: float16\ngradient_accumulation_steps: 1\nweight_dtype: bfloat16\ngradient_checkpointing: True\ncache_vae_outputs: True\ncache_text_encoder_outputs: True\n\nmax_train_steps: 2000\nvalidate_every_n_steps: 200\nsave_every_n_steps: 2000\n# We save a max of 1 checkpoint for demo purposes, because the checkpoints take up a lot of disk space.\nmax_checkpoints: 1\n\nvalidation_prompts:\n  - In the robocat style, a robotic lion in the jungle.\n  - In the robocat style, a hamburger and fries.\ntrain_batch_size: 4\nnum_validation_images_per_prompt: 3\n"
  },
  {
    "path": "src/invoke_training/sample_configs/sdxl_lora_and_ti_gnome_1x24gb.yaml",
    "content": "# Training mode: Finetuning with LoRA and Textual Inversion\n# Base model:    SDXL 1.0\n# GPU:           1 x 24GB\n\ntype: SDXL_LORA_AND_TEXTUAL_INVERSION\nseed: 1\nbase_output_dir: output/sdxl_lora_and_ti_bruce_the_gnome\n\noptimizer:\n  optimizer_type: AdamW\n  learning_rate: 2e-3\n\nlr_warmup_steps: 200\nlr_scheduler: constant\n\ndata_loader:\n  type: TEXTUAL_INVERSION_SD_DATA_LOADER\n  dataset:\n    type: IMAGE_DIR_DATASET\n    dataset_dir: \"sample_data/bruce_the_gnome\"\n    keep_in_memory: True\n  caption_preset: object\n  resolution: 1024\n  center_crop: True\n  random_flip: False\n  shuffle_caption_delimiter: null\n  dataloader_num_workers: 4\n\n# General\nmodel: stabilityai/stable-diffusion-xl-base-1.0\nvae_model: madebyollin/sdxl-vae-fp16-fix\nnum_vectors: 2\nplaceholder_token: \"bruce_the_gnome\"\ninitializer_token: \"gnome\"\ncache_vae_outputs: False\ngradient_accumulation_steps: 1\nweight_dtype: bfloat16\ngradient_checkpointing: True\n\nmax_train_steps: 2000\nsave_every_n_steps: 200\nvalidate_every_n_steps: 200\n\nmax_checkpoints: 50\nvalidation_prompts:\n  - A photo of bruce_the_gnome at the beach\n  - A photo of bruce_the_gnome reading a book\ntrain_batch_size: 1\nnum_validation_images_per_prompt: 3\n"
  },
  {
    "path": "src/invoke_training/sample_configs/sdxl_lora_baroque_1x24gb.yaml",
    "content": "# Training mode: Finetuning with LoRA\n# Base model:    SDXL 1.0\n# Dataset:       https://huggingface.co/datasets/InvokeAI/nga-baroque\n# GPU:           1 x 24GB\n\n# Instructions:\n# 1. Download the dataset from https://huggingface.co/datasets/InvokeAI/nga-baroque.\n# 2. Update the `jsonl_path` field in the `data_loader` section to point to the `metadata.jsonl` file of the downloaded\n# dataset.\n\n# Notes:\n# This config file has been optimized for the primary goal of achieving reasonable results *quickly* for demo\n# purposes.\n\ntype: SDXL_LORA\nseed: 1\nbase_output_dir: output/baroque/sdxl_lora\n\noptimizer:\n  optimizer_type: AdamW\n  learning_rate: 1e-3\n\ndata_loader:\n  type: IMAGE_CAPTION_SD_DATA_LOADER\n  dataset:\n    type: IMAGE_CAPTION_JSONL_DATASET\n    # Update the jsonl_path field to point to the metadata.jsonl file of the downloaded dataset.\n    jsonl_path: data/nga-baroque/metadata_masks.jsonl\n  resolution: 1024\n  aspect_ratio_buckets:\n    target_resolution: 1024\n    start_dim: 512\n    end_dim: 1536\n    divisible_by: 128\n  caption_prefix: \"A baroque painting of\"\n\n# General\nmodel: stabilityai/stable-diffusion-xl-base-1.0\n# vae_model: madebyollin/sdxl-vae-fp16-fix\ngradient_accumulation_steps: 1\nweight_dtype: bfloat16\ngradient_checkpointing: True\ncache_vae_outputs: True\n\nmax_train_epochs: 16\nsave_every_n_epochs: 2\nvalidate_every_n_epochs: 2\n\nuse_masks: False\n\nmax_checkpoints: 5\nvalidation_prompts:\n  - A baroque painting of a woman carrying a basket of fruit.\n  - A baroque painting of a cute Yoda creature.\ntrain_batch_size: 4\nnum_validation_images_per_prompt: 3\n"
  },
  {
    "path": "src/invoke_training/sample_configs/sdxl_lora_baroque_1x8gb.yaml",
    "content": "# Training mode: Finetuning with LoRA\n# Base model:    SDXL 1.0\n# Dataset:       https://huggingface.co/datasets/InvokeAI/nga-baroque\n# GPU:           1 x 8GB\n\n# Instructions:\n# 1. Download the dataset from https://huggingface.co/datasets/InvokeAI/nga-baroque.\n# 2. Update the `jsonl_path` field in the `data_loader` section to point to the `metadata.jsonl` file of the downloaded\n# dataset.\n\n# Notes:\n# This config file has been optimized for 2 primary goals:\n#   - Minimize VRAM usage so that an SDXL model can be trained with only 8GB of VRAM.\n#   - Achieve reasonable results *quickly* for demo purposes.\n\ntype: SDXL_LORA\nseed: 1\nbase_output_dir: output/baroque/sdxl_lora\n\noptimizer:\n  optimizer_type: Prodigy\n  learning_rate: 1.0\n  weight_decay: 0.01\n  use_bias_correction: True\n  safeguard_warmup: True\n\ndata_loader:\n  type: IMAGE_CAPTION_SD_DATA_LOADER\n  dataset:\n    type: IMAGE_CAPTION_JSONL_DATASET\n    # Update the jsonl_path field to point to the metadata.jsonl file of the downloaded dataset.\n    jsonl_path: data/nga-baroque/metadata.jsonl\n  # TODO: More optimizations are needed to train at full 1024x1024 resolution with 8GB VRAM.\n  resolution: 512\n  # aspect_ratio_buckets:\n  #   target_resolution: 1024\n  #   start_dim: 512\n  #   end_dim: 1536\n  #   divisible_by: 128\n  caption_prefix: \"A baroque painting of\"\n\n# General\nmodel: stabilityai/stable-diffusion-xl-base-1.0\nvae_model: madebyollin/sdxl-vae-fp16-fix\ntrain_text_encoder: False\ncache_text_encoder_outputs: True\nenable_cpu_offload_during_validation: True\ngradient_accumulation_steps: 4\nweight_dtype: bfloat16\ngradient_checkpointing: True\n\nmax_train_epochs: 6\nsave_every_n_epochs: 1\nvalidate_every_n_epochs: 1\n\nmax_checkpoints: 5\nvalidation_prompts:\n  - A baroque painting of a woman carrying a basket of fruit.\n  - A baroque painting of a cute Yoda creature.\ntrain_batch_size: 1\nnum_validation_images_per_prompt: 3\n"
  },
  {
    "path": "src/invoke_training/sample_configs/sdxl_lora_masks_gnome_1x24gb.yaml",
    "content": "# Training mode: LoRA with masks\n# Base model:    SDXL 1.0\n# Dataset:       Bruce the Gnome\n# GPU:           1 x 24GB\n\ntype: SDXL_LORA\nseed: 1\nbase_output_dir: output/bruce/sdxl_lora_masks\n\noptimizer:\n  optimizer_type: AdamW\n  learning_rate: 7e-5\n\nlr_scheduler: constant_with_warmup\nlr_warmup_steps: 50\n\ndata_loader:\n  type: IMAGE_CAPTION_SD_DATA_LOADER\n  dataset:\n    type: IMAGE_CAPTION_JSONL_DATASET\n    jsonl_path: sample_data/bruce_the_gnome/data_masks.jsonl\n  resolution: 1024\n  aspect_ratio_buckets:\n    target_resolution: 1024\n    start_dim: 512\n    end_dim: 1536\n    divisible_by: 128\n\n# General\nmodel: stabilityai/stable-diffusion-xl-base-1.0\n# vae_model: madebyollin/sdxl-vae-fp16-fix\ngradient_accumulation_steps: 1\nweight_dtype: bfloat16\ngradient_checkpointing: True\ncache_vae_outputs: True\n\nmax_train_steps: 500\nsave_every_n_steps: 50\nvalidate_every_n_steps: 50\n\nuse_masks: True\n\nmax_checkpoints: 5\nvalidation_prompts:\n  - A stuffed gnome at the beach with a pina colada in its hand.\n  - A stuffed gnome reading a book in a cozy library.\ntrain_batch_size: 4\nnum_validation_images_per_prompt: 3\n"
  },
  {
    "path": "src/invoke_training/sample_configs/sdxl_textual_inversion_gnome_1x24gb.yaml",
    "content": "# Training mode: Textual Inversion\n# Base model:    SDXL\n# GPU:           1 x 24GB\n\ntype: SDXL_TEXTUAL_INVERSION\nseed: 1\nbase_output_dir: output/bruce/sdxl_ti\n\noptimizer:\n  optimizer_type: AdamW\n  learning_rate: 2e-3\n\nlr_warmup_steps: 200\nlr_scheduler: cosine\n\ndata_loader:\n  type: TEXTUAL_INVERSION_SD_DATA_LOADER\n  dataset:\n    type: IMAGE_DIR_DATASET\n    dataset_dir: \"sample_data/bruce_the_gnome\"\n    keep_in_memory: True\n  caption_preset: object\n  resolution: 1024\n  center_crop: True\n  random_flip: False\n  shuffle_caption_delimiter: null\n  dataloader_num_workers: 4\n\n# General\nmodel: stabilityai/stable-diffusion-xl-base-1.0\nvae_model: madebyollin/sdxl-vae-fp16-fix\nnum_vectors: 4\nplaceholder_token: \"bruce_the_gnome\"\ninitializer_token: \"gnome\"\ncache_vae_outputs: False\ngradient_accumulation_steps: 1\nweight_dtype: bfloat16\ngradient_checkpointing: True\n\nmax_train_steps: 2000\nsave_every_n_steps: 200\nvalidate_every_n_steps: 200\n\nmax_checkpoints: 20\nvalidation_prompts:\n  - A photo of bruce_the_gnome at the beach\n  - A photo of bruce_the_gnome reading a book\ntrain_batch_size: 1\nnum_validation_images_per_prompt: 3\n"
  },
  {
    "path": "src/invoke_training/sample_configs/sdxl_textual_inversion_masks_gnome_1x24gb.yaml",
    "content": "# Training mode: Textual Inversion with Masks\n# Base model:    SDXL\n# GPU:           1 x 24GB\n\ntype: SDXL_TEXTUAL_INVERSION\nseed: 1\nbase_output_dir: output/bruce/sdxl_ti_masks\n\noptimizer:\n  optimizer_type: AdamW\n  learning_rate: 5e-4\n\nlr_scheduler: constant_with_warmup\nlr_warmup_steps: 50\n\ndata_loader:\n  type: TEXTUAL_INVERSION_SD_DATA_LOADER\n  dataset:\n    type: IMAGE_CAPTION_JSONL_DATASET\n    jsonl_path: sample_data/bruce_the_gnome/data_masks.jsonl\n    keep_in_memory: True\n  caption_preset: object\n  resolution: 1024\n  center_crop: True\n  random_flip: False\n  shuffle_caption_delimiter: null\n\n# General\nmodel: stabilityai/stable-diffusion-xl-base-1.0\nnum_vectors: 16\nplaceholder_token: \"bruce_the_gnome\"\ninitializer_token: \"gnome\"\ncache_vae_outputs: False\ngradient_accumulation_steps: 1\nweight_dtype: bfloat16\ngradient_checkpointing: True\n\nmax_train_steps: 500\nsave_every_n_steps: 50\nvalidate_every_n_steps: 50\n\nuse_masks: True\n\nmax_checkpoints: 10\nvalidation_prompts:\n  - A photo of bruce_the_gnome at the beach with a pina colada in its hand.\n  - A photo of bruce_the_gnome reading a book in a cozy library.\ntrain_batch_size: 4\nnum_validation_images_per_prompt: 3\n"
  },
  {
    "path": "src/invoke_training/scripts/__init__.py",
    "content": ""
  },
  {
    "path": "src/invoke_training/scripts/_experimental/auto_caption/auto_caption_images.py",
    "content": "import argparse\nimport json\nfrom pathlib import Path\n\nimport torch\nimport torch.utils.data\nfrom PIL import Image\nfrom tqdm import tqdm\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\n\nfrom invoke_training.scripts.utils.image_dir_dataset import ImageDirDataset, list_collate_fn\n\n\ndef select_device_and_dtype(force_cpu: bool = False) -> tuple[torch.device, torch.dtype]:\n    if force_cpu:\n        return torch.device(\"cpu\"), torch.float32\n\n    if torch.cuda.is_available():\n        return torch.device(\"cuda\"), torch.float16\n\n    return torch.device(\"cpu\"), torch.float32\n\n\ndef process_images(images: list[Image.Image], prompt: str, moondream, tokenizer) -> list[str]:\n    # image_embeds = moondream.encode_image(image).to(device=device)\n    # answer = moondream.answer_question(image_embeds, prompt, tokenizer)\n    answers = moondream.batch_answer(\n        images=images,\n        prompts=[prompt] * len(images),\n        tokenizer=tokenizer,\n    )\n    return answers\n\n\ndef main(\n    prompt: str,\n    use_cpu: bool,\n    batch_size: int,\n    output_path: str,\n    dataset: torch.utils.data.Dataset,\n):\n    device, dtype = select_device_and_dtype(use_cpu)\n    print(f\"Using device: {device}\")\n    print(f\"Using dtype: {dtype}\")\n\n    # Check that the output file does not already exist before spending time generating captions.\n    out_path = Path(output_path)\n    if out_path.exists():\n        raise FileExistsError(f\"Output file already exists: {out_path}\")\n\n    # Load the model.\n    model_id = \"vikhyatk/moondream2\"\n    model_revision = \"2024-04-02\"\n    tokenizer = AutoTokenizer.from_pretrained(model_id, revision=model_revision)\n    # TODO(ryand): Warn about security implications of trust_remote_code=True.\n    moondream_model: AutoModelForCausalLM = AutoModelForCausalLM.from_pretrained(\n        model_id, trust_remote_code=True, revision=model_revision\n    ).to(device=device, dtype=dtype)\n    moondream_model.eval()\n\n    data_loader = torch.utils.data.DataLoader(\n        dataset, collate_fn=list_collate_fn, batch_size=batch_size, drop_last=False\n    )\n\n    results = []\n    for image_batch in tqdm(data_loader):\n        image_paths = image_batch[\"image_path\"]\n        answers = process_images(image_batch[\"image\"], prompt, moondream_model, tokenizer)\n        for image_path, answer in zip(image_paths, answers, strict=True):\n            results.append({\"image\": image_path, \"text\": answer})\n\n    # Check that the output file does not exist immediately before writing to it.\n    if out_path.exists():\n        raise FileExistsError(f\"Output file already exists: {out_path}\")\n\n    with open(out_path, \"w\") as outfile:\n        for entry in results:\n            json.dump(entry, outfile)\n            outfile.write(\"\\n\")\n    print(\"Output saved to output.jsonl.\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Run the moondream captioning model on a directory of images.\")\n    parser.add_argument(\"--dir\", type=str, required=True, help=\"Directory containing images.\")\n    parser.add_argument(\n        \"--prompt\",\n        type=str,\n        default=\"Describe this image in 20 words or less.\",\n        help=\"(Optional) Prompt for the model.\",\n    )\n    parser.add_argument(\n        \"--cpu\",\n        action=\"store_true\",\n        default=False,\n        help=\"Force use of CPU instead of GPU. If not set, a GPU will be used if available.\",\n    )\n    parser.add_argument(\n        \"--batch-size\",\n        type=int,\n        default=4,\n        help=\"Batch size for processing images. To maximize speed, set this to the largest value that fits in GPU \"\n        \"memory.\",\n    )\n    parser.add_argument(\n        \"--output\",\n        type=str,\n        default=\"output.jsonl\",\n        help=\"(Optional) Path to the output file. Default is 'output.jsonl'.\",\n    )\n    args = parser.parse_args()\n\n    # Prepare the dataset.\n    dataset = ImageDirDataset(args.dir)\n    print(f\"Found {len(dataset)} images in '{args.dir}'.\")\n\n    main(args.prompt, args.cpu, args.batch_size, args.output, dataset)\n"
  },
  {
    "path": "src/invoke_training/scripts/_experimental/masks/clipseg.py",
    "content": "import torch\nfrom PIL import Image\nfrom transformers import AutoProcessor, CLIPSegForImageSegmentation, CLIPSegProcessor\n\n\ndef load_clipseg_model() -> tuple[CLIPSegProcessor, CLIPSegForImageSegmentation]:\n    # Load the model.\n    clipseg_processor = AutoProcessor.from_pretrained(\"CIDAS/clipseg-rd64-refined\")\n    clipseg_model = CLIPSegForImageSegmentation.from_pretrained(\"CIDAS/clipseg-rd64-refined\")\n    return clipseg_processor, clipseg_model\n\n\ndef run_clipseg(\n    images: list[Image.Image],\n    prompt: str,\n    clipseg_processor,\n    clipseg_model,\n    clipseg_temp: float,\n    device: torch.device,\n) -> list[Image.Image]:\n    \"\"\"Run ClipSeg on a list of images.\n\n    Args:\n        clipseg_temp (float): Temperature applied to the CLIPSeg logits. Higher values cause the mask to be 'smoother'\n            and include more of the background. Recommended range: 0.5 to 1.0.\n    \"\"\"\n\n    orig_image_sizes = [img.size for img in images]\n\n    prompts = [prompt] * len(images)\n    # TODO(ryand): Should we run the same image with and without the prompt to normalize for any bias in the model?\n    inputs = clipseg_processor(text=prompts, images=images, padding=True, return_tensors=\"pt\")\n\n    # Move inputs and clipseg_model to the correct device and dtype.\n    inputs = {k: v.to(device=device) for k, v in inputs.items()}\n    clipseg_model = clipseg_model.to(device=device)\n\n    outputs = clipseg_model(**inputs)\n\n    logits = outputs.logits\n    if logits.ndim == 2:\n        # The model squeezes the batch dimension if it's 1, so we need to unsqueeze it.\n        logits = logits.unsqueeze(0)\n    probs = torch.nn.functional.sigmoid(logits / clipseg_temp)\n    # Normalize each mask to 0-255. Note that each mask is normalized independently.\n    probs = 255 * probs / probs.amax(dim=(1, 2), keepdim=True)\n\n    # Make mask greyscale.\n    masks: list[Image.Image] = []\n    for prob, orig_size in zip(probs, orig_image_sizes, strict=True):\n        mask = Image.fromarray(prob.cpu().numpy()).convert(\"L\")\n        mask = mask.resize(orig_size)\n        masks.append(mask)\n\n    return masks\n\n\ndef select_device() -> torch.device:\n    if torch.cuda.is_available():\n        return torch.device(\"cuda\")\n\n    return torch.device(\"cpu\")\n"
  },
  {
    "path": "src/invoke_training/scripts/_experimental/masks/generate_masks.py",
    "content": "import argparse\nfrom pathlib import Path\n\nimport torch\nimport torch.utils.data\nfrom tqdm import tqdm\n\nfrom invoke_training.scripts._experimental.masks.clipseg import load_clipseg_model, run_clipseg, select_device\nfrom invoke_training.scripts.utils.image_dir_dataset import ImageDirDataset, list_collate_fn\n\n\n@torch.no_grad()\ndef generate_masks(image_dir: str, prompt: str, clipseg_temp: float, batch_size: int):\n    \"\"\"Generate masks for a directory of images.\n\n    Args:\n        image_dir (str): The directory containing images.\n        prompt (str): A short description of the thing you want to mask. E.g. 'a cat'.\n        clipseg_temp (float): Temperature applied to the CLIPSeg logits. Higher values cause the mask to be 'smoother'.\n            and include more of the background. Recommended range: 0.5 to 1.0.\n        batch_size (int): Batch size to use when processing images. Larger batch sizes may be faster but require more.\n    \"\"\"\n    device = select_device()\n\n    clipseg_processor, clipseg_model = load_clipseg_model()\n\n    # Prepare the dataloader.\n    dataset = ImageDirDataset(image_dir)\n    print(f\"Found {len(dataset)} images in '{image_dir}'.\")\n    data_loader = torch.utils.data.DataLoader(\n        dataset, collate_fn=list_collate_fn, batch_size=batch_size, drop_last=False\n    )\n\n    # Process each image.\n    for batch in tqdm(data_loader):\n        masks = run_clipseg(\n            images=batch[\"image\"],\n            prompt=prompt,\n            clipseg_processor=clipseg_processor,\n            clipseg_model=clipseg_model,\n            clipseg_temp=clipseg_temp,\n            device=device,\n        )\n\n        for image_path, mask in zip(batch[\"image_path\"], masks, strict=True):\n            image_path = Path(image_path)\n            out_path = image_path.parent / \"masks\" / (image_path.stem + \".png\")\n            out_path.parent.mkdir(exist_ok=True, parents=True)\n            mask.save(out_path)\n            print(f\"Saved mask to: {out_path}\")\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Generate masks for a directory of images.\")\n    parser.add_argument(\"--dir\", type=str, required=True, help=\"Directory containing images.\")\n    parser.add_argument(\n        \"--prompt\",\n        required=True,\n        type=str,\n        help=\"A short description of the thing you want to mask. E.g. 'a cat'.\",\n    )\n    parser.add_argument(\n        \"--clipseg-temp\",\n        type=float,\n        default=1.0,\n        help=\"Temperature applied to the CLIPSeg logits. Higher values cause the mask to be 'smoother' and include \"\n        \"more of the background. Recommended range: 0.5 to 1.0.\",\n    )\n    parser.add_argument(\n        \"--batch-size\",\n        type=int,\n        default=4,\n        help=\"Batch size to use when processing images. Larger batch sizes may be faster but require more memory.\",\n    )\n    args = parser.parse_args()\n\n    generate_masks(image_dir=args.dir, prompt=args.prompt, clipseg_temp=args.clipseg_temp, batch_size=args.batch_size)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/invoke_training/scripts/_experimental/masks/generate_masks_for_jsonl_dataset.py",
    "content": "import argparse\nfrom pathlib import Path\n\nimport torch\nimport torch.utils.data\nfrom tqdm import tqdm\n\nfrom invoke_training._shared.data.datasets.image_caption_jsonl_dataset import (\n    MASK_COLUMN_DEFAULT,\n    ImageCaptionJsonlDataset,\n)\nfrom invoke_training._shared.utils.jsonl import load_jsonl, save_jsonl\nfrom invoke_training.scripts._experimental.masks.clipseg import load_clipseg_model, run_clipseg, select_device\n\n\ndef collate_fn(examples):\n    \"\"\"A collate_fn that combines images into a list rather than stacking into a tensor.\"\"\"\n    return {\n        \"id\": [example[\"id\"] for example in examples],\n        \"image\": [example[\"image\"] for example in examples],\n    }\n\n\ndef validate_out_json_path(out_json_path: str | Path):\n    out_json_path = Path(out_json_path)\n    if out_json_path.exists():\n        raise FileExistsError(f\"Output jsonl file '{out_json_path}' already exists.\")\n    if not out_json_path.suffix == \".jsonl\":\n        raise ValueError(f\"Output jsonl file '{out_json_path}' must have a .jsonl extension.\")\n\n\n@torch.no_grad()\ndef generate_masks(\n    in_jsonl_path: str,\n    out_jsonl_path: str,\n    image_column: str,\n    caption_column: str,\n    prompt: str,\n    clipseg_temp: float,\n    batch_size: int,\n):\n    \"\"\"Generate masks for a .jsonl dataset.\"\"\"\n    # Load the .jsonl dataset.\n    dataset = ImageCaptionJsonlDataset(\n        jsonl_path=in_jsonl_path, image_column=image_column, caption_column=caption_column\n    )\n    print(f\"Loaded dataset from '{in_jsonl_path}' with {len(dataset)} images.\")\n    data_loader = torch.utils.data.DataLoader(dataset, collate_fn=collate_fn, batch_size=batch_size, drop_last=False)\n\n    # We also need the raw jsonl data.\n    jsonl_data = load_jsonl(in_jsonl_path)\n\n    # Prepare output locations.\n    out_jsonl_path = Path(out_jsonl_path)\n    validate_out_json_path(out_jsonl_path)\n    out_masks_dir = out_jsonl_path.parent / \"masks\"\n    out_masks_dir.mkdir(exist_ok=False, parents=True)\n\n    clipseg_processor, clipseg_model = load_clipseg_model()\n\n    device = select_device()\n\n    # Process each image.\n    for batch in tqdm(data_loader):\n        masks = run_clipseg(\n            images=batch[\"image\"],\n            prompt=prompt,\n            clipseg_processor=clipseg_processor,\n            clipseg_model=clipseg_model,\n            clipseg_temp=clipseg_temp,\n            device=device,\n        )\n\n        for id, mask in zip(batch[\"id\"], masks, strict=True):\n            orig_image_path = Path(jsonl_data[int(id)][image_column])\n            out_mask_path: Path = out_masks_dir / (orig_image_path.stem + \".png\")\n            mask.save(out_mask_path)\n            print(f\"Saved mask to: {out_mask_path}\")\n\n            # Infer whether the mask path should be relative or absolute based on the image path.\n            if orig_image_path.is_absolute():\n                jsonl_data[int(id)][MASK_COLUMN_DEFAULT] = str(out_mask_path.resolve())\n            else:\n                jsonl_data[int(id)][MASK_COLUMN_DEFAULT] = str(out_mask_path.relative_to(out_jsonl_path.parent))\n\n    # Save the modified jsonl data.\n    validate_out_json_path(out_jsonl_path)\n    save_jsonl(jsonl_data, out_jsonl_path)\n    print(f\"Saved modified jsonl data to: {out_jsonl_path}\")\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Generate masks for a jsonl dataset.\")\n    parser.add_argument(\"--in-jsonl\", type=str, required=True, help=\"Path to the dataset .jsonl file.\")\n    parser.add_argument(\n        \"--out-jsonl\",\n        type=str,\n        required=True,\n        help=\"Path to save the modified .jsonl file to. A masks/ directory will be created in the same directory as \"\n        \"the .jsonl file to store the masks. The choice of whether to use relative or absolute paths for the masks is \"\n        \"inferred from the image paths.\",\n    )\n    parser.add_argument(\n        \"--image-column\",\n        type=str,\n        default=\"image\",\n        help=\"The name of the column containing image paths in the input .jsonl file.\",\n    )\n    parser.add_argument(\n        \"--caption-column\",\n        type=str,\n        default=\"text\",\n        help=\"The name of the column containing captions in the input .jsonl file.\",\n    )\n    parser.add_argument(\n        \"--prompt\",\n        required=True,\n        type=str,\n        help=\"A short description of the thing you want to mask. E.g. 'a cat'.\",\n    )\n    parser.add_argument(\n        \"--clipseg-temp\",\n        type=float,\n        default=1.0,\n        help=\"Temperature applied to the CLIPSeg logits. Higher values cause the mask to be 'smoother' and include \"\n        \"more of the background. Recommended range: 0.5 to 1.0.\",\n    )\n    parser.add_argument(\n        \"--batch-size\",\n        type=int,\n        default=4,\n        help=\"Batch size to use when processing images. Larger batch sizes may be faster but require more memory.\",\n    )\n    args = parser.parse_args()\n\n    generate_masks(\n        in_jsonl_path=args.in_jsonl,\n        out_jsonl_path=args.out_jsonl,\n        image_column=args.image_column,\n        caption_column=args.caption_column,\n        prompt=args.prompt,\n        clipseg_temp=args.clipseg_temp,\n        batch_size=args.batch_size,\n    )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/invoke_training/scripts/_experimental/rank_images.py",
    "content": "import argparse\nimport os\nimport time\nfrom pathlib import Path\nfrom typing import Literal\n\nimport gradio as gr\nimport yaml\nfrom pydantic import TypeAdapter\n\nfrom invoke_training._shared.data.datasets.image_pair_preference_dataset import ImagePairPreferenceDataset\nfrom invoke_training.config.pipeline_config import PipelineConfig\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Choose preferences from image pairs.\")\n    parser.add_argument(\n        \"-c\",\n        \"--cfg-file\",\n        type=Path,\n        required=True,\n        help=\"Path to the YAML training config file. The internal dataset config will be used.\",\n    )\n\n    return parser.parse_args()\n\n\ndef clip(val, min_val, max_val):\n    return max(min(val, max_val), min_val)\n\n\ndef main():\n    args = parse_args()\n\n    # Load YAML config file.\n    with open(args.cfg_file, \"r\") as f:\n        cfg = yaml.safe_load(f)\n\n    pipeline_adapter: TypeAdapter[PipelineConfig] = TypeAdapter(PipelineConfig)\n    train_config = pipeline_adapter.validate_python(cfg)\n    dataset_config = train_config.data_loader.dataset\n    assert dataset_config.type == \"IMAGE_PAIR_PREFERENCE_DATASET\"\n    metadata = ImagePairPreferenceDataset.load_metadata(dataset_config.dataset_dir)\n\n    print(f\"Launching UI to rank image pairs in '{dataset_config.dataset_dir}'.\")\n\n    def get_img_path(index: int, image_id: Literal[\"image_0\", \"image_1\"]):\n        return os.path.join(dataset_config.dataset_dir, metadata[index][image_id])\n\n    def get_state(index: int):\n        img_0 = get_img_path(index, \"image_0\")\n        img_1 = get_img_path(index, \"image_1\")\n        prefer_0 = metadata[index][\"prefer_0\"]\n        prefer_1 = metadata[index][\"prefer_1\"]\n        caption = metadata[index][\"prompt\"]\n        return [index, img_0, img_1, prefer_0, prefer_1, caption]\n\n    def go_to_index(index: int):\n        new_index = clip(index, 0, len(metadata) - 1)\n        return get_state(new_index)\n\n    def mark_prefer_0(index: int):\n        metadata[index][\"prefer_0\"] = True\n        metadata[index][\"prefer_1\"] = False\n        # Step to next example.\n        return go_to_index(index + 1)\n\n    def mark_prefer_1(index: int):\n        metadata[index][\"prefer_0\"] = False\n        metadata[index][\"prefer_1\"] = True\n        # Step to next example.\n        return go_to_index(index + 1)\n\n    def save_metadata():\n        timestamp = str(time.time()).replace(\".\", \"_\")\n        metadata_file = f\"metadata-{timestamp}.jsonl\"\n        metadata_path = ImagePairPreferenceDataset.save_metadata(\n            metadata=metadata, dataset_dir=dataset_config.dataset_dir, metadata_file=metadata_file\n        )\n        print(f\"Saved metadata to '{metadata_path}'.\")\n\n    with gr.Blocks() as demo:\n        index = gr.Number(value=-1, precision=0)\n        with gr.Row():\n            img_0 = gr.Image(type=\"filepath\", label=\"Image 0\", interactive=False)\n            img_1 = gr.Image(type=\"filepath\", label=\"Image 1\", interactive=False)\n\n        caption = gr.Textbox(interactive=False, show_label=False)\n\n        with gr.Row():\n            prefer_0 = gr.Checkbox(label=\"Prefer 0\", interactive=False)\n            prefer_1 = gr.Checkbox(label=\"Prefer 1\", interactive=False)\n\n        with gr.Row():\n            mark_prefer_0_button = gr.Button(\"Prefer 0\")\n            mark_prefer_1_button = gr.Button(\"Prefer 1\")\n\n        save_metadata_button = gr.Button(\"Save Metadata\")\n\n        index.change(go_to_index, inputs=[index], outputs=[index, img_0, img_1, prefer_0, prefer_1, caption])\n        mark_prefer_0_button.click(\n            mark_prefer_0, inputs=[index], outputs=[index, img_0, img_1, prefer_0, prefer_1, caption]\n        )\n        mark_prefer_1_button.click(\n            mark_prefer_1, inputs=[index], outputs=[index, img_0, img_1, prefer_0, prefer_1, caption]\n        )\n        save_metadata_button.click(save_metadata)\n\n    demo.launch()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/invoke_training/scripts/convert_sd_lora_to_kohya_format.py",
    "content": "import argparse\nfrom pathlib import Path\n\nimport torch\n\nfrom invoke_training._shared.stable_diffusion.lora_checkpoint_utils import (\n    convert_sd_peft_checkpoint_to_kohya_state_dict,\n)\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description=\"Convert a Stable Diffusion LoRA checkpoint in PEFT format to kohya format.\"\n    )\n    parser.add_argument(\n        \"--src-ckpt-dir\",\n        type=str,\n        required=True,\n        help=\"Path to the source checkpoint directory.\",\n    )\n    parser.add_argument(\n        \"--dst-ckpt-file\",\n        type=str,\n        required=True,\n        help=\"Path to the destination Kohya checkpoint file.\",\n    )\n    parser.add_argument(\n        \"--dtype\",\n        type=str,\n        default=\"fp16\",\n        help=\"The precision to save the kohya state dict in. One of ['fp16', 'fp32'].\",\n    )\n    return parser.parse_args()\n\n\ndef main():\n    args = parse_args()\n\n    in_checkpoint_dir = Path(args.src_ckpt_dir)\n    out_checkpoint_file = Path(args.dst_ckpt_file)\n\n    if args.dtype == \"fp32\":\n        dtype = torch.float32\n    elif args.dtype == \"fp16\":\n        dtype = torch.float16\n    else:\n        raise ValueError(f\"Unsupported --dtype = '{args.dtype}'.\")\n\n    convert_sd_peft_checkpoint_to_kohya_state_dict(\n        in_checkpoint_dir=in_checkpoint_dir, out_checkpoint_file=out_checkpoint_file, dtype=dtype\n    )\n\n    print(f\"Saved kohya checkpoint to '{out_checkpoint_file}'.\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/invoke_training/scripts/invoke_generate_images.py",
    "content": "import argparse\nfrom pathlib import Path\n\nfrom invoke_training._shared.stable_diffusion.model_loading_utils import PipelineVersionEnum\nfrom invoke_training._shared.tools.generate_images import generate_images\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(\n        description=\"Generate a dataset of images from a single prompt. (Typically used to generate prior \"\n        \"preservation/regularization datasets.)\"\n    )\n    parser.add_argument(\n        \"-o\",\n        \"--out-dir\",\n        type=str,\n        required=True,\n        help=\"Path to the directory where the images will be stored.\",\n    )\n    parser.add_argument(\n        \"-m\",\n        \"--model\",\n        type=str,\n        required=True,\n        help=\"Name or path of the diffusers model to generate images with. Can be in diffusers format, or a single \"\n        \"stable diffusion checkpoint file. (E.g. 'runwayml/stable-diffusion-v1-5', \"\n        \"'stabilityai/stable-diffusion-xl-base-1.0', '/path/to/realisticVisionV51_v51VAE.safetensors', etc. )\",\n    )\n    parser.add_argument(\n        \"-v\",\n        \"--variant\",\n        type=str,\n        required=False,\n        default=None,\n        help=\"The Hugging Face Hub model variant to use. Only applies if `--model` is a Hugging Face Hub model name.\",\n    )\n    parser.add_argument(\n        \"-l\",\n        \"--lora\",\n        type=str,\n        nargs=\"*\",\n        help=\"LoRA models to apply to the base model. The LoRA weight can optionally be provided after a colon \"\n        \"separator. E.g. `-l path/to/lora.bin:0.5 -l path/to/lora_2.safetensors`. \",\n    )\n    parser.add_argument(\n        \"--ti\",\n        type=str,\n        nargs=\"*\",\n        help=\"Paths(s) to Textual Inversion embeddings to apply to the base model.\",\n    )\n    parser.add_argument(\n        \"--sd-version\",\n        type=str,\n        required=True,\n        help=\"The Stable Diffusion version. One of: ['SD', 'SDXL'].\",\n    )\n\n    # One of --prompt or --prompt-file.\n    group = parser.add_mutually_exclusive_group(required=True)\n    group.add_argument(\"-p\", \"--prompt\", type=str, help=\"The prompt to use for image generation.\")\n    group.add_argument(\"--prompt-file\", type=str, help=\"A file containing prompts. One per line.\")\n\n    parser.add_argument(\n        \"--set-size\", type=int, default=1, help=\"The number of images generated in each 'set' for a given prompt.\"\n    )\n    parser.add_argument(\"--num-sets\", type=int, default=1, help=\"The number of 'sets' to generate for each prompt.\")\n    parser.add_argument(\n        \"--height\",\n        type=int,\n        required=True,\n        help=\"The height of the generated images in pixels.\",\n    )\n    parser.add_argument(\n        \"--width\",\n        type=int,\n        required=True,\n        help=\"The width of the generated images in pixels.\",\n    )\n    parser.add_argument(\n        \"-s\",\n        \"--seed\",\n        type=int,\n        default=0,\n        help=\"Seed for repeatability.\",\n    )\n    parser.add_argument(\n        \"--enable-cpu-offload\",\n        default=False,\n        action=\"store_true\",\n        help=\"If True, models will be loaded onto the GPU one by one to conserve VRAM.\",\n    )\n    return parser.parse_args()\n\n\ndef parse_lora_args(lora_args: list[str] | None) -> list[tuple[Path, int]]:\n    loras: list[tuple[Path, int]] = []\n\n    lora_args = lora_args or []\n    for lora in lora_args:\n        lora_split = lora.split(\":\")\n\n        if len(lora_split) == 1:\n            # If weight is not specified, assume 1.0.\n            loras.append((Path(lora_split[0]), 1.0))\n        elif len(lora_split) == 2:\n            loras.append((Path(lora_split[0]), float(lora_split[1])))\n        else:\n            raise ValueError(f\"Invalid lora argument syntax: '{lora}'.\")\n\n    return loras\n\n\ndef parse_prompt_file(prompt_file: str) -> list[str]:\n    with open(prompt_file) as f:\n        prompts = f.readlines()\n\n    return [p.strip() for p in prompts]\n\n\ndef main():\n    args = parse_args()\n\n    loras = parse_lora_args(args.lora)\n\n    if args.prompt:\n        prompts = [args.prompt]\n    else:\n        prompts = parse_prompt_file(args.prompt_file)\n\n    print(f\"Generating {args.num_sets} sets of {args.set_size} images for {len(prompts)} prompts in '{args.out_dir}'.\")\n    generate_images(\n        out_dir=args.out_dir,\n        model=args.model,\n        hf_variant=args.variant,\n        pipeline_version=PipelineVersionEnum(args.sd_version),\n        prompts=prompts,\n        set_size=args.set_size,\n        num_sets=args.num_sets,\n        height=args.height,\n        width=args.width,\n        loras=loras,\n        ti_embeddings=args.ti,\n        seed=args.seed,\n        enable_cpu_offload=args.enable_cpu_offload,\n    )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/invoke_training/scripts/invoke_train.py",
    "content": "import argparse\nfrom pathlib import Path\n\nimport yaml\nfrom pydantic import TypeAdapter\n\nfrom invoke_training.config.pipeline_config import PipelineConfig\nfrom invoke_training.pipelines.invoke_train import train\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Run a training pipeline.\")\n    parser.add_argument(\n        \"-c\",\n        \"--cfg-file\",\n        type=Path,\n        required=True,\n        help=\"Path to the YAML training config file.\",\n    )\n    return parser.parse_args()\n\n\ndef main():\n    args = parse_args()\n\n    # Load YAML config file.\n    with open(args.cfg_file, \"r\") as f:\n        cfg = yaml.safe_load(f)\n\n    pipeline_adapter: TypeAdapter[PipelineConfig] = TypeAdapter(PipelineConfig)\n    train_config = pipeline_adapter.validate_python(cfg)\n\n    train(train_config)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/invoke_training/scripts/invoke_train_ui.py",
    "content": "import argparse\n\nimport uvicorn\n\nfrom invoke_training.ui.app import build_app\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--host\",\n        default=\"127.0.0.1\",\n        help=\"The server host. Set `--host 0.0.0.0` to make the app available on your network.\",\n    )\n    parser.add_argument(\"--port\", default=8000, type=int, help=\"The server port.\")\n    args = parser.parse_args()\n\n    app = build_app()\n    uvicorn.run(\n        app,\n        host=args.host,\n        port=args.port,\n    )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/invoke_training/scripts/invoke_visualize_data_loading.py",
    "content": "import argparse\nimport os\nimport time\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nimport yaml\nfrom PIL import Image\nfrom pydantic import TypeAdapter\nfrom torch.utils.data import DataLoader\n\nfrom invoke_training._shared.data.data_loaders.dreambooth_sd_dataloader import build_dreambooth_sd_dataloader\nfrom invoke_training._shared.data.data_loaders.image_caption_sd_dataloader import (\n    build_image_caption_sd_dataloader,\n)\nfrom invoke_training._shared.data.data_loaders.image_pair_preference_sd_dataloader import (\n    build_image_pair_preference_sd_dataloader,\n)\nfrom invoke_training._shared.data.data_loaders.textual_inversion_sd_dataloader import (\n    build_textual_inversion_sd_dataloader,\n)\nfrom invoke_training.config.pipeline_config import PipelineConfig\n\n\ndef save_image(torch_image: torch.Tensor, out_path: Path):\n    \"\"\"Save a torch image to disk.\n\n    Args:\n        torch_image (torch.Tensor): Shape=(C, H, W). Pixel values are expected to be normalized in the range\n            [-1.0, 1.0].\n        out_path (Path): The output path.\n    \"\"\"\n    np_image = torch_image.clone().detach().cpu().numpy()\n\n    # Convert back to range [0, 1.0].\n    np_image = np_image * 0.5 + 0.5\n    # Convert back to range [0, 255].\n    np_image *= 255\n    # Move channel axis from first dimension to last dimension.\n    np_image = np.moveaxis(np_image, 0, -1)\n\n    # Cast to np.uint8.\n    np_image = np_image.astype(np.uint8)\n\n    Image.fromarray(np_image).save(out_path)\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description=\"Visualize data loading from a pipeline config.\")\n    parser.add_argument(\n        \"-c\",\n        \"--cfg-file\",\n        type=Path,\n        required=True,\n        help=\"Path to the YAML training config file.\",\n    )\n\n    return parser.parse_args()\n\n\ndef visualize(data_loader: DataLoader):\n    out_dir = Path(f\"out_{str(time.time()).replace('.', '-')}/\")\n    os.makedirs(out_dir)\n\n    for batch_idx, batch in enumerate(data_loader):\n        print(f\"Batch {batch_idx}:\")\n        batch_path = out_dir / f\"batch_{batch_idx}\"\n        batch_path.mkdir()\n        saved_images = []\n        for k, v in batch.items():\n            if isinstance(v, torch.Tensor):\n                print(f\"{k}: Tensor.shape={v.shape}\")\n                if len(v.shape) == 4 and v.shape[1] == 3:\n                    # This is likely a batch of RGB images, so we save them to disk.\n                    for i in range(v.shape[0]):\n                        out_path = batch_path / f\"{k}_{i}.png\"\n                        save_image(v[i, ...], out_path)\n                        saved_images.append(out_path)\n            else:\n                print(f\"{k}: {v}\")\n\n        for saved_image in saved_images:\n            print(f\"Saved image to '{saved_image}'.\")\n\n        _ = input(\"\\n\\nPress Enter to continue to next batch...\\n\")\n\n\ndef main():\n    args = parse_args()\n\n    # Load YAML config file.\n    with open(args.cfg_file, \"r\") as f:\n        cfg = yaml.safe_load(f)\n\n    pipeline_adapter: TypeAdapter[PipelineConfig] = TypeAdapter(PipelineConfig)\n    train_config = pipeline_adapter.validate_python(cfg)\n\n    data_loader_config = train_config.data_loader\n\n    if data_loader_config.type == \"IMAGE_CAPTION_SD_DATA_LOADER\":\n        data_loader = build_image_caption_sd_dataloader(\n            config=data_loader_config,\n            batch_size=train_config.train_batch_size,\n            shuffle=False,\n        )\n    elif data_loader_config.type == \"TEXTUAL_INVERSION_SD_DATA_LOADER\":\n        data_loader = build_textual_inversion_sd_dataloader(\n            config=data_loader_config,\n            placeholder_token=\"<placeholder_token_not_shown_when_visualizing>\",\n            batch_size=train_config.train_batch_size,\n            shuffle=False,\n        )\n    elif data_loader_config.type == \"DREAMBOOTH_SD_DATA_LOADER\":\n        data_loader = build_dreambooth_sd_dataloader(\n            config=data_loader_config,\n            batch_size=train_config.train_batch_size,\n            shuffle=False,\n            sequential_batching=False,\n        )\n    elif data_loader_config.type == \"IMAGE_PAIR_PREFERENCE_SD_DATA_LOADER\":\n        data_loader = build_image_pair_preference_sd_dataloader(\n            config=data_loader_config,\n            batch_size=train_config.train_batch_size,\n            shuffle=False,\n        )\n    else:\n        raise ValueError(f\"Unexpected data loader type: '{data_loader_config.type}'.\")\n\n    visualize(data_loader)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "src/invoke_training/scripts/utils/image_dir_dataset.py",
    "content": "import os\nimport typing\n\nimport torch\nfrom PIL import Image\n\n\nclass ImageDirDataset(torch.utils.data.Dataset):\n    \"\"\"A simple dataset that loads images from a directory.\"\"\"\n\n    def __init__(\n        self,\n        dataset_dir: str,\n        image_extensions: typing.Optional[list[str]] = None,\n    ):\n        super().__init__()\n        if image_extensions is None:\n            image_extensions = [\".png\", \".jpg\", \".jpeg\"]\n        image_extensions = [ext.lower() for ext in image_extensions]\n\n        # Determine the list of image paths to include in the dataset.\n        self._image_paths: list[str] = []\n        for image_file in os.listdir(dataset_dir):\n            image_path = os.path.join(dataset_dir, image_file)\n            if os.path.isfile(image_path) and os.path.splitext(image_path)[1].lower() in image_extensions:\n                self._image_paths.append(image_path)\n        self._image_paths.sort()\n\n    def _load_image(self, image_path: str) -> Image.Image:\n        # We call `convert(\"RGB\")` to drop the alpha channel from RGBA images, or to repeat channels for greyscale\n        # images.\n        return Image.open(image_path).convert(\"RGB\")\n\n    def __len__(self) -> int:\n        return len(self._image_paths)\n\n    def __getitem__(self, idx: int):\n        image_path = self._image_paths[idx]\n        image = self._load_image(image_path)\n        return {\"image_path\": self._image_paths[idx], \"image\": image}\n\n\ndef list_collate_fn(examples):\n    \"\"\"Custom collate_fn that combines images into a list rather than stacking into a tensor. This is what the Moondream\n    model expects.\n    \"\"\"\n    return {\n        \"image\": [example[\"image\"] for example in examples],\n        \"image_path\": [example[\"image_path\"] for example in examples],\n    }\n"
  },
  {
    "path": "src/invoke_training/ui/__init__.py",
    "content": ""
  },
  {
    "path": "src/invoke_training/ui/app.py",
    "content": "from pathlib import Path\n\nimport gradio as gr\nfrom fastapi import FastAPI\nfrom fastapi.responses import FileResponse\nfrom fastapi.staticfiles import StaticFiles\n\nfrom invoke_training.ui.pages.data_page import DataPage\nfrom invoke_training.ui.pages.training_page import TrainingPage\n\n\ndef build_app():\n    training_page = TrainingPage()\n    data_page = DataPage()\n\n    app = FastAPI()\n\n    @app.get(\"/\")\n    async def root():\n        index_path = Path(__file__).parent / \"index.html\"\n        return FileResponse(index_path)\n\n    app.mount(\"/assets\", StaticFiles(directory=Path(__file__).parent.parent / \"assets\"), name=\"assets\")\n\n    app = gr.mount_gradio_app(app, training_page.app(), \"/train\", app_kwargs={\"favicon_path\": \"/assets/favicon.png\"})\n    app = gr.mount_gradio_app(app, data_page.app(), \"/data\", app_kwargs={\"favicon_path\": \"/assets/favicon.png\"})\n    return app\n"
  },
  {
    "path": "src/invoke_training/ui/config_groups/__init__.py",
    "content": ""
  },
  {
    "path": "src/invoke_training/ui/config_groups/aspect_ratio_bucket_config_group.py",
    "content": "from typing import Any\n\nimport gradio as gr\n\nfrom invoke_training.config.data.data_loader_config import AspectRatioBucketConfig\nfrom invoke_training.ui.config_groups.ui_config_element import UIConfigElement\n\n\nclass AspectRatioBucketConfigGroup(UIConfigElement):\n    def __init__(self):\n        gr.Markdown(\n            \"Aspect ratio bucket resolutions are generated as follows:\\n\"\n            \"- Iterate over 'first' dimension values from `start_dim` to `end_dim` in steps of size `divisible_by`.\\n\"\n            \"- Calculate the 'second' dimension to be as close as possible to the total number of pixels in \"\n            \"`target_resolution`, while still being divisible by `divisible_by`.\"\n        )\n        self.enabled = gr.Checkbox(label=\"Use Aspect Ratio Bucketing\", interactive=True)\n        self.target_resolution = gr.Number(label=\"target_resolution\", interactive=True, precision=0)\n        self.start_dim = gr.Number(label=\"start_dimension\", interactive=True, precision=0)\n        self.end_dim = gr.Number(label=\"end_dimension\", interactive=True, precision=0)\n        self.divisible_by = gr.Number(label=\"divisible_by\", interactive=True, precision=0)\n\n    def update_ui_components_with_config_data(\n        self, config: AspectRatioBucketConfig | None\n    ) -> dict[gr.components.Component, Any]:\n        enabled = True\n        if config is None:\n            enabled = False\n            # We just construct this config to hold default values.\n            config = AspectRatioBucketConfig(target_resolution=512, start_dim=256, end_dim=768, divisible_by=64)\n\n        update_dict = {\n            self.enabled: enabled,\n            self.target_resolution: config.target_resolution,\n            self.start_dim: config.start_dim,\n            self.end_dim: config.end_dim,\n            self.divisible_by: config.divisible_by,\n        }\n        return update_dict\n\n    def update_config_with_ui_component_data(\n        self, orig_config: AspectRatioBucketConfig | None, ui_data: dict[gr.components.Component, Any]\n    ) -> AspectRatioBucketConfig | None:\n        # TODO: Use orig_config?\n        if not ui_data.pop(self.enabled):\n            # Pop fields from ui_data so that upstream code knows that the fields were handled.\n            ui_data.pop(self.target_resolution)\n            ui_data.pop(self.start_dim)\n            ui_data.pop(self.end_dim)\n            ui_data.pop(self.divisible_by)\n            return None\n\n        new_config = AspectRatioBucketConfig(\n            target_resolution=ui_data.pop(self.target_resolution),\n            start_dim=ui_data.pop(self.start_dim),\n            end_dim=ui_data.pop(self.end_dim),\n            divisible_by=ui_data.pop(self.divisible_by),\n        )\n        return new_config\n"
  },
  {
    "path": "src/invoke_training/ui/config_groups/base_pipeline_config_group.py",
    "content": "from typing import Any\n\nimport gradio as gr\n\nfrom invoke_training.config.base_pipeline_config import BasePipelineConfig\nfrom invoke_training.config.pipeline_config import PipelineConfig\nfrom invoke_training.ui.config_groups.ui_config_element import UIConfigElement\n\n\nclass BasePipelineConfigGroup(UIConfigElement):\n    def __init__(self):\n        self.base_output_dir = gr.Textbox(\n            label=\"Base Output Directory\",\n            info=\"The base output directory where the training outputs (model checkpoints, logs,\"\n            \" intermediate predictions) will be written.\",\n            interactive=True,\n        )\n        with gr.Row():\n            with gr.Column():\n                self.max_train_steps_or_epochs_dropdown = gr.Dropdown(\n                    label=\"Training Length\",\n                    info=\"Train for a fixed number of gradient update steps or epochs.\",\n                    choices=[\"max_train_steps\", \"max_train_epochs\"],\n                    interactive=True,\n                )\n                self.max_train_steps_or_epochs = gr.Number(label=\"Steps or Epochs\", precision=0, interactive=True)\n\n            with gr.Column():\n                self.save_every_n_steps_or_epochs_dropdown = gr.Dropdown(\n                    label=\"Checkpoint Save Frequency\",\n                    info=\"Save a checkpoint every N gradient update steps or epochs.\",\n                    choices=[\"save_every_n_steps\", \"save_every_n_epochs\"],\n                    interactive=True,\n                )\n                self.save_every_n_steps_or_epochs = gr.Number(label=\"Steps or Epochs\", precision=0, interactive=True)\n\n            with gr.Column():\n                self.validate_every_n_steps_or_epochs_dropdown = gr.Dropdown(\n                    label=\"Validation Frequency\",\n                    info=\"Save validation images every N gradient update steps or epochs.\",\n                    choices=[\"validate_every_n_steps\", \"validate_every_n_epochs\"],\n                    interactive=True,\n                )\n                self.validate_every_n_steps_or_epochs = gr.Number(\n                    label=\"Steps or Epochs\", precision=0, interactive=True\n                )\n        self.seed = gr.Number(\n            label=\"Seed\",\n            info=\"Set to any constant integer for consistent training results. If set to null, training\"\n            \" will be non-deterministic.\",\n            precision=0,\n            interactive=True,\n        )\n\n    def update_ui_components_with_config_data(self, config: BasePipelineConfig) -> dict[gr.components.Component, Any]:\n        if config.max_train_epochs is not None:\n            max_train_steps_or_epochs_dropdown = \"max_train_epochs\"\n            max_train_steps_or_epochs = config.max_train_epochs\n        elif config.max_train_steps is not None:\n            max_train_steps_or_epochs_dropdown = \"max_train_steps\"\n            max_train_steps_or_epochs = config.max_train_steps\n        else:\n            raise ValueError(\"One of max_train_epochs or max_train_steps must be set.\")\n\n        if config.save_every_n_epochs is not None:\n            save_every_n_steps_or_epochs_dropdown = \"save_every_n_epochs\"\n            save_every_n_steps_or_epochs = config.save_every_n_epochs\n        elif config.save_every_n_steps is not None:\n            save_every_n_steps_or_epochs_dropdown = \"save_every_n_steps\"\n            save_every_n_steps_or_epochs = config.save_every_n_steps\n        else:\n            raise ValueError(\"One of save_every_n_epochs or save_every_n_steps must be set.\")\n\n        if config.validate_every_n_epochs is not None:\n            validate_every_n_steps_or_epochs_dropdown = \"validate_every_n_epochs\"\n            validate_every_n_steps_or_epochs = config.validate_every_n_epochs\n        elif config.validate_every_n_steps is not None:\n            validate_every_n_steps_or_epochs_dropdown = \"validate_every_n_steps\"\n            validate_every_n_steps_or_epochs = config.validate_every_n_steps\n        else:\n            raise ValueError(\"One of validate_every_n_epochs or validate_every_n_steps must be set.\")\n\n        return {\n            self.seed: config.seed,\n            self.base_output_dir: config.base_output_dir,\n            self.max_train_steps_or_epochs_dropdown: max_train_steps_or_epochs_dropdown,\n            self.max_train_steps_or_epochs: max_train_steps_or_epochs,\n            self.save_every_n_steps_or_epochs_dropdown: save_every_n_steps_or_epochs_dropdown,\n            self.save_every_n_steps_or_epochs: save_every_n_steps_or_epochs,\n            self.validate_every_n_steps_or_epochs_dropdown: validate_every_n_steps_or_epochs_dropdown,\n            self.validate_every_n_steps_or_epochs: validate_every_n_steps_or_epochs,\n        }\n\n    def update_config_with_ui_component_data(\n        self, orig_config: PipelineConfig, ui_data: dict[gr.components.Component, Any]\n    ) -> PipelineConfig:\n        new_config = orig_config.model_copy(deep=True)\n\n        new_config.seed = ui_data.pop(self.seed)\n        new_config.base_output_dir = ui_data.pop(self.base_output_dir)\n\n        if ui_data.pop(self.max_train_steps_or_epochs_dropdown) == \"max_train_epochs\":\n            new_config.max_train_epochs = ui_data.pop(self.max_train_steps_or_epochs)\n            new_config.max_train_steps = None\n        else:\n            new_config.max_train_steps = ui_data.pop(self.max_train_steps_or_epochs)\n            new_config.max_train_epochs = None\n\n        if ui_data.pop(self.save_every_n_steps_or_epochs_dropdown) == \"save_every_n_epochs\":\n            new_config.save_every_n_epochs = ui_data.pop(self.save_every_n_steps_or_epochs)\n            new_config.save_every_n_steps = None\n        else:\n            new_config.save_every_n_steps = ui_data.pop(self.save_every_n_steps_or_epochs)\n            new_config.save_every_n_epochs = None\n\n        if ui_data.pop(self.validate_every_n_steps_or_epochs_dropdown) == \"validate_every_n_epochs\":\n            new_config.validate_every_n_epochs = ui_data.pop(self.validate_every_n_steps_or_epochs)\n            new_config.validate_every_n_steps = None\n        else:\n            new_config.validate_every_n_steps = ui_data.pop(self.validate_every_n_steps_or_epochs)\n            new_config.validate_every_n_epochs = None\n\n        return new_config\n"
  },
  {
    "path": "src/invoke_training/ui/config_groups/dataset_config_group.py",
    "content": "from typing import Any\n\nimport gradio as gr\n\nfrom invoke_training.config.data.dataset_config import (\n    HFHubImageCaptionDatasetConfig,\n    ImageCaptionDatasetConfig,\n    ImageCaptionDirDatasetConfig,\n    ImageCaptionJsonlDatasetConfig,\n    ImageDirDatasetConfig,\n)\nfrom invoke_training.ui.config_groups.ui_config_element import UIConfigElement\n\nALL_DATASET_TYPES = [\n    \"HF_HUB_IMAGE_CAPTION_DATASET\",\n    \"IMAGE_CAPTION_JSONL_DATASET\",\n    \"IMAGE_CAPTION_DIR_DATASET\",\n    \"IMAGE_DIR_DATASET\",\n]\n\n\nclass HFHubImageCaptionDatasetConfigGroup(UIConfigElement):\n    def __init__(self):\n        self.dataset_name = gr.Textbox(\n            label=\"Dataset Name\", info=\"Hugging Face Dataset Name (e.g., owner/RepoID).\", interactive=True\n        )\n        with gr.Row():\n            self.dataset_config_name = gr.Textbox(\n                label=\"Dataset Config Name (Optional)\",\n                info=\"The Hugging Face dataset config name. Leave as None if there's only one config.\",\n                interactive=True,\n            )\n        with gr.Row():\n            self.hf_cache_dir = gr.Textbox(\n                label=\"Cache Directory\",\n                info=\"The Hugging Face cache directory to use for dataset downloads. If None, the default value\"\n                \" will be used (usually '~/.cache/huggingface/datasets').\",\n                interactive=True,\n            )\n        # self.image_column = gr.Textbox(label=\"image_column\", interactive=True)\n        # self.caption_column = gr.Textbox(label=\"caption_column\", interactive=True)\n\n    def update_ui_components_with_config_data(\n        self, config: HFHubImageCaptionDatasetConfig | None\n    ) -> dict[gr.components.Component, Any]:\n        return {\n            self.dataset_name: config.dataset_name if config else \"<insert_dataset_name>\",\n            self.dataset_config_name: config.dataset_config_name if config else None,\n            self.hf_cache_dir: config.hf_cache_dir if config else None,\n            # self.image_column: config.image_column,\n            # self.caption_column: config.caption_column,\n        }\n\n    def update_config_with_ui_component_data(\n        self, orig_config: HFHubImageCaptionDatasetConfig | None, ui_data: dict[gr.components.Component, Any]\n    ) -> HFHubImageCaptionDatasetConfig:\n        assert orig_config is None\n        # new_config = orig_config.model_copy(deep=True)\n\n        new_config = HFHubImageCaptionDatasetConfig(\n            dataset_name=ui_data.pop(self.dataset_name),\n            dataset_config_name=ui_data.pop(self.dataset_config_name) or None,\n            hf_cache_dir=ui_data.pop(self.hf_cache_dir) or None,\n            # image_column=ui_data.pop(self.image_column),\n            # caption_column=ui_data.pop(self.caption_column),\n        )\n        return new_config\n\n\nclass ImageCaptionJsonlDatasetConfigGroup(UIConfigElement):\n    def __init__(self):\n        self.jsonl_path = gr.Textbox(label=\"jsonl_path\", info=\"Path to the dataset `.jsonl` file.\", interactive=True)\n        self.image_column = gr.Textbox(\n            label=\"image_column\",\n            info=\"The name of the field in the `.jsonl` containing image file paths.\",\n            interactive=True,\n        )\n        self.caption_column = gr.Textbox(\n            label=\"caption_column\",\n            info=\"The name of the field in the `.jsonl` containing image captions.\",\n            interactive=True,\n        )\n        self.keep_in_memory = gr.Checkbox(\n            label=\"keep_in_memory\",\n            info=\"If True, the entire dataset will be kept in RAM. This increases speed for small datasets at the \"\n            \"cost of higher RAM usage.\",\n            interactive=True,\n        )\n\n    def update_ui_components_with_config_data(\n        self, config: ImageCaptionJsonlDatasetConfig | None\n    ) -> dict[gr.components.Component, Any]:\n        if config is None:\n            # We just construct this so that we can use its default values.\n            config = ImageCaptionJsonlDatasetConfig(jsonl_path=\"<path/to/data.jsonl>\")\n\n        return {\n            self.jsonl_path: config.jsonl_path,\n            self.image_column: config.image_column,\n            self.caption_column: config.caption_column,\n            self.keep_in_memory: config.keep_in_memory,\n        }\n\n    def update_config_with_ui_component_data(\n        self, orig_config: ImageCaptionJsonlDatasetConfig | None, ui_data: dict[gr.components.Component, Any]\n    ) -> ImageCaptionJsonlDatasetConfig:\n        assert orig_config is None\n        # new_config = orig_config.model_copy(deep=True)\n\n        new_config = ImageCaptionJsonlDatasetConfig(\n            jsonl_path=ui_data.pop(self.jsonl_path),\n            image_column=ui_data.pop(self.image_column),\n            caption_column=ui_data.pop(self.caption_column),\n            keep_in_memory=ui_data.pop(self.keep_in_memory),\n        )\n        return new_config\n\n\nclass ImageCaptionDirDatasetConfigGroup(UIConfigElement):\n    def __init__(self):\n        with gr.Row():\n            self.dataset_dir = gr.Textbox(\n                label=\"dataset_dir\", info=\"The path to the dataset directory.\", interactive=True\n            )\n        with gr.Row():\n            self.keep_in_memory = gr.Checkbox(\n                label=\"keep_in_memory\",\n                info=\"If True, the entire dataset will be kept in RAM. This increases speed for small datasets at the \"\n                \"cost of higher RAM usage.\",\n                interactive=True,\n            )\n\n    def update_ui_components_with_config_data(\n        self, config: ImageCaptionDirDatasetConfig | None\n    ) -> dict[gr.components.Component, Any]:\n        return {\n            self.dataset_dir: config.dataset_dir if config else \"<path/to/dataset_dir>\",\n            self.keep_in_memory: config.keep_in_memory if config else False,\n        }\n\n    def update_config_with_ui_component_data(\n        self, orig_config: ImageCaptionDirDatasetConfig | None, ui_data: dict[gr.components.Component, Any]\n    ) -> ImageCaptionDirDatasetConfig:\n        assert orig_config is None\n        # new_config = orig_config.model_copy(deep=True)\n\n        new_config = ImageCaptionDirDatasetConfig(\n            dataset_dir=ui_data.pop(self.dataset_dir), keep_in_memory=ui_data.pop(self.keep_in_memory)\n        )\n        return new_config\n\n\nclass ImageDirDatasetConfigGroup(UIConfigElement):\n    def __init__(self):\n        with gr.Row():\n            self.dataset_dir = gr.Textbox(\n                label=\"dataset_dir\", info=\"The path to the dataset directory.\", interactive=True\n            )\n        with gr.Row():\n            self.keep_in_memory = gr.Checkbox(\n                label=\"keep_in_memory\",\n                info=\"If True, the entire dataset will be kept in RAM. This increases speed for small datasets at the \"\n                \"cost of higher RAM usage.\",\n                interactive=True,\n            )\n\n    def update_ui_components_with_config_data(\n        self, config: ImageDirDatasetConfig | None\n    ) -> dict[gr.components.Component, Any]:\n        return {\n            self.dataset_dir: config.dataset_dir if config else \"<path/to/dataset_dir>\",\n            self.keep_in_memory: config.keep_in_memory if config else False,\n        }\n\n    def update_config_with_ui_component_data(\n        self, orig_config: ImageDirDatasetConfig | None, ui_data: dict[gr.components.Component, Any]\n    ) -> ImageDirDatasetConfig:\n        assert orig_config is None\n        # new_config = orig_config.model_copy(deep=True)\n\n        new_config = ImageDirDatasetConfig(\n            dataset_dir=ui_data.pop(self.dataset_dir), keep_in_memory=ui_data.pop(self.keep_in_memory)\n        )\n        return new_config\n\n\nclass DatasetConfigGroup(UIConfigElement):\n    def __init__(self, allowed_types: list[str]):\n        self.type = gr.Dropdown(\n            choices=[t for t in ALL_DATASET_TYPES if t in allowed_types],\n            label=\"Dataset Type\",\n            info=\"The type of dataset to use for training. See \"\n            \"https://invoke-ai.github.io/invoke-training/concepts/dataset_formats/ for a description of each format.\",\n            interactive=True,\n        )\n\n        with gr.Group() as hf_hub_image_caption_dataset_config_group:\n            self.hf_hub_image_caption_dataset_config = HFHubImageCaptionDatasetConfigGroup()\n        self.hf_hub_image_caption_dataset_config_group = hf_hub_image_caption_dataset_config_group\n\n        with gr.Group() as image_caption_jsonl_dataset_config_group:\n            self.image_caption_jsonl_dataset_config = ImageCaptionJsonlDatasetConfigGroup()\n        self.image_caption_jsonl_dataset_config_group = image_caption_jsonl_dataset_config_group\n\n        with gr.Group() as image_caption_dir_dataset_config_group:\n            self.image_caption_dir_dataset_config = ImageCaptionDirDatasetConfigGroup()\n        self.image_caption_dir_dataset_config_group = image_caption_dir_dataset_config_group\n\n        with gr.Group() as image_dir_dataset_config_group:\n            self.image_dir_dataset_config = ImageDirDatasetConfigGroup()\n        self.image_dir_dataset_config_group = image_dir_dataset_config_group\n\n        self.type.change(\n            self._on_type_change,\n            inputs=[self.type],\n            outputs=[\n                self.hf_hub_image_caption_dataset_config_group,\n                self.image_caption_jsonl_dataset_config_group,\n                self.image_caption_dir_dataset_config_group,\n                self.image_dir_dataset_config_group,\n            ],\n        )\n\n    def _on_type_change(self, type: str):\n        return {\n            self.hf_hub_image_caption_dataset_config_group: gr.Group(visible=type == \"HF_HUB_IMAGE_CAPTION_DATASET\"),\n            self.image_caption_jsonl_dataset_config_group: gr.Group(visible=type == \"IMAGE_CAPTION_JSONL_DATASET\"),\n            self.image_caption_dir_dataset_config_group: gr.Group(visible=type == \"IMAGE_CAPTION_DIR_DATASET\"),\n            self.image_dir_dataset_config_group: gr.Group(visible=type == \"IMAGE_DIR_DATASET\"),\n        }\n\n    def update_ui_components_with_config_data(\n        self, config: ImageCaptionDatasetConfig\n    ) -> dict[gr.components.Component, Any]:\n        update_dict = {\n            self.type: config.type,\n            self.hf_hub_image_caption_dataset_config_group: gr.Group(\n                visible=config.type == \"HF_HUB_IMAGE_CAPTION_DATASET\"\n            ),\n            self.image_caption_jsonl_dataset_config_group: gr.Group(\n                visible=config.type == \"IMAGE_CAPTION_JSONL_DATASET\"\n            ),\n            self.image_caption_dir_dataset_config_group: gr.Group(visible=config.type == \"IMAGE_CAPTION_DIR_DATASET\"),\n            self.image_dir_dataset_config_group: gr.Group(visible=config.type == \"IMAGE_DIR_DATASET\"),\n        }\n\n        update_dict.update(\n            self.hf_hub_image_caption_dataset_config.update_ui_components_with_config_data(\n                config if config.type == \"HF_HUB_IMAGE_CAPTION_DATASET\" else None\n            )\n        )\n        update_dict.update(\n            self.image_caption_jsonl_dataset_config.update_ui_components_with_config_data(\n                config if config.type == \"IMAGE_CAPTION_JSONL_DATASET\" else None\n            )\n        )\n        update_dict.update(\n            self.image_caption_dir_dataset_config.update_ui_components_with_config_data(\n                config if config.type == \"IMAGE_CAPTION_DIR_DATASET\" else None\n            )\n        )\n        update_dict.update(\n            self.image_dir_dataset_config.update_ui_components_with_config_data(\n                config if config.type == \"IMAGE_DIR_DATASET\" else None\n            )\n        )\n\n        return update_dict\n\n    def update_config_with_ui_component_data(\n        self, orig_config: ImageCaptionDatasetConfig, ui_data: dict[gr.components.Component, Any]\n    ) -> ImageCaptionDatasetConfig:\n        # TODO: Use orig_config.\n\n        new_config_hf_hub = self.hf_hub_image_caption_dataset_config.update_config_with_ui_component_data(None, ui_data)\n        new_config_jsonl = self.image_caption_jsonl_dataset_config.update_config_with_ui_component_data(None, ui_data)\n        new_config_image_caption_dir = self.image_caption_dir_dataset_config.update_config_with_ui_component_data(\n            None, ui_data\n        )\n        new_config_image_dir = self.image_dir_dataset_config.update_config_with_ui_component_data(None, ui_data)\n\n        type = ui_data.pop(self.type)\n        if type == \"HF_HUB_IMAGE_CAPTION_DATASET\":\n            new_config = new_config_hf_hub\n        elif type == \"IMAGE_CAPTION_JSONL_DATASET\":\n            new_config = new_config_jsonl\n        elif type == \"IMAGE_CAPTION_DIR_DATASET\":\n            new_config = new_config_image_caption_dir\n        elif type == \"IMAGE_DIR_DATASET\":\n            new_config = new_config_image_dir\n        else:\n            raise ValueError(f\"Unknown dataset type: {type}\")\n\n        return new_config\n"
  },
  {
    "path": "src/invoke_training/ui/config_groups/flux_lora_config_group.py",
    "content": "import typing\n\nimport gradio as gr\n\nfrom invoke_training.pipelines.flux.lora.config import FluxLoraConfig\nfrom invoke_training.ui.config_groups.base_pipeline_config_group import BasePipelineConfigGroup\nfrom invoke_training.ui.config_groups.image_caption_sd_data_loader_config_group import (\n    ImageCaptionSDDataLoaderConfigGroup,\n)\nfrom invoke_training.ui.config_groups.optimizer_config_group import OptimizerConfigGroup\nfrom invoke_training.ui.config_groups.ui_config_element import UIConfigElement\nfrom invoke_training.ui.utils.utils import get_typing_literal_options\n\n\nclass FluxLoraConfigGroup(UIConfigElement):\n    def __init__(self):\n        \"\"\"The Flux LoRA configs.\"\"\"\n\n        gr.Markdown(\"## Basic Configs\")\n        with gr.Row():\n            with gr.Column(scale=1):\n                with gr.Tab(\"Base Model\"):\n                    self.model = gr.Textbox(\n                        label=\"Model\",\n                        info=\"The base model. Can be a Hugging Face Hub model name, or a path to a local model (in \"\n                        \"diffusers or checkpoint format).\",\n                        type=\"text\",\n                        interactive=True,\n                    )\n                    # Flux model doesn't use hf_variant\n            with gr.Column(scale=3):\n                with gr.Tab(\"Training Outputs\"):\n                    self.base_pipeline_config_group = BasePipelineConfigGroup()\n                    self.max_checkpoints = gr.Number(\n                        label=\"Maximum Number of Checkpoints\",\n                        info=\"The maximum number of checkpoints to keep on disk from this training run. Earlier \"\n                        \"checkpoints will be deleted to respect this limit.\",\n                        interactive=True,\n                        precision=0,\n                    )\n\n        gr.Markdown(\"## Data Configs\")\n        self.image_caption_sd_data_loader_config_group = ImageCaptionSDDataLoaderConfigGroup()\n\n        gr.Markdown(\"## Optimizer Configs\")\n        self.optimizer_config_group = OptimizerConfigGroup()\n\n        gr.Markdown(\"## Scheduler Configs\")\n        with gr.Row():\n            with gr.Column():\n                self.lr_scheduler = gr.Dropdown(\n                    label=\"Learning Rate Scheduler\",\n                    choices=get_typing_literal_options(FluxLoraConfig, \"lr_scheduler\"),\n                    interactive=True,\n                )\n                self.lr_warmup_steps = gr.Number(\n                    label=\"Learning Rate Warmup Steps\",\n                    info=\"Number of steps for the warmup in the lr scheduler.\",\n                    interactive=True,\n                    precision=0,\n                )\n\n        gr.Markdown(\"## General Training Configs\")\n        with gr.Tab(\"Core\"):\n            with gr.Row():\n                self.train_transformer = gr.Checkbox(label=\"Train Transformer\", interactive=True)\n            with gr.Row():\n                self.transformer_learning_rate = gr.Number(\n                    label=\"Transformer Learning Rate\",\n                    info=\"The transformer learning rate. Set to 0 or leave empty to inherit from the base optimizer \"\n                    \"learning rate.\",\n                    interactive=True,\n                )\n            with gr.Row():\n                self.gradient_accumulation_steps = gr.Number(\n                    label=\"Gradient Accumulation Steps\",\n                    info=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n                    interactive=True,\n                    precision=0,\n                )\n                self.gradient_checkpointing = gr.Checkbox(\n                    label=\"Gradient Checkpointing\",\n                    info=\"Whether to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n                    interactive=True,\n                )\n            # Training/saving/validating steps/epochs are handled by BasePipelineConfigGroup\n\n        with gr.Tab(\"Advanced\"):\n            with gr.Column():\n                self.lora_rank_dim = gr.Number(\n                    label=\"LoRA Rank Dim\",\n                    info=\"The rank dimension to use for the LoRA layers. Increasing the rank dimension increases\"\n                    \" the model's expressivity, but also increases the size of the generated LoRA model.\",\n                    interactive=True,\n                    precision=0,\n                )\n                self.min_snr_gamma = gr.Number(\n                    label=\"Minimum SNR Gamma\",\n                    info=\"min_snr_gamma acts like an an upper bound on the weight of samples with low noise \"\n                    \"levels. If None, then Min-SNR weighting will not be applied. If enabled, the recommended \"\n                    \"value is min_snr gamma = 5.0.\",\n                    interactive=True,\n                )\n                self.max_grad_norm = gr.Number(\n                    label=\"Max Gradient Norm\",\n                    info=\"Max gradient norm for clipping. Set to 0 or leave empty for no clipping (null).\",\n                    interactive=True,\n                )\n                self.train_batch_size = gr.Number(\n                    label=\"Batch Size\",\n                    info=\"The Training Batch Size - Higher values require increasing amounts of VRAM.\",\n                    precision=0,\n                    interactive=True,\n                )\n                self.weight_dtype = gr.Dropdown(\n                    label=\"Weight Data Type\",\n                    choices=get_typing_literal_options(FluxLoraConfig, \"weight_dtype\"),\n                    info=\"The data type to use for model weights during training.\",\n                    interactive=True,\n                )\n                self.mixed_precision = gr.Dropdown(\n                    label=\"Mixed Precision\",\n                    choices=get_typing_literal_options(FluxLoraConfig, \"mixed_precision\"),\n                    info=\"The mixed precision mode to use.\",\n                    interactive=True,\n                )\n                self.lora_checkpoint_format = gr.Dropdown(\n                    label=\"LoRA Checkpoint Format\",\n                    choices=get_typing_literal_options(FluxLoraConfig, \"lora_checkpoint_format\"),\n                    info=\"The format of the LoRA checkpoint to save.\",\n                    interactive=True,\n                )\n                self.timestep_sampler = gr.Dropdown(\n                    label=\"Timestep Sampler\",\n                    choices=get_typing_literal_options(FluxLoraConfig, \"timestep_sampler\"),\n                    info=\"The timestep sampler to use.\",\n                    interactive=True,\n                )\n                self.discrete_flow_shift = gr.Number(\n                    label=\"Discrete Flow Shift\",\n                    info=\"The shift parameter for the discrete flow. Only used if timestep_sampler is 'shift'.\",\n                    interactive=True,\n                )\n                self.sigmoid_scale = gr.Number(\n                    label=\"Sigmoid Scale\",\n                    info=\"The scale parameter for the sigmoid function. Only used if timestep_sampler is 'shift'.\",\n                    interactive=True,\n                )\n                self.lora_scale = gr.Number(\n                    label=\"LoRA Scale\",\n                    info=\"The scale parameter for the LoRA layers.\",\n                    interactive=True,\n                )\n                self.guidance_scale = gr.Number(\n                    label=\"Guidance Scale\",\n                    info=\"The guidance scale for the Flux model.\",\n                    interactive=True,\n                )\n                self.use_masks = gr.Checkbox(\n                    label=\"Use Masks\",\n                    info=\"If True, image masks will be applied to weight the loss during training. The dataset must \"\n                    \"contain masks for this feature to be used.\",\n                    interactive=True,\n                )\n                self.prediction_type = gr.Dropdown(\n                    label=\"Prediction Type\",\n                    choices=[\"epsilon\", \"v_prediction\", None],\n                    info=\"The prediction type that will be used for training.\",\n                    interactive=True,\n                )\n\n        gr.Markdown(\"## Validation\")\n        with gr.Group():\n            self.validation_prompts = gr.Textbox(\n                label=\"Validation Prompts\",\n                info=\"Enter one validation prompt per line. Optionally, add negative prompts after a '[NEG]' \"\n                \"delimiter. For example: `positive prompt[NEG]negative prompt`. \",\n                lines=5,\n                interactive=True,\n            )\n            self.num_validation_images_per_prompt = gr.Number(\n                label=\"# of Validation Images to Generate per Prompt\", precision=0, interactive=True\n            )\n\n    def get_ui_output_components(self) -> list[gr.components.Component]:\n        # Get our own components\n        components = [\n            self.model,\n            self.train_transformer,\n            self.transformer_learning_rate,\n            self.gradient_accumulation_steps,\n            self.gradient_checkpointing,\n            self.lr_scheduler,\n            self.lr_warmup_steps,\n            self.lora_rank_dim,\n            self.min_snr_gamma,\n            self.max_grad_norm,\n            self.train_batch_size,\n            self.weight_dtype,\n            self.mixed_precision,\n            self.lora_checkpoint_format,\n            self.timestep_sampler,\n            self.discrete_flow_shift,\n            self.sigmoid_scale,\n            self.lora_scale,\n            self.guidance_scale,\n            self.use_masks,\n            self.prediction_type,\n            # These are not UI components but need to be preserved\n            # self.flux_lora_target_modules,\n            # self.text_encoder_lora_target_modules,\n            self.validation_prompts,\n            self.num_validation_images_per_prompt,\n            self.max_checkpoints,\n        ]\n\n        # Add components from nested config groups\n        components.extend(self.base_pipeline_config_group.get_ui_output_components())\n        components.extend(self.image_caption_sd_data_loader_config_group.get_ui_output_components())\n        components.extend(self.optimizer_config_group.get_ui_output_components())\n\n        return components\n\n    def update_ui_components_with_config_data(\n        self, config: FluxLoraConfig\n    ) -> dict[gr.components.Component, typing.Any]:\n        try:\n            update_dict = {\n                self.model: config.model,\n                self.train_transformer: config.train_transformer,\n                self.transformer_learning_rate: config.transformer_learning_rate,\n                self.gradient_accumulation_steps: config.gradient_accumulation_steps,\n                self.gradient_checkpointing: config.gradient_checkpointing,\n                self.lr_scheduler: config.lr_scheduler,\n                self.lr_warmup_steps: config.lr_warmup_steps,\n                self.lora_rank_dim: config.lora_rank_dim,\n                self.min_snr_gamma: config.min_snr_gamma,\n                self.max_grad_norm: config.max_grad_norm,\n                self.train_batch_size: config.train_batch_size,\n                self.weight_dtype: config.weight_dtype,\n                self.mixed_precision: config.mixed_precision,\n                self.lora_checkpoint_format: config.lora_checkpoint_format,\n                self.timestep_sampler: config.timestep_sampler,\n                self.discrete_flow_shift: config.discrete_flow_shift,\n                self.sigmoid_scale: config.sigmoid_scale,\n                self.lora_scale: config.lora_scale,\n                self.guidance_scale: config.guidance_scale,\n                self.use_masks: config.use_masks,\n                self.prediction_type: config.prediction_type,\n                self.validation_prompts: config.validation_prompts,\n                self.num_validation_images_per_prompt: config.num_validation_images_per_prompt,\n                self.max_checkpoints: config.max_checkpoints,\n            }\n\n            # Update with nested config groups\n            try:\n                update_dict.update(self.base_pipeline_config_group.update_ui_components_with_config_data(config))\n            except Exception as e:\n                print(f\"Error updating base pipeline config: {e}\")\n\n            try:\n                update_dict.update(self.optimizer_config_group.update_ui_components_with_config_data(config.optimizer))\n            except Exception as e:\n                print(f\"Error updating optimizer config: {e}\")\n\n            try:\n                update_dict.update(\n                    self.image_caption_sd_data_loader_config_group.update_ui_components_with_config_data(\n                        config.data_loader\n                    )\n                )\n            except Exception as e:\n                print(f\"Error updating data loader config: {e}\")\n\n            # Sanity check to catch if we accidentally forget to update a UI component.\n            # We'll skip this check for now as it's causing issues with nested components\n            # assert set(update_dict.keys()) == set(self.get_ui_output_components())\n\n            return update_dict\n        except Exception as e:\n            print(f\"Error in update_ui_components_with_config_data: {e}\")\n            # Return a minimal update dict to avoid UI errors\n            return {self.model: config.model}\n\n    def update_config_with_ui_component_data(  # noqa: C901\n        self, orig_config: FluxLoraConfig, ui_data: dict[gr.components.Component, typing.Any]\n    ) -> FluxLoraConfig:\n        try:\n            # Handle the case where orig_config might be None\n            if orig_config is None:\n                from invoke_training.config.optimizer.optimizer_config import AdamOptimizerConfig\n                from invoke_training.pipelines.flux.lora.config import FluxLoraConfig\n\n                # Create a default config\n                orig_config = FluxLoraConfig(\n                    model=\"black-forest-labs/FLUX.1-dev\",\n                    optimizer=AdamOptimizerConfig(),\n                )\n\n            new_config = orig_config.model_copy(deep=True)\n\n            # Create a copy of ui_data to avoid modifying the original\n            ui_data_copy = ui_data.copy()\n\n            # Helper function to safely pop values from ui_data\n            def safe_pop(component, default=None):\n                try:\n                    return ui_data_copy.pop(component)\n                except (KeyError, TypeError) as e:\n                    print(f\"Error popping {component}: {e}\")\n                    return default\n\n            # Set basic properties\n            new_config.model = safe_pop(self.model, new_config.model)\n            new_config.train_transformer = safe_pop(self.train_transformer, new_config.train_transformer)\n            # Note: train_text_encoder and text_encoder_learning_rate are not supported for Flux LoRA\n            transformer_lr_value = safe_pop(self.transformer_learning_rate, new_config.transformer_learning_rate)\n            new_config.transformer_learning_rate = None if transformer_lr_value == 0 else transformer_lr_value\n            new_config.gradient_accumulation_steps = safe_pop(\n                self.gradient_accumulation_steps, new_config.gradient_accumulation_steps\n            )\n            new_config.gradient_checkpointing = safe_pop(self.gradient_checkpointing, new_config.gradient_checkpointing)\n            # Training/saving/validating steps/epochs are handled by BasePipelineConfigGroup\n            new_config.lr_scheduler = safe_pop(self.lr_scheduler, new_config.lr_scheduler)\n            new_config.lr_warmup_steps = safe_pop(self.lr_warmup_steps, new_config.lr_warmup_steps)\n\n            new_config.lora_rank_dim = safe_pop(self.lora_rank_dim, new_config.lora_rank_dim)\n            new_config.min_snr_gamma = safe_pop(self.min_snr_gamma, new_config.min_snr_gamma)\n            max_grad_norm_value = safe_pop(self.max_grad_norm, new_config.max_grad_norm)\n            new_config.max_grad_norm = None if max_grad_norm_value == 0 else max_grad_norm_value\n            new_config.train_batch_size = safe_pop(self.train_batch_size, new_config.train_batch_size)\n            new_config.weight_dtype = safe_pop(self.weight_dtype, new_config.weight_dtype)\n            new_config.mixed_precision = safe_pop(self.mixed_precision, new_config.mixed_precision)\n            new_config.lora_checkpoint_format = safe_pop(self.lora_checkpoint_format, new_config.lora_checkpoint_format)\n            new_config.timestep_sampler = safe_pop(self.timestep_sampler, new_config.timestep_sampler)\n            new_config.discrete_flow_shift = safe_pop(self.discrete_flow_shift, new_config.discrete_flow_shift)\n            new_config.sigmoid_scale = safe_pop(self.sigmoid_scale, new_config.sigmoid_scale)\n            new_config.lora_scale = safe_pop(self.lora_scale, new_config.lora_scale)\n            new_config.guidance_scale = safe_pop(self.guidance_scale, new_config.guidance_scale)\n            new_config.use_masks = safe_pop(self.use_masks, new_config.use_masks)\n            new_config.prediction_type = safe_pop(self.prediction_type, new_config.prediction_type)\n            new_config.max_checkpoints = safe_pop(self.max_checkpoints, new_config.max_checkpoints)\n\n            # Preserve the target modules from the original config\n            # These are not UI components but need to be preserved\n            if hasattr(orig_config, \"flux_lora_target_modules\") and orig_config.flux_lora_target_modules:\n                new_config.flux_lora_target_modules = orig_config.flux_lora_target_modules\n            if (\n                hasattr(orig_config, \"text_encoder_lora_target_modules\")\n                and orig_config.text_encoder_lora_target_modules\n            ):\n                new_config.text_encoder_lora_target_modules = orig_config.text_encoder_lora_target_modules\n\n            # Handle validation prompts\n            try:\n                validation_prompts_text = safe_pop(self.validation_prompts, \"\")\n                positive_prompts = validation_prompts_text\n                new_config.validation_prompts = positive_prompts\n            except Exception as e:\n                print(f\"Error processing validation prompts: {e}\")\n\n            new_config.num_validation_images_per_prompt = safe_pop(\n                self.num_validation_images_per_prompt, new_config.num_validation_images_per_prompt\n            )\n\n            # Update nested configs\n            try:\n                data_loader_config_group = self.image_caption_sd_data_loader_config_group\n                # Handle the case where data_loader might be None\n                new_config.data_loader = data_loader_config_group.update_config_with_ui_component_data(\n                    new_config.data_loader, ui_data_copy\n                )\n            except Exception as e:\n                print(f\"Error updating data loader config: {e}\")\n\n            try:\n                base_pipeline_group = self.base_pipeline_config_group\n                new_config = base_pipeline_group.update_config_with_ui_component_data(new_config, ui_data_copy)\n            except Exception as e:\n                print(f\"Error updating base pipeline config: {e}\")\n\n            try:\n                # Handle the case where optimizer might be None\n                if new_config.optimizer is None:\n                    from invoke_training.config.optimizer.optimizer_config import AdamOptimizerConfig\n\n                    new_config.optimizer = AdamOptimizerConfig()\n\n                new_config.optimizer = self.optimizer_config_group.update_config_with_ui_component_data(\n                    new_config.optimizer, ui_data_copy\n                )\n            except Exception as e:\n                print(f\"Error updating optimizer config: {e}\")\n\n            # We're more lenient with the assertion now\n            if len(ui_data_copy) > 0:\n                print(f\"Warning: {len(ui_data_copy)} UI components were not transferred to the config\")\n\n            return new_config\n\n        except Exception as e:\n            print(f\"Error in update_config_with_ui_component_data: {e}\")\n            # Return the original config to avoid errors\n            return orig_config\n"
  },
  {
    "path": "src/invoke_training/ui/config_groups/image_caption_sd_data_loader_config_group.py",
    "content": "from typing import Any\n\nimport gradio as gr\n\nfrom invoke_training.config.data.data_loader_config import ImageCaptionSDDataLoaderConfig\nfrom invoke_training.ui.config_groups.aspect_ratio_bucket_config_group import AspectRatioBucketConfigGroup\nfrom invoke_training.ui.config_groups.dataset_config_group import DatasetConfigGroup\nfrom invoke_training.ui.config_groups.ui_config_element import UIConfigElement\n\n\nclass ImageCaptionSDDataLoaderConfigGroup(UIConfigElement):\n    def __init__(self):\n        with gr.Tab(\"Data Source Configs\"):\n            with gr.Row():\n                with gr.Column(scale=1):\n                    with gr.Group():\n                        self.dataset = DatasetConfigGroup(\n                            allowed_types=[\n                                \"HF_HUB_IMAGE_CAPTION_DATASET\",\n                                \"IMAGE_CAPTION_JSONL_DATASET\",\n                                \"IMAGE_CAPTION_DIR_DATASET\",\n                            ]\n                        )\n                with gr.Column(scale=3):\n                    with gr.Tab(\"Data Loading Configs\"):\n                        with gr.Group():\n                            with gr.Row():\n                                self.resolution = gr.Number(\n                                    label=\"Resolution\",\n                                    info=\"The resolution for input images. All of the images in the dataset will be\"\n                                    \" resized to this resolution unless the aspect_ratio_buckets config is set.\",\n                                    precision=0,\n                                    interactive=True,\n                                )\n                                self.dataloader_num_workers = gr.Number(\n                                    label=\"Dataloading Workers\",\n                                    info=\"Number of subprocesses to use for data loading. 0 means that the data will\"\n                                    \" be loaded in the main process.\",\n                                    precision=0,\n                                    interactive=True,\n                                )\n                            with gr.Row():\n                                self.center_crop = gr.Checkbox(\n                                    label=\"Center Crop\",\n                                    info=\"If set, input images will be center-cropped to the target resolution.\"\n                                    \" Otherwise, input images will be randomly cropped to the target resolution.\",\n                                    interactive=True,\n                                )\n                                self.random_flip = gr.Checkbox(\n                                    label=\"Random Flip\",\n                                    info=\"If set, random flip augmentations will be applied to input images.\",\n                                    interactive=True,\n                                )\n                            self.caption_prefix = gr.Textbox(\n                                label=\"Caption Prefix\",\n                                info=\"A prefix that will be prepended to all captions.\"\n                                \" If None, no prefix will be added.\",\n                                interactive=True,\n                            )\n                    with gr.Tab(\"Aspect Ratio Bucketing Configs\"):\n                        self.aspect_ratio_bucket_config_group = AspectRatioBucketConfigGroup()\n\n    def update_ui_components_with_config_data(\n        self, config: ImageCaptionSDDataLoaderConfig\n    ) -> dict[gr.components.Component, Any]:\n        update_dict = {\n            self.resolution: config.resolution,\n            self.center_crop: config.center_crop,\n            self.random_flip: config.random_flip,\n            self.caption_prefix: config.caption_prefix,\n            self.dataloader_num_workers: config.dataloader_num_workers,\n        }\n\n        update_dict.update(self.dataset.update_ui_components_with_config_data(config.dataset))\n        update_dict.update(\n            self.aspect_ratio_bucket_config_group.update_ui_components_with_config_data(config.aspect_ratio_buckets)\n        )\n\n        return update_dict\n\n    def update_config_with_ui_component_data(\n        self,\n        orig_config: ImageCaptionSDDataLoaderConfig,\n        ui_data: dict[gr.components.Component, Any],\n    ) -> ImageCaptionSDDataLoaderConfig:\n        # Handle the case where orig_config is None\n        if orig_config is None:\n            from invoke_training.config.data.data_loader_config import (\n                AspectRatioBucketConfig,\n                ImageCaptionSDDataLoaderConfig,\n            )\n            from invoke_training.config.data.dataset_config import ImageCaptionJsonlDatasetConfig\n\n            # Create a default config\n            orig_config = ImageCaptionSDDataLoaderConfig(\n                type=\"IMAGE_CAPTION_SD_DATA_LOADER\",\n                dataset=ImageCaptionJsonlDatasetConfig(jsonl_path=\"<path/to/data.jsonl>\"),\n                aspect_ratio_buckets=AspectRatioBucketConfig(),\n                resolution=512,\n                center_crop=False,\n                random_flip=True,\n                caption_prefix=None,\n                dataloader_num_workers=4,\n            )\n\n        new_config = orig_config.model_copy(deep=True)\n\n        new_config.dataset = self.dataset.update_config_with_ui_component_data(orig_config.dataset, ui_data)\n        new_config.aspect_ratio_buckets = self.aspect_ratio_bucket_config_group.update_config_with_ui_component_data(\n            orig_config.aspect_ratio_buckets, ui_data\n        )\n        new_config.resolution = ui_data.pop(self.resolution)\n        new_config.center_crop = ui_data.pop(self.center_crop)\n        new_config.random_flip = ui_data.pop(self.random_flip)\n        new_config.caption_prefix = ui_data.pop(self.caption_prefix) or None\n        new_config.dataloader_num_workers = ui_data.pop(self.dataloader_num_workers)\n\n        return new_config\n"
  },
  {
    "path": "src/invoke_training/ui/config_groups/optimizer_config_group.py",
    "content": "from typing import Any\n\nimport gradio as gr\n\nfrom invoke_training.config.optimizer.optimizer_config import AdamOptimizerConfig, ProdigyOptimizerConfig\nfrom invoke_training.ui.config_groups.ui_config_element import UIConfigElement\n\nOptimizerConfig = AdamOptimizerConfig | ProdigyOptimizerConfig\n\n\nclass AdamOptimizerConfigGroup(UIConfigElement):\n    def __init__(self):\n        with gr.Tab(\"Core\"):\n            with gr.Row():\n                self.learning_rate = gr.Number(\n                    label=\"Learning Rate\",\n                    info=\"Initial learning rate to use (after the potential warmup period). Note that in some training \"\n                    \"pipelines this can be overriden for a specific group of params.\",\n                    interactive=True,\n                )\n                self.use_8bit = gr.Checkbox(\n                    label=\"Use 8-bit\",\n                    info=\"Use 8-bit Adam optimizer to reduce VRAM requirements. (Requires bitsandbytes.)\",\n                    interactive=True,\n                )\n        with gr.Tab(\"Advanced\"):\n            with gr.Row():\n                self.beta1 = gr.Number(label=\"beta1\", interactive=True)\n                self.beta2 = gr.Number(label=\"beta2\", interactive=True)\n            with gr.Row():\n                self.weight_decay = gr.Number(label=\"Weight Decay\", interactive=True)\n                self.epsilon = gr.Number(label=\"epsilon\", interactive=True)\n\n    def update_ui_components_with_config_data(self, config: AdamOptimizerConfig) -> dict[gr.components.Component, Any]:\n        return {\n            self.learning_rate: config.learning_rate,\n            self.beta1: config.beta1,\n            self.beta2: config.beta2,\n            self.weight_decay: config.weight_decay,\n            self.epsilon: config.epsilon,\n            self.use_8bit: config.use_8bit,\n        }\n\n    def update_config_with_ui_component_data(\n        self, orig_config: AdamOptimizerConfig | None, ui_data: dict\n    ) -> OptimizerConfig:\n        assert orig_config is None\n\n        return AdamOptimizerConfig(\n            learning_rate=ui_data.pop(self.learning_rate),\n            beta1=ui_data.pop(self.beta1),\n            beta2=ui_data.pop(self.beta2),\n            weight_decay=ui_data.pop(self.weight_decay),\n            epsilon=ui_data.pop(self.epsilon),\n            use_8bit=ui_data.pop(self.use_8bit),\n        )\n\n\nclass ProdigyOptimizerConfigGroup(UIConfigElement):\n    def __init__(self):\n        with gr.Tab(\"Core\"):\n            with gr.Row():\n                self.learning_rate = gr.Number(\n                    label=\"Learning Rate\",\n                    info=\"The learning rate. For the Prodigy optimizer, the learning rate is adjusted dynamically. A \"\n                    \"value of 1.0 is recommended. Note that in some pipelines this can be overriden for specific \"\n                    \"groups of parameters.\",\n                    interactive=True,\n                )\n        with gr.Tab(\"Advanced\"):\n            with gr.Row():\n                self.weight_decay = gr.Number(label=\"Weight Decay\", interactive=True)\n            with gr.Row():\n                self.use_bias_correction = gr.Checkbox(label=\"Bias Correction\", interactive=True)\n                self.safeguard_warmup = gr.Checkbox(label=\"Safeguard Warmup\", interactive=True)\n\n    def update_ui_components_with_config_data(\n        self, config: ProdigyOptimizerConfig\n    ) -> dict[gr.components.Component, Any]:\n        return {\n            self.learning_rate: config.learning_rate,\n            self.weight_decay: config.weight_decay,\n            self.use_bias_correction: config.use_bias_correction,\n            self.safeguard_warmup: config.safeguard_warmup,\n        }\n\n    def update_config_with_ui_component_data(\n        self, orig_config: ProdigyOptimizerConfig | None, ui_data: dict\n    ) -> OptimizerConfig:\n        assert orig_config is None\n\n        return ProdigyOptimizerConfig(\n            learning_rate=ui_data.pop(self.learning_rate),\n            weight_decay=ui_data.pop(self.weight_decay),\n            use_bias_correction=ui_data.pop(self.use_bias_correction),\n            safeguard_warmup=ui_data.pop(self.safeguard_warmup),\n        )\n\n\nclass OptimizerConfigGroup(UIConfigElement):\n    def __init__(self):\n        with gr.Group():\n            self.optimizer_type = gr.Dropdown(label=\"optimizer\", choices=[\"AdamW\", \"Prodigy\"], interactive=True)\n\n            with gr.Group() as adam_optimizer_config_group:\n                self.adam_optimizer_config = AdamOptimizerConfigGroup()\n            self.adam_optimizer_config_group = adam_optimizer_config_group\n\n            with gr.Group() as prodigy_optimizer_config_group:\n                self.prodigy_optimizer_config = ProdigyOptimizerConfigGroup()\n            self.prodigy_optimizer_config_group = prodigy_optimizer_config_group\n\n        self.optimizer_type.change(\n            self._on_optimizer_type_change,\n            inputs=[self.optimizer_type],\n            outputs=[self.adam_optimizer_config_group, self.prodigy_optimizer_config_group],\n        )\n\n    def _on_optimizer_type_change(self, optimizer_type: str):\n        return {\n            self.adam_optimizer_config_group: gr.Group(visible=optimizer_type == \"AdamW\"),\n            self.prodigy_optimizer_config_group: gr.Group(visible=optimizer_type == \"Prodigy\"),\n        }\n\n    def update_ui_components_with_config_data(self, config: OptimizerConfig) -> dict[gr.components.Component, Any]:\n        update_dict = {\n            self.optimizer_type: config.optimizer_type,\n            self.adam_optimizer_config_group: gr.Group(visible=config.optimizer_type == \"AdamW\"),\n            self.prodigy_optimizer_config_group: gr.Group(visible=config.optimizer_type == \"Prodigy\"),\n        }\n\n        update_dict.update(\n            self.adam_optimizer_config.update_ui_components_with_config_data(\n                config if config.optimizer_type == \"AdamW\" else AdamOptimizerConfig()\n            )\n        )\n        update_dict.update(\n            self.prodigy_optimizer_config.update_ui_components_with_config_data(\n                config if config.optimizer_type == \"Prodigy\" else ProdigyOptimizerConfig()\n            )\n        )\n\n        return update_dict\n\n    def update_config_with_ui_component_data(self, orig_config: OptimizerConfig, ui_data: dict) -> OptimizerConfig:\n        # TODO: Use orig_config?\n\n        new_config_adam = self.adam_optimizer_config.update_config_with_ui_component_data(None, ui_data)\n        new_config_prodigy = self.prodigy_optimizer_config.update_config_with_ui_component_data(None, ui_data)\n\n        optimizer_type = ui_data.pop(self.optimizer_type)\n        if optimizer_type == \"AdamW\":\n            return new_config_adam\n        elif optimizer_type == \"Prodigy\":\n            return new_config_prodigy\n        else:\n            raise ValueError(f\"Invalid optimizer type: {optimizer_type}\")\n"
  },
  {
    "path": "src/invoke_training/ui/config_groups/sd_lora_config_group.py",
    "content": "import typing\n\nimport gradio as gr\n\nfrom invoke_training.pipelines.stable_diffusion.lora.config import SdLoraConfig\nfrom invoke_training.ui.config_groups.base_pipeline_config_group import BasePipelineConfigGroup\nfrom invoke_training.ui.config_groups.image_caption_sd_data_loader_config_group import (\n    ImageCaptionSDDataLoaderConfigGroup,\n)\nfrom invoke_training.ui.config_groups.optimizer_config_group import OptimizerConfigGroup\nfrom invoke_training.ui.config_groups.ui_config_element import UIConfigElement\nfrom invoke_training.ui.utils.prompts import (\n    convert_pos_neg_prompts_to_ui_prompts,\n    convert_ui_prompts_to_pos_neg_prompts,\n)\nfrom invoke_training.ui.utils.utils import get_typing_literal_options\n\n\nclass SdLoraConfigGroup(UIConfigElement):\n    def __init__(self):\n        \"\"\"The SD_LORA configs.\"\"\"\n\n        gr.Markdown(\"## Basic Configs\")\n        with gr.Row():\n            with gr.Column(scale=1):\n                with gr.Tab(\"Base Model\"):\n                    self.model = gr.Textbox(\n                        label=\"Model\",\n                        info=\"The base model. Can be a Hugging Face Hub model name, or a path to a local model (in \"\n                        \"diffusers or checkpoint format).\",\n                        type=\"text\",\n                        interactive=True,\n                    )\n                    self.hf_variant = gr.Textbox(\n                        label=\"Variant\",\n                        info=\"(optional) The Hugging Face hub model variant (e.g., fp16, fp32) to use if the model is a\"\n                        \" HF Hub model name.\",\n                        type=\"text\",\n                        interactive=True,\n                    )\n            with gr.Column(scale=3):\n                with gr.Tab(\"Training Outputs\"):\n                    self.base_pipeline_config_group = BasePipelineConfigGroup()\n                    self.max_checkpoints = gr.Number(\n                        label=\"Maximum Number of Checkpoints\",\n                        info=\"The maximum number of checkpoints to keep on disk from this training run. Earlier \"\n                        \"checkpoints will be deleted to respect this limit.\",\n                        interactive=True,\n                        precision=0,\n                    )\n\n        gr.Markdown(\"## Data Configs\")\n        self.image_caption_sd_data_loader_config_group = ImageCaptionSDDataLoaderConfigGroup()\n\n        gr.Markdown(\"## Optimizer Configs\")\n        self.optimizer_config_group = OptimizerConfigGroup()\n\n        gr.Markdown(\"## Speed / Memory Configs\")\n        with gr.Group():\n            with gr.Row():\n                self.gradient_accumulation_steps = gr.Number(\n                    label=\"Gradient Accumulation Steps\",\n                    info=\"The number of gradient steps to accumulate before each weight update. This is an\"\n                    \" alternative to increasing the batch size when training with limited VRAM. \"\n                    \"effective_batch_size = train_batch_size * gradient_accumulation_steps.\",\n                    precision=0,\n                    interactive=True,\n                )\n            with gr.Row():\n                self.weight_dtype = gr.Dropdown(\n                    label=\"Weight Type\",\n                    info=\"The precision of the model weights. Lower precision can speed up training and reduce memory, \"\n                    \"with increased risk of numerical stability issues. 'bfloat16' is recommended for most use cases \"\n                    \"if your GPU supports it.\",\n                    choices=get_typing_literal_options(SdLoraConfig, \"weight_dtype\"),\n                    interactive=True,\n                )\n            with gr.Row():\n                self.cache_text_encoder_outputs = gr.Checkbox(\n                    label=\"Cache Text Encoder Outputs\",\n                    info=\"Cache the text encoder outputs to increase speed. This should not be used when training the \"\n                    \"text encoder or performing data augmentations that would change the text encoder outputs.\",\n                    interactive=True,\n                )\n                self.cache_vae_outputs = gr.Checkbox(\n                    label=\"Cache VAE Outputs\",\n                    info=\"Cache the VAE outputs to increase speed. This should not be used when training the UNet or \"\n                    \"performing data augmentations that would change the VAE outputs.\",\n                    interactive=True,\n                )\n            with gr.Row():\n                self.enable_cpu_offload_during_validation = gr.Checkbox(\n                    label=\"Enable CPU Offload during Validation\",\n                    info=\"Offload models to the CPU sequentially during validation. This reduces peak VRAM \"\n                    \"requirements at the cost of slower validation during training.\",\n                    interactive=True,\n                )\n                self.gradient_checkpointing = gr.Checkbox(\n                    label=\"Gradient Checkpointing\",\n                    info=\"If True, VRAM requirements are reduced at the cost of ~20% slower training\",\n                    interactive=True,\n                )\n\n        gr.Markdown(\"## General Training Configs\")\n        with gr.Tab(\"Core\"):\n            with gr.Row():\n                self.train_unet = gr.Checkbox(label=\"Train UNet\", interactive=True)\n                self.train_text_encoder = gr.Checkbox(label=\"Train Text Encoder\", interactive=True)\n            with gr.Row():\n                self.unet_learning_rate = gr.Number(\n                    label=\"UNet Learning Rate\",\n                    info=\"The UNet learning rate. Set to 0 or leave empty to inherit from the base optimizer \"\n                    \"learning rate.\",\n                    interactive=True,\n                )\n                self.text_encoder_learning_rate = gr.Number(\n                    label=\"Text Encoder Learning Rate\",\n                    info=\"The text encoder learning rate. Set to 0 or leave empty to inherit from the base optimizer \"\n                    \"learning rate.\",\n                    interactive=True,\n                )\n            with gr.Row():\n                self.lr_scheduler = gr.Dropdown(\n                    label=\"Learning Rate Scheduler\",\n                    choices=get_typing_literal_options(SdLoraConfig, \"lr_scheduler\"),\n                    interactive=True,\n                )\n                self.lr_warmup_steps = gr.Number(\n                    label=\"Warmup Steps\",\n                    info=\"The number of warmup steps in the \"\n                    \"learning rate schedule, if applicable to the selected scheduler.\",\n                    interactive=True,\n                )\n            with gr.Row():\n                self.use_masks = gr.Checkbox(\n                    label=\"Use Masks\", info=\"This can only be enabled if the dataset contains masks.\", interactive=True\n                )\n\n        with gr.Tab(\"Advanced\"):\n            with gr.Column():\n                self.lora_rank_dim = gr.Number(\n                    label=\"LoRA Rank Dim\",\n                    info=\"The rank dimension to use for the LoRA layers. Increasing the rank dimension\"\n                    \" increases the model's expressivity, but also increases the size of the generated LoRA model.\",\n                    interactive=True,\n                    precision=0,\n                )\n                self.min_snr_gamma = gr.Number(\n                    label=\"Minumum SNR Gamma\",\n                    info=\"min_snr_gamma acts like an an upper bound on the weight of samples with low noise \"\n                    \"levels. If None, then Min-SNR weighting will not be applied. If enabled, the recommended \"\n                    \"value is min_snr gamma = 5.0.\",\n                    interactive=True,\n                )\n                self.max_grad_norm = gr.Number(\n                    label=\"Max Gradient Norm\",\n                    info=\"Max gradient norm for clipping. Set to 0 or leave empty for no clipping (null).\",\n                    interactive=True,\n                )\n                self.train_batch_size = gr.Number(\n                    label=\"Batch Size\",\n                    info=\"The Training Batch Size - Higher values require increasing amounts of VRAM.\",\n                    precision=0,\n                    interactive=True,\n                )\n\n        gr.Markdown(\"## Validation\")\n        with gr.Group():\n            self.validation_prompts = gr.Textbox(\n                label=\"Validation Prompts\",\n                info=\"Enter one validation prompt per line. Optionally, add negative prompts after a '[NEG]' \"\n                \"delimiter. For example: `positive prompt[NEG]negative prompt`. \",\n                lines=5,\n                interactive=True,\n            )\n            self.num_validation_images_per_prompt = gr.Number(\n                label=\"# of Validation Images to Generate per Prompt\", precision=0, interactive=True\n            )\n\n    def update_ui_components_with_config_data(self, config: SdLoraConfig) -> dict[gr.components.Component, typing.Any]:\n        update_dict = {\n            self.model: config.model,\n            self.hf_variant: config.hf_variant,\n            self.max_checkpoints: config.max_checkpoints,\n            self.train_unet: config.train_unet,\n            self.unet_learning_rate: config.unet_learning_rate,\n            self.train_text_encoder: config.train_text_encoder,\n            self.text_encoder_learning_rate: config.text_encoder_learning_rate,\n            self.lr_scheduler: config.lr_scheduler,\n            self.lr_warmup_steps: config.lr_warmup_steps,\n            self.use_masks: config.use_masks,\n            self.max_grad_norm: config.max_grad_norm,\n            self.train_batch_size: config.train_batch_size,\n            self.cache_text_encoder_outputs: config.cache_text_encoder_outputs,\n            self.cache_vae_outputs: config.cache_vae_outputs,\n            self.enable_cpu_offload_during_validation: config.enable_cpu_offload_during_validation,\n            self.gradient_accumulation_steps: config.gradient_accumulation_steps,\n            self.weight_dtype: config.weight_dtype,\n            self.gradient_checkpointing: config.gradient_checkpointing,\n            self.lora_rank_dim: config.lora_rank_dim,\n            self.min_snr_gamma: config.min_snr_gamma,\n            self.validation_prompts: convert_pos_neg_prompts_to_ui_prompts(\n                config.validation_prompts, config.negative_validation_prompts\n            ),\n            self.num_validation_images_per_prompt: config.num_validation_images_per_prompt,\n        }\n        update_dict.update(\n            self.image_caption_sd_data_loader_config_group.update_ui_components_with_config_data(config.data_loader)\n        )\n        update_dict.update(self.base_pipeline_config_group.update_ui_components_with_config_data(config))\n        update_dict.update(self.optimizer_config_group.update_ui_components_with_config_data(config.optimizer))\n\n        # Sanity check to catch if we accidentally forget to update a UI component.\n        assert set(update_dict.keys()) == set(self.get_ui_output_components())\n\n        return update_dict\n\n    def update_config_with_ui_component_data(\n        self, orig_config: SdLoraConfig, ui_data: dict[gr.components.Component, typing.Any]\n    ) -> SdLoraConfig:\n        new_config = orig_config.model_copy(deep=True)\n\n        new_config.model = ui_data.pop(self.model)\n        new_config.hf_variant = ui_data.pop(self.hf_variant) or None\n        new_config.max_checkpoints = ui_data.pop(self.max_checkpoints)\n        new_config.train_unet = ui_data.pop(self.train_unet)\n        unet_lr_value = ui_data.pop(self.unet_learning_rate)\n        new_config.unet_learning_rate = None if unet_lr_value == 0 else unet_lr_value\n        new_config.train_text_encoder = ui_data.pop(self.train_text_encoder)\n        text_encoder_lr_value = ui_data.pop(self.text_encoder_learning_rate)\n        new_config.text_encoder_learning_rate = None if text_encoder_lr_value == 0 else text_encoder_lr_value\n        new_config.lr_scheduler = ui_data.pop(self.lr_scheduler)\n        new_config.lr_warmup_steps = ui_data.pop(self.lr_warmup_steps)\n        new_config.use_masks = ui_data.pop(self.use_masks)\n        max_grad_norm_value = ui_data.pop(self.max_grad_norm)\n        new_config.max_grad_norm = None if max_grad_norm_value == 0 else max_grad_norm_value\n        new_config.train_batch_size = ui_data.pop(self.train_batch_size)\n        new_config.cache_text_encoder_outputs = ui_data.pop(self.cache_text_encoder_outputs)\n        new_config.cache_vae_outputs = ui_data.pop(self.cache_vae_outputs)\n        new_config.enable_cpu_offload_during_validation = ui_data.pop(self.enable_cpu_offload_during_validation)\n        new_config.gradient_accumulation_steps = ui_data.pop(self.gradient_accumulation_steps)\n        new_config.weight_dtype = ui_data.pop(self.weight_dtype)\n        new_config.gradient_checkpointing = ui_data.pop(self.gradient_checkpointing)\n        new_config.lora_rank_dim = ui_data.pop(self.lora_rank_dim)\n        new_config.min_snr_gamma = ui_data.pop(self.min_snr_gamma)\n        new_config.num_validation_images_per_prompt = ui_data.pop(self.num_validation_images_per_prompt)\n\n        positive_prompts, negative_prompts = convert_ui_prompts_to_pos_neg_prompts(ui_data.pop(self.validation_prompts))\n        new_config.validation_prompts = positive_prompts\n        new_config.negative_validation_prompts = negative_prompts\n\n        new_config.data_loader = self.image_caption_sd_data_loader_config_group.update_config_with_ui_component_data(\n            new_config.data_loader, ui_data\n        )\n        new_config = self.base_pipeline_config_group.update_config_with_ui_component_data(new_config, ui_data)\n        new_config.optimizer = self.optimizer_config_group.update_config_with_ui_component_data(\n            new_config.optimizer, ui_data\n        )\n\n        # We pop items from ui_data as we use them so that we can sanity check that all the input data was transferred\n        # to the config.\n        assert len(ui_data) == 0\n\n        return new_config\n"
  },
  {
    "path": "src/invoke_training/ui/config_groups/sd_textual_inversion_config_group.py",
    "content": "import typing\n\nimport gradio as gr\n\nfrom invoke_training.pipelines.stable_diffusion.textual_inversion.config import SdTextualInversionConfig\nfrom invoke_training.ui.config_groups.base_pipeline_config_group import BasePipelineConfigGroup\nfrom invoke_training.ui.config_groups.optimizer_config_group import OptimizerConfigGroup\nfrom invoke_training.ui.config_groups.textual_inversion_sd_data_loader_config_group import (\n    TextualInversionSDDataLoaderConfigGroup,\n)\nfrom invoke_training.ui.config_groups.ui_config_element import UIConfigElement\nfrom invoke_training.ui.utils.prompts import (\n    convert_pos_neg_prompts_to_ui_prompts,\n    convert_ui_prompts_to_pos_neg_prompts,\n)\nfrom invoke_training.ui.utils.utils import get_typing_literal_options\n\n\nclass SdTextualInversionConfigGroup(UIConfigElement):\n    def __init__(self):\n        \"\"\"The SD_TEXTUAL_INVERSION configs.\"\"\"\n\n        gr.Markdown(\"## Basic Configs\")\n        with gr.Row():\n            with gr.Column(scale=1):\n                with gr.Tab(\"Base Model\"):\n                    self.model = gr.Textbox(\n                        label=\"Model\",\n                        info=\"The base model. Can be a Hugging Face Hub model name, or a path to a local model (in \"\n                        \"diffusers or checkpoint format).\",\n                        type=\"text\",\n                        interactive=True,\n                    )\n                    self.hf_variant = gr.Textbox(\n                        label=\"Variant\",\n                        info=\"(optional) The Hugging Face hub model variant (e.g., fp16, fp32) to use if the model is a\"\n                        \" HF Hub model name.\",\n                        type=\"text\",\n                        interactive=True,\n                    )\n            with gr.Column(scale=3):\n                with gr.Tab(\"Training Outputs\"):\n                    self.base_pipeline_config_group = BasePipelineConfigGroup()\n                    self.max_checkpoints = gr.Number(\n                        label=\"Maximum Number of Checkpoints\",\n                        info=\"The maximum number of checkpoints to keep on disk from this training run. Earlier \"\n                        \"checkpoints will be deleted to respect this limit.\",\n                        interactive=True,\n                        precision=0,\n                    )\n\n        gr.Markdown(\"## Data Configs\")\n        self.textual_inversion_sd_data_loader_config_group = TextualInversionSDDataLoaderConfigGroup()\n\n        gr.Markdown(\"## Textual Inversion Configs\")\n        self.num_vectors = gr.Number(\n            label=\"Num Vectors\",\n            info=\"The number of TI vectors that will be trained. Can be overriden by 'Initial Phrase'.\",\n            interactive=True,\n            precision=0,\n        )\n        self.placeholder_token = gr.Textbox(\n            label=\"Placeholder Token\",\n            info=\"The special word to associate the learned embeddings with. Choose a unique token that is unlikely to \"\n            \"already exist in the tokenizer's vocabulary.\",\n            interactive=True,\n        )\n        self.initializer_token = gr.Textbox(\n            label=\"Initializer Token\",\n            info=\"Only one of 'Initializer Token' or 'Initial Phrase' should be set. A vocabulary token to use as an \"\n            \"initializer for the placeholder token. It should be a single word that roughly describes the object or \"\n            \"style that you're trying to train on. The initializer token ust map to a single tokenizer token.\",\n            interactive=True,\n        )\n        self.initial_phrase = gr.Textbox(\n            label=\"Initial Phrase\",\n            info=\"Only one of 'Initializer Token' or 'Initial Phrase' should be set. A phrase that will be used to \"\n            \"initialize the placeholder token embedding. The phrase will be tokenized, and the corresponding \"\n            \"embeddings will be used to initialize the placeholder tokens. The number of embedding vectors will be \"\n            \"inferred from the length of the tokenized phrase, so keep the phrase short.\",\n            interactive=True,\n        )\n\n        gr.Markdown(\"## Optimizer Configs\")\n        self.optimizer_config_group = OptimizerConfigGroup()\n\n        gr.Markdown(\"## Speed / Memory Configs\")\n        with gr.Group():\n            with gr.Row():\n                self.gradient_accumulation_steps = gr.Number(\n                    label=\"Gradient Accumulation Steps\",\n                    info=\"The number of gradient steps to accumulate before each weight update. This is an\"\n                    \" alternative to increasing the batch size when training with limited VRAM.\"\n                    \"effective_batch_size = train_batch_size * gradient_accumulation_steps.\",\n                    precision=0,\n                    interactive=True,\n                )\n            with gr.Row():\n                self.weight_dtype = gr.Dropdown(\n                    label=\"Weight Type\",\n                    info=\"The precision of the model weights. Lower precision can speed up training and reduce memory, \"\n                    \"with increased risk of numerical stability issues. 'bfloat16' is recommended for most use cases \"\n                    \"if your GPU supports it.\",\n                    choices=get_typing_literal_options(SdTextualInversionConfig, \"weight_dtype\"),\n                    interactive=True,\n                )\n            with gr.Row():\n                self.cache_vae_outputs = gr.Checkbox(\n                    label=\"Cache VAE Outputs\",\n                    info=\"Cache the VAE outputs to increase speed. This should not be used when training the UNet or \"\n                    \"performing data augmentations that would change the VAE outputs.\",\n                    interactive=True,\n                )\n            with gr.Row():\n                self.enable_cpu_offload_during_validation = gr.Checkbox(\n                    label=\"Enable CPU Offload during Validation\",\n                    info=\"Offload models to the CPU sequentially during validation. This reduces peak VRAM \"\n                    \"requirements at the cost of slower validation during training.\",\n                    interactive=True,\n                )\n                self.gradient_checkpointing = gr.Checkbox(\n                    label=\"Gradient Checkpointing\",\n                    info=\"If True, VRAM requirements are reduced at the cost of ~20% slower training\",\n                    interactive=True,\n                )\n\n        gr.Markdown(\"## General Training Configs\")\n        with gr.Tab(\"Core\"):\n            with gr.Row():\n                self.lr_scheduler = gr.Dropdown(\n                    label=\"Learning Rate Scheduler\",\n                    choices=get_typing_literal_options(SdTextualInversionConfig, \"lr_scheduler\"),\n                    interactive=True,\n                )\n                self.lr_warmup_steps = gr.Number(\n                    label=\"Warmup Steps\",\n                    info=\"The number of warmup steps in the \"\n                    \"learning rate schedule, if applicable to the selected scheduler.\",\n                    interactive=True,\n                )\n            with gr.Row():\n                self.use_masks = gr.Checkbox(\n                    label=\"Use Masks\", info=\"This can only be enabled if the dataset contains masks.\", interactive=True\n                )\n\n        with gr.Tab(\"Advanced\"):\n            with gr.Column():\n                self.min_snr_gamma = gr.Number(\n                    label=\"Minumum SNR Gamma\",\n                    info=\"min_snr_gamma acts like an an upper bound on the weight of samples with low noise \"\n                    \"levels. If None, then Min-SNR weighting will not be applied. If enabled, the recommended \"\n                    \"value is min_snr gamma = 5.0.\",\n                    interactive=True,\n                )\n                self.max_grad_norm = gr.Number(\n                    label=\"Max Gradient Norm\",\n                    info=\"Max gradient norm for clipping. Set to 0 or leave empty for no clipping (null).\",\n                    interactive=True,\n                )\n                self.train_batch_size = gr.Number(\n                    label=\"Batch Size\",\n                    info=\"The Training Batch Size - Higher values require increasing amounts of VRAM.\",\n                    precision=0,\n                    interactive=True,\n                )\n\n        gr.Markdown(\"## Validation\")\n        with gr.Group():\n            self.validation_prompts = gr.Textbox(\n                label=\"Validation Prompts\",\n                info=\"Enter one validation prompt per line. Optionally, add negative prompts after a '[NEG]' \"\n                \"delimiter. For example: `positive prompt[NEG]negative prompt`. \",\n                lines=5,\n                interactive=True,\n            )\n            self.num_validation_images_per_prompt = gr.Number(\n                label=\"# of Validation Images to Generate per Prompt\", precision=0, interactive=True\n            )\n\n    def update_ui_components_with_config_data(\n        self, config: SdTextualInversionConfig\n    ) -> dict[gr.components.Component, typing.Any]:\n        update_dict = {\n            self.model: config.model,\n            self.hf_variant: config.hf_variant,\n            self.num_vectors: config.num_vectors,\n            self.placeholder_token: config.placeholder_token,\n            self.initializer_token: config.initializer_token,\n            self.initial_phrase: config.initial_phrase,\n            self.max_checkpoints: config.max_checkpoints,\n            self.lr_scheduler: config.lr_scheduler,\n            self.lr_warmup_steps: config.lr_warmup_steps,\n            self.use_masks: config.use_masks,\n            self.max_grad_norm: config.max_grad_norm,\n            self.train_batch_size: config.train_batch_size,\n            self.cache_vae_outputs: config.cache_vae_outputs,\n            self.enable_cpu_offload_during_validation: config.enable_cpu_offload_during_validation,\n            self.gradient_accumulation_steps: config.gradient_accumulation_steps,\n            self.weight_dtype: config.weight_dtype,\n            self.gradient_checkpointing: config.gradient_checkpointing,\n            self.min_snr_gamma: config.min_snr_gamma,\n            self.validation_prompts: convert_pos_neg_prompts_to_ui_prompts(\n                config.validation_prompts, config.negative_validation_prompts\n            ),\n            self.num_validation_images_per_prompt: config.num_validation_images_per_prompt,\n        }\n        update_dict.update(\n            self.textual_inversion_sd_data_loader_config_group.update_ui_components_with_config_data(config.data_loader)\n        )\n        update_dict.update(self.base_pipeline_config_group.update_ui_components_with_config_data(config))\n        update_dict.update(self.optimizer_config_group.update_ui_components_with_config_data(config.optimizer))\n\n        # Sanity check to catch if we accidentally forget to update a UI component.\n        assert set(update_dict.keys()) == set(self.get_ui_output_components())\n\n        return update_dict\n\n    def update_config_with_ui_component_data(\n        self, orig_config: SdTextualInversionConfig, ui_data: dict[gr.components.Component, typing.Any]\n    ) -> SdTextualInversionConfig:\n        new_config = orig_config.model_copy(deep=True)\n\n        new_config.model = ui_data.pop(self.model)\n        new_config.hf_variant = ui_data.pop(self.hf_variant) or None\n        new_config.num_vectors = ui_data.pop(self.num_vectors)\n        new_config.placeholder_token = ui_data.pop(self.placeholder_token)\n        new_config.initializer_token = ui_data.pop(self.initializer_token) or None\n        new_config.initial_phrase = ui_data.pop(self.initial_phrase) or None\n        new_config.max_checkpoints = ui_data.pop(self.max_checkpoints)\n        new_config.lr_scheduler = ui_data.pop(self.lr_scheduler)\n        new_config.lr_warmup_steps = ui_data.pop(self.lr_warmup_steps)\n        new_config.use_masks = ui_data.pop(self.use_masks)\n        max_grad_norm_value = ui_data.pop(self.max_grad_norm)\n        new_config.max_grad_norm = None if max_grad_norm_value == 0 else max_grad_norm_value\n        new_config.train_batch_size = ui_data.pop(self.train_batch_size)\n        new_config.cache_vae_outputs = ui_data.pop(self.cache_vae_outputs)\n        new_config.enable_cpu_offload_during_validation = ui_data.pop(self.enable_cpu_offload_during_validation)\n        new_config.gradient_accumulation_steps = ui_data.pop(self.gradient_accumulation_steps)\n        new_config.weight_dtype = ui_data.pop(self.weight_dtype)\n        new_config.gradient_checkpointing = ui_data.pop(self.gradient_checkpointing)\n        new_config.min_snr_gamma = ui_data.pop(self.min_snr_gamma)\n        new_config.num_validation_images_per_prompt = ui_data.pop(self.num_validation_images_per_prompt)\n\n        positive_prompts, negative_prompts = convert_ui_prompts_to_pos_neg_prompts(ui_data.pop(self.validation_prompts))\n        new_config.validation_prompts = positive_prompts\n        new_config.negative_validation_prompts = negative_prompts\n\n        new_config.data_loader = (\n            self.textual_inversion_sd_data_loader_config_group.update_config_with_ui_component_data(\n                new_config.data_loader, ui_data\n            )\n        )\n        new_config = self.base_pipeline_config_group.update_config_with_ui_component_data(new_config, ui_data)\n        new_config.optimizer = self.optimizer_config_group.update_config_with_ui_component_data(\n            new_config.optimizer, ui_data\n        )\n\n        # We pop items from ui_data as we use them so that we can sanity check that all the input data was transferred\n        # to the config.\n        assert len(ui_data) == 0\n\n        return new_config\n"
  },
  {
    "path": "src/invoke_training/ui/config_groups/sdxl_finetune_config_group.py",
    "content": "import typing\n\nimport gradio as gr\n\nfrom invoke_training.pipelines.stable_diffusion_xl.finetune.config import SdxlFinetuneConfig\nfrom invoke_training.ui.config_groups.base_pipeline_config_group import BasePipelineConfigGroup\nfrom invoke_training.ui.config_groups.image_caption_sd_data_loader_config_group import (\n    ImageCaptionSDDataLoaderConfigGroup,\n)\nfrom invoke_training.ui.config_groups.optimizer_config_group import OptimizerConfigGroup\nfrom invoke_training.ui.config_groups.ui_config_element import UIConfigElement\nfrom invoke_training.ui.utils.prompts import (\n    convert_pos_neg_prompts_to_ui_prompts,\n    convert_ui_prompts_to_pos_neg_prompts,\n)\nfrom invoke_training.ui.utils.utils import get_typing_literal_options\n\n\nclass SdxlFinetuneConfigGroup(UIConfigElement):\n    def __init__(self):\n        \"\"\"The SDXL_FINETUNE configs.\"\"\"\n\n        gr.Markdown(\"## Basic Configs\")\n        with gr.Row():\n            with gr.Column(scale=1):\n                with gr.Tab(\"Base Model\"):\n                    self.model = gr.Textbox(\n                        label=\"Model\",\n                        info=\"The base model. Can be a Hugging Face Hub model name, or a path to a local model (in \"\n                        \"diffusers or checkpoint format).\",\n                        type=\"text\",\n                        interactive=True,\n                    )\n                    self.hf_variant = gr.Textbox(\n                        label=\"Variant\",\n                        info=\"(optional) The Hugging Face hub model variant (e.g., fp16, fp32) to use if the model is a\"\n                        \" HF Hub model name.\",\n                        type=\"text\",\n                        interactive=True,\n                    )\n                    self.vae_model = gr.Textbox(\n                        label=\"VAE Model\",\n                        info=\"(optional) If set, this overrides the base model's default VAE model.\",\n                        type=\"text\",\n                        interactive=True,\n                    )\n            with gr.Column(scale=3):\n                with gr.Tab(\"Training Outputs\"):\n                    self.base_pipeline_config_group = BasePipelineConfigGroup()\n                    self.save_checkpoint_format = gr.Dropdown(\n                        label=\"Checkpoint Format\",\n                        info=\"The save format for the checkpoints. `full_diffusers` saves the full model in diffusers \"\n                        \"format. `trained_only_diffusers` saves only the parts of the model that were finetuned \"\n                        \"(i.e. the UNet).\",\n                        choices=get_typing_literal_options(SdxlFinetuneConfig, \"save_checkpoint_format\"),\n                        interactive=True,\n                    )\n                    self.save_dtype = gr.Dropdown(\n                        label=\"Save Dtype\",\n                        info=\"The dtype to use when saving the model.\",\n                        choices=get_typing_literal_options(SdxlFinetuneConfig, \"save_dtype\"),\n                        interactive=True,\n                    )\n                    self.max_checkpoints = gr.Number(\n                        label=\"Maximum Number of Checkpoints\",\n                        info=\"The maximum number of checkpoints to keep on disk from this training run. Earlier \"\n                        \"checkpoints will be deleted to respect this limit.\",\n                        interactive=True,\n                        precision=0,\n                    )\n\n        gr.Markdown(\"## Data Configs\")\n        self.image_caption_sd_data_loader_config_group = ImageCaptionSDDataLoaderConfigGroup()\n\n        gr.Markdown(\"## Optimizer Configs\")\n        self.optimizer_config_group = OptimizerConfigGroup()\n\n        gr.Markdown(\"## Speed / Memory Configs\")\n        with gr.Group():\n            with gr.Row():\n                self.gradient_accumulation_steps = gr.Number(\n                    label=\"Gradient Accumulation Steps\",\n                    info=\"The number of gradient steps to accumulate before each weight update. This is an alternative\"\n                    \"to increasing the batch size when training with limited VRAM.\"\n                    \"effective_batch_size = train_batch_size * gradient_accumulation_steps.\",\n                    precision=0,\n                    interactive=True,\n                )\n            with gr.Row():\n                self.weight_dtype = gr.Dropdown(\n                    label=\"Weight Type\",\n                    info=\"The precision of the model weights. Lower precision can speed up training and reduce memory, \"\n                    \"with increased risk of numerical stability issues. 'bfloat16' is recommended for most use cases \"\n                    \"if your GPU supports it.\",\n                    choices=get_typing_literal_options(SdxlFinetuneConfig, \"weight_dtype\"),\n                    interactive=True,\n                )\n            with gr.Row():\n                self.cache_text_encoder_outputs = gr.Checkbox(\n                    label=\"Cache Text Encoder Outputs\",\n                    info=\"Cache the text encoder outputs to increase speed. This should not be used when training the \"\n                    \"text encoder or performing data augmentations that would change the text encoder outputs.\",\n                    interactive=True,\n                )\n                self.cache_vae_outputs = gr.Checkbox(\n                    label=\"Cache VAE Outputs\",\n                    info=\"Cache the VAE outputs to increase speed. This should not be used when training the UNet or \"\n                    \"performing data augmentations that would change the VAE outputs.\",\n                    interactive=True,\n                )\n            with gr.Row():\n                self.enable_cpu_offload_during_validation = gr.Checkbox(\n                    label=\"Enable CPU Offload during Validation\",\n                    info=\"Offload models to the CPU sequentially during validation. This reduces peak VRAM \"\n                    \"requirements at the cost of slower validation during training.\",\n                    interactive=True,\n                )\n                self.gradient_checkpointing = gr.Checkbox(\n                    label=\"Gradient Checkpointing\",\n                    info=\"If True, VRAM requirements are reduced at the cost of ~20% slower training\",\n                    interactive=True,\n                )\n\n        gr.Markdown(\"## General Training Configs\")\n        with gr.Tab(\"Core\"):\n            with gr.Row():\n                self.lr_scheduler = gr.Dropdown(\n                    label=\"Learning Rate Scheduler\",\n                    choices=get_typing_literal_options(SdxlFinetuneConfig, \"lr_scheduler\"),\n                    interactive=True,\n                )\n                self.lr_warmup_steps = gr.Number(\n                    label=\"Warmup Steps\",\n                    info=\"The number of warmup steps in the \"\n                    \"learning rate schedule, if applicable to the selected scheduler.\",\n                    interactive=True,\n                )\n            with gr.Row():\n                self.use_masks = gr.Checkbox(\n                    label=\"Use Masks\", info=\"This can only be enabled if the dataset contains masks.\", interactive=True\n                )\n\n        with gr.Tab(\"Advanced\"):\n            with gr.Row():\n                self.min_snr_gamma = gr.Number(\n                    label=\"Minimum SNR Gamma\",\n                    info=\"min_snr_gamma acts like an an upper bound on the weight of samples with low noise \"\n                    \"levels. If None, then Min-SNR weighting will not be applied. If enabled, the recommended \"\n                    \"value is min_snr gamma = 5.0.\",\n                    interactive=True,\n                )\n                self.max_grad_norm = gr.Number(\n                    label=\"Max Gradient Norm\",\n                    info=\"Max gradient norm for clipping. Set to 0 or leave empty for no clipping (null).\",\n                    interactive=True,\n                )\n                self.train_batch_size = gr.Number(\n                    label=\"Batch Size\",\n                    info=\"The Training Batch Size - Higher values require increasing amounts of VRAM.\",\n                    precision=0,\n                    interactive=True,\n                )\n\n        gr.Markdown(\"## Validation\")\n        with gr.Group():\n            self.validation_prompts = gr.Textbox(\n                label=\"Validation Prompts\",\n                info=\"Enter one validation prompt per line. Optionally, add negative prompts after a '[NEG]' \"\n                \"delimiter. For example: `positive prompt[NEG]negative prompt`. \",\n                lines=5,\n                interactive=True,\n            )\n            self.num_validation_images_per_prompt = gr.Number(\n                label=\"# of Validation Images to Generate per Prompt\", precision=0, interactive=True\n            )\n\n    def update_ui_components_with_config_data(\n        self, config: SdxlFinetuneConfig\n    ) -> dict[gr.components.Component, typing.Any]:\n        update_dict = {\n            self.model: config.model,\n            self.hf_variant: config.hf_variant,\n            self.vae_model: config.vae_model,\n            self.save_checkpoint_format: config.save_checkpoint_format,\n            self.save_dtype: config.save_dtype,\n            self.max_checkpoints: config.max_checkpoints,\n            self.lr_scheduler: config.lr_scheduler,\n            self.lr_warmup_steps: config.lr_warmup_steps,\n            self.use_masks: config.use_masks,\n            self.min_snr_gamma: config.min_snr_gamma,\n            self.max_grad_norm: config.max_grad_norm,\n            self.train_batch_size: config.train_batch_size,\n            self.cache_text_encoder_outputs: config.cache_text_encoder_outputs,\n            self.cache_vae_outputs: config.cache_vae_outputs,\n            self.enable_cpu_offload_during_validation: config.enable_cpu_offload_during_validation,\n            self.gradient_accumulation_steps: config.gradient_accumulation_steps,\n            self.weight_dtype: config.weight_dtype,\n            self.gradient_checkpointing: config.gradient_checkpointing,\n            self.validation_prompts: convert_pos_neg_prompts_to_ui_prompts(\n                config.validation_prompts, config.negative_validation_prompts\n            ),\n            self.num_validation_images_per_prompt: config.num_validation_images_per_prompt,\n        }\n        update_dict.update(\n            self.image_caption_sd_data_loader_config_group.update_ui_components_with_config_data(config.data_loader)\n        )\n        update_dict.update(self.base_pipeline_config_group.update_ui_components_with_config_data(config))\n        update_dict.update(self.optimizer_config_group.update_ui_components_with_config_data(config.optimizer))\n\n        # Sanity check to catch if we accidentally forget to update a UI component.\n        assert set(update_dict.keys()) == set(self.get_ui_output_components())\n\n        return update_dict\n\n    def update_config_with_ui_component_data(\n        self, orig_config: SdxlFinetuneConfig, ui_data: dict[gr.components.Component, typing.Any]\n    ) -> SdxlFinetuneConfig:\n        new_config = orig_config.model_copy(deep=True)\n\n        new_config.model = ui_data.pop(self.model)\n        new_config.hf_variant = ui_data.pop(self.hf_variant) or None\n        new_config.vae_model = ui_data.pop(self.vae_model) or None\n        new_config.save_checkpoint_format = ui_data.pop(self.save_checkpoint_format)\n        new_config.save_dtype = ui_data.pop(self.save_dtype)\n        new_config.max_checkpoints = ui_data.pop(self.max_checkpoints)\n        new_config.lr_scheduler = ui_data.pop(self.lr_scheduler)\n        new_config.lr_warmup_steps = ui_data.pop(self.lr_warmup_steps)\n        new_config.use_masks = ui_data.pop(self.use_masks)\n        new_config.min_snr_gamma = ui_data.pop(self.min_snr_gamma)\n        max_grad_norm_value = ui_data.pop(self.max_grad_norm)\n        new_config.max_grad_norm = None if max_grad_norm_value == 0 else max_grad_norm_value\n        new_config.train_batch_size = ui_data.pop(self.train_batch_size)\n        new_config.cache_text_encoder_outputs = ui_data.pop(self.cache_text_encoder_outputs)\n        new_config.cache_vae_outputs = ui_data.pop(self.cache_vae_outputs)\n        new_config.enable_cpu_offload_during_validation = ui_data.pop(self.enable_cpu_offload_during_validation)\n        new_config.gradient_accumulation_steps = ui_data.pop(self.gradient_accumulation_steps)\n        new_config.weight_dtype = ui_data.pop(self.weight_dtype)\n        new_config.gradient_checkpointing = ui_data.pop(self.gradient_checkpointing)\n        new_config.num_validation_images_per_prompt = ui_data.pop(self.num_validation_images_per_prompt)\n\n        positive_prompts, negative_prompts = convert_ui_prompts_to_pos_neg_prompts(ui_data.pop(self.validation_prompts))\n        new_config.validation_prompts = positive_prompts\n        new_config.negative_validation_prompts = negative_prompts\n\n        new_config.data_loader = self.image_caption_sd_data_loader_config_group.update_config_with_ui_component_data(\n            new_config.data_loader, ui_data\n        )\n        new_config = self.base_pipeline_config_group.update_config_with_ui_component_data(new_config, ui_data)\n        new_config.optimizer = self.optimizer_config_group.update_config_with_ui_component_data(\n            new_config.optimizer, ui_data\n        )\n\n        # We pop items from ui_data as we use them so that we can sanity check that all the input data was transferred\n        # to the config.\n        assert len(ui_data) == 0\n\n        return new_config\n"
  },
  {
    "path": "src/invoke_training/ui/config_groups/sdxl_lora_and_textual_inversion_config_group.py",
    "content": "import typing\n\nimport gradio as gr\n\nfrom invoke_training.pipelines.stable_diffusion_xl.lora_and_textual_inversion.config import (\n    SdxlLoraAndTextualInversionConfig,\n)\nfrom invoke_training.ui.config_groups.base_pipeline_config_group import BasePipelineConfigGroup\nfrom invoke_training.ui.config_groups.optimizer_config_group import OptimizerConfigGroup\nfrom invoke_training.ui.config_groups.textual_inversion_sd_data_loader_config_group import (\n    TextualInversionSDDataLoaderConfigGroup,\n)\nfrom invoke_training.ui.config_groups.ui_config_element import UIConfigElement\nfrom invoke_training.ui.utils.prompts import (\n    convert_pos_neg_prompts_to_ui_prompts,\n    convert_ui_prompts_to_pos_neg_prompts,\n)\nfrom invoke_training.ui.utils.utils import get_typing_literal_options\n\n\nclass SdxlLoraAndTextualInversionConfigGroup(UIConfigElement):\n    def __init__(self):\n        \"\"\"The SDXL_LORA_AND_TEXTUAL_INVERSION configs.\"\"\"\n\n        gr.Markdown(\"## Basic Configs\")\n        with gr.Row():\n            with gr.Column(scale=1):\n                with gr.Tab(\"Base Model\"):\n                    self.model = gr.Textbox(\n                        label=\"Model\",\n                        info=\"The base model. Can be a Hugging Face Hub model name, or a path to a local model (in \"\n                        \"diffusers or checkpoint format).\",\n                        type=\"text\",\n                        interactive=True,\n                    )\n                    self.hf_variant = gr.Textbox(\n                        label=\"Variant\",\n                        info=\"(optional) The Hugging Face hub model variant (e.g., fp16, fp32) to use if the model is a\"\n                        \" HF Hub model name.\",\n                        type=\"text\",\n                        interactive=True,\n                    )\n                    self.vae_model = gr.Textbox(\n                        label=\"VAE Model\",\n                        info=\"(optional) If set, this overrides the base model's default VAE model.\",\n                        type=\"text\",\n                        interactive=True,\n                    )\n            with gr.Column(scale=3):\n                with gr.Tab(\"Training Outputs\"):\n                    self.base_pipeline_config_group = BasePipelineConfigGroup()\n                    self.max_checkpoints = gr.Number(\n                        label=\"Maximum Number of Checkpoints\",\n                        info=\"The maximum number of checkpoints to keep on disk from this training run. Earlier \"\n                        \"checkpoints will be deleted to respect this limit.\",\n                        interactive=True,\n                        precision=0,\n                    )\n\n        gr.Markdown(\"## Data Configs\")\n        self.image_caption_sd_data_loader_config_group = TextualInversionSDDataLoaderConfigGroup()\n\n        gr.Markdown(\"## Textual Inversion Configs\")\n        self.num_vectors = gr.Number(\n            label=\"Num Vectors\",\n            info=\"The number of TI vectors that will be trained. Can be overriden by 'Initial Phrase'.\",\n            interactive=True,\n            precision=0,\n        )\n        self.placeholder_token = gr.Textbox(\n            label=\"Placeholder Token\",\n            info=\"The special word to associate the learned embeddings with. Choose a unique token that is unlikely to \"\n            \"already exist in the tokenizer's vocabulary.\",\n            interactive=True,\n        )\n        self.initializer_token = gr.Textbox(\n            label=\"Initializer Token\",\n            info=\"Only one of 'Initializer Token' or 'Initial Phrase' should be set. A vocabulary token to use as an \"\n            \"initializer for the placeholder token. It should be a single word that roughly describes the object or \"\n            \"style that you're trying to train on. The initializer token ust map to a single tokenizer token.\",\n            interactive=True,\n        )\n        self.initial_phrase = gr.Textbox(\n            label=\"Initial Phrase\",\n            info=\"Only one of 'Initializer Token' or 'Initial Phrase' should be set. A phrase that will be used to \"\n            \"initialize the placeholder token embedding. The phrase will be tokenized, and the corresponding \"\n            \"embeddings will be used to initialize the placeholder tokens. The number of embedding vectors will be \"\n            \"inferred from the length of the tokenized phrase, so keep the phrase short.\",\n            interactive=True,\n        )\n\n        gr.Markdown(\"## Optimizer Configs\")\n        self.optimizer_config_group = OptimizerConfigGroup()\n\n        gr.Markdown(\"## Speed / Memory Configs\")\n        with gr.Group():\n            with gr.Row():\n                self.gradient_accumulation_steps = gr.Number(\n                    label=\"Gradient Accumulation Steps\",\n                    info=\"The number of gradient steps to accumulate before each weight update. This is an alternative\"\n                    \"to increasing the batch size when training with limited VRAM.\"\n                    \"effective_batch_size = train_batch_size * gradient_accumulation_steps.\",\n                    precision=0,\n                    interactive=True,\n                )\n            with gr.Row():\n                self.weight_dtype = gr.Dropdown(\n                    label=\"Weight Type\",\n                    info=\"The precision of the model weights. Lower precision can speed up training and reduce memory, \"\n                    \"with increased risk of numerical stability issues. 'bfloat16' is recommended for most use cases \"\n                    \"if your GPU supports it.\",\n                    choices=get_typing_literal_options(SdxlLoraAndTextualInversionConfig, \"weight_dtype\"),\n                    interactive=True,\n                )\n            with gr.Row():\n                self.cache_text_encoder_outputs = gr.Checkbox(\n                    label=\"Cache Text Encoder Outputs\",\n                    info=\"Cache the text encoder outputs to increase speed. This should not be used when training the \"\n                    \"text encoder or performing data augmentations that would change the text encoder outputs.\",\n                    interactive=True,\n                )\n                self.cache_vae_outputs = gr.Checkbox(\n                    label=\"Cache VAE Outputs\",\n                    info=\"Cache the VAE outputs to increase speed. This should not be used when training the UNet or \"\n                    \"performing data augmentations that would change the VAE outputs.\",\n                    interactive=True,\n                )\n            with gr.Row():\n                self.enable_cpu_offload_during_validation = gr.Checkbox(\n                    label=\"Enable CPU Offload during Validation\",\n                    info=\"Offload models to the CPU sequentially during validation. This reduces peak VRAM \"\n                    \"requirements at the cost of slower validation during training.\",\n                    interactive=True,\n                )\n                self.gradient_checkpointing = gr.Checkbox(\n                    label=\"Gradient Checkpointing\",\n                    info=\"If True, VRAM requirements are reduced at the cost of ~20% slower training\",\n                    interactive=True,\n                )\n\n        gr.Markdown(\"## General Training Configs\")\n        with gr.Tab(\"Core\"):\n            with gr.Row():\n                self.train_unet = gr.Checkbox(label=\"Train UNet\", interactive=True)\n                self.train_text_encoder = gr.Checkbox(label=\"Train Text Encoder\", interactive=True)\n                self.train_ti = gr.Checkbox(label=\"Train Textual Inversion Token\", scale=2, interactive=True)\n            with gr.Row():\n                self.unet_learning_rate = gr.Number(\n                    label=\"UNet Learning Rate\",\n                    info=\"The UNet learning rate. Set to 0 or leave empty to inherit from the base optimizer \"\n                    \"learning rate.\",\n                    interactive=True,\n                )\n                self.text_encoder_learning_rate = gr.Number(\n                    label=\"Text Encoder Learning Rate\",\n                    info=\"The text encoder learning rate. Set to 0 or leave empty to inherit from the base optimizer \"\n                    \"learning rate.\",\n                    interactive=True,\n                )\n                self.textual_inversion_learning_rate = gr.Number(\n                    label=\"Textual Inversion Learning Rate\",\n                    info=\"The textual inversion learning rate. Set to 0 or leave empty to inherit from the base \"\n                    \"optimizer learning rate.\",\n                    interactive=True,\n                )\n                self.ti_train_steps_ratio = gr.Number(label=\"Textual Inversion Train Steps Ratio\", interactive=True)\n            with gr.Row():\n                self.lr_scheduler = gr.Dropdown(\n                    label=\"Learning Rate Scheduler\",\n                    choices=get_typing_literal_options(SdxlLoraAndTextualInversionConfig, \"lr_scheduler\"),\n                    interactive=True,\n                )\n                self.lr_warmup_steps = gr.Number(\n                    label=\"Warmup Steps\",\n                    info=\"The number of warmup steps in the \"\n                    \"learning rate schedule, if applicable to the selected scheduler.\",\n                    interactive=True,\n                )\n            with gr.Row():\n                self.use_masks = gr.Checkbox(\n                    label=\"Use Masks\", info=\"This can only be enabled if the dataset contains masks.\", interactive=True\n                )\n\n        with gr.Tab(\"Advanced\"):\n            with gr.Column():\n                self.lora_rank_dim = gr.Number(\n                    label=\"LoRA Rank Dim\",\n                    info=\"The rank dimension to use for the LoRA layers. Increasing the rank dimension increases\"\n                    \" the model's expressivity, but also increases the size of the generated LoRA model.\",\n                    interactive=True,\n                    precision=0,\n                )\n                self.min_snr_gamma = gr.Number(\n                    label=\"Minumum SNR Gamma\",\n                    info=\"min_snr_gamma acts like an an upper bound on the weight of samples with low noise \"\n                    \"levels. If None, then Min-SNR weighting will not be applied. If enabled, the recommended \"\n                    \"value is min_snr gamma = 5.0.\",\n                    interactive=True,\n                )\n                self.max_grad_norm = gr.Number(\n                    label=\"Max Gradient Norm\",\n                    info=\"Max gradient norm for clipping. Set to 0 or leave empty for no clipping (null).\",\n                    interactive=True,\n                )\n                self.train_batch_size = gr.Number(\n                    label=\"Batch Size\",\n                    info=\"The Training Batch Size - Higher values require increasing amounts of VRAM.\",\n                    precision=0,\n                    interactive=True,\n                )\n\n        gr.Markdown(\"## Validation\")\n        with gr.Group():\n            self.validation_prompts = gr.Textbox(\n                label=\"Validation Prompts\",\n                info=\"Enter one validation prompt per line. Optionally, add negative prompts after a '[NEG]' \"\n                \"delimiter. For example: `positive prompt[NEG]negative prompt`. \",\n                lines=5,\n                interactive=True,\n            )\n            self.num_validation_images_per_prompt = gr.Number(\n                label=\"# of Validation Images to Generate per Prompt\", precision=0, interactive=True\n            )\n\n    def update_ui_components_with_config_data(\n        self, config: SdxlLoraAndTextualInversionConfig\n    ) -> dict[gr.components.Component, typing.Any]:\n        update_dict = {\n            self.model: config.model,\n            self.hf_variant: config.hf_variant,\n            self.vae_model: config.vae_model,\n            self.num_vectors: config.num_vectors,\n            self.placeholder_token: config.placeholder_token,\n            self.initializer_token: config.initializer_token,\n            self.initial_phrase: config.initial_phrase,\n            self.max_checkpoints: config.max_checkpoints,\n            self.train_unet: config.train_unet,\n            self.train_text_encoder: config.train_text_encoder,\n            self.train_ti: config.train_ti,\n            self.unet_learning_rate: config.unet_learning_rate,\n            self.text_encoder_learning_rate: config.text_encoder_learning_rate,\n            self.textual_inversion_learning_rate: config.textual_inversion_learning_rate,\n            self.ti_train_steps_ratio: config.ti_train_steps_ratio,\n            self.lr_scheduler: config.lr_scheduler,\n            self.lr_warmup_steps: config.lr_warmup_steps,\n            self.use_masks: config.use_masks,\n            self.max_grad_norm: config.max_grad_norm,\n            self.train_batch_size: config.train_batch_size,\n            self.cache_text_encoder_outputs: config.cache_text_encoder_outputs,\n            self.cache_vae_outputs: config.cache_vae_outputs,\n            self.enable_cpu_offload_during_validation: config.enable_cpu_offload_during_validation,\n            self.gradient_accumulation_steps: config.gradient_accumulation_steps,\n            self.weight_dtype: config.weight_dtype,\n            self.gradient_checkpointing: config.gradient_checkpointing,\n            self.lora_rank_dim: config.lora_rank_dim,\n            self.min_snr_gamma: config.min_snr_gamma,\n            self.validation_prompts: convert_pos_neg_prompts_to_ui_prompts(\n                config.validation_prompts, config.negative_validation_prompts\n            ),\n            self.num_validation_images_per_prompt: config.num_validation_images_per_prompt,\n        }\n        update_dict.update(\n            self.image_caption_sd_data_loader_config_group.update_ui_components_with_config_data(config.data_loader)\n        )\n        update_dict.update(self.base_pipeline_config_group.update_ui_components_with_config_data(config))\n        update_dict.update(self.optimizer_config_group.update_ui_components_with_config_data(config.optimizer))\n\n        # Sanity check to catch if we accidentally forget to update a UI component.\n        assert set(update_dict.keys()) == set(self.get_ui_output_components())\n\n        return update_dict\n\n    def update_config_with_ui_component_data(\n        self, orig_config: SdxlLoraAndTextualInversionConfig, ui_data: dict[gr.components.Component, typing.Any]\n    ) -> SdxlLoraAndTextualInversionConfig:\n        new_config = orig_config.model_copy(deep=True)\n\n        new_config.model = ui_data.pop(self.model)\n        new_config.hf_variant = ui_data.pop(self.hf_variant) or None\n        new_config.vae_model = ui_data.pop(self.vae_model) or None\n        new_config.num_vectors = ui_data.pop(self.num_vectors)\n        new_config.placeholder_token = ui_data.pop(self.placeholder_token)\n        new_config.initializer_token = ui_data.pop(self.initializer_token) or None\n        new_config.initial_phrase = ui_data.pop(self.initial_phrase) or None\n        new_config.max_checkpoints = ui_data.pop(self.max_checkpoints)\n        new_config.train_unet = ui_data.pop(self.train_unet)\n        new_config.train_text_encoder = ui_data.pop(self.train_text_encoder)\n        new_config.train_ti = ui_data.pop(self.train_ti)\n        unet_lr_value = ui_data.pop(self.unet_learning_rate)\n        new_config.unet_learning_rate = None if unet_lr_value == 0 else unet_lr_value\n        text_encoder_lr_value = ui_data.pop(self.text_encoder_learning_rate)\n        new_config.text_encoder_learning_rate = None if text_encoder_lr_value == 0 else text_encoder_lr_value\n        ti_lr_value = ui_data.pop(self.textual_inversion_learning_rate)\n        new_config.textual_inversion_learning_rate = None if ti_lr_value == 0 else ti_lr_value\n        new_config.ti_train_steps_ratio = ui_data.pop(self.ti_train_steps_ratio)\n        new_config.lr_scheduler = ui_data.pop(self.lr_scheduler)\n        new_config.lr_warmup_steps = ui_data.pop(self.lr_warmup_steps)\n        new_config.use_masks = ui_data.pop(self.use_masks)\n        max_grad_norm_value = ui_data.pop(self.max_grad_norm)\n        new_config.max_grad_norm = None if max_grad_norm_value == 0 else max_grad_norm_value\n        new_config.train_batch_size = ui_data.pop(self.train_batch_size)\n        new_config.cache_text_encoder_outputs = ui_data.pop(self.cache_text_encoder_outputs)\n        new_config.cache_vae_outputs = ui_data.pop(self.cache_vae_outputs)\n        new_config.enable_cpu_offload_during_validation = ui_data.pop(self.enable_cpu_offload_during_validation)\n        new_config.gradient_accumulation_steps = ui_data.pop(self.gradient_accumulation_steps)\n        new_config.weight_dtype = ui_data.pop(self.weight_dtype)\n        new_config.gradient_checkpointing = ui_data.pop(self.gradient_checkpointing)\n        new_config.lora_rank_dim = ui_data.pop(self.lora_rank_dim)\n        new_config.min_snr_gamma = ui_data.pop(self.min_snr_gamma)\n        new_config.num_validation_images_per_prompt = ui_data.pop(self.num_validation_images_per_prompt)\n\n        positive_prompts, negative_prompts = convert_ui_prompts_to_pos_neg_prompts(ui_data.pop(self.validation_prompts))\n        new_config.validation_prompts = positive_prompts\n        new_config.negative_validation_prompts = negative_prompts\n\n        new_config.data_loader = self.image_caption_sd_data_loader_config_group.update_config_with_ui_component_data(\n            new_config.data_loader, ui_data\n        )\n        new_config = self.base_pipeline_config_group.update_config_with_ui_component_data(new_config, ui_data)\n        new_config.optimizer = self.optimizer_config_group.update_config_with_ui_component_data(\n            new_config.optimizer, ui_data\n        )\n\n        # We pop items from ui_data as we use them so that we can sanity check that all the input data was transferred\n        # to the config.\n        assert len(ui_data) == 0\n\n        return new_config\n"
  },
  {
    "path": "src/invoke_training/ui/config_groups/sdxl_lora_config_group.py",
    "content": "import typing\n\nimport gradio as gr\n\nfrom invoke_training.pipelines.stable_diffusion_xl.lora.config import SdxlLoraConfig\nfrom invoke_training.ui.config_groups.base_pipeline_config_group import BasePipelineConfigGroup\nfrom invoke_training.ui.config_groups.image_caption_sd_data_loader_config_group import (\n    ImageCaptionSDDataLoaderConfigGroup,\n)\nfrom invoke_training.ui.config_groups.optimizer_config_group import OptimizerConfigGroup\nfrom invoke_training.ui.config_groups.ui_config_element import UIConfigElement\nfrom invoke_training.ui.utils.prompts import (\n    convert_pos_neg_prompts_to_ui_prompts,\n    convert_ui_prompts_to_pos_neg_prompts,\n)\nfrom invoke_training.ui.utils.utils import get_typing_literal_options\n\n\nclass SdxlLoraConfigGroup(UIConfigElement):\n    def __init__(self):\n        \"\"\"The SD_LORA configs.\"\"\"\n\n        gr.Markdown(\"## Basic Configs\")\n        with gr.Row():\n            with gr.Column(scale=1):\n                with gr.Tab(\"Base Model\"):\n                    self.model = gr.Textbox(\n                        label=\"Model\",\n                        info=\"The base model. Can be a Hugging Face Hub model name, or a path to a local model (in \"\n                        \"diffusers or checkpoint format).\",\n                        type=\"text\",\n                        interactive=True,\n                    )\n                    self.hf_variant = gr.Textbox(\n                        label=\"Variant\",\n                        info=\"(optional) The Hugging Face hub model variant (e.g., fp16, fp32) to use if the model is a\"\n                        \" HF Hub model name.\",\n                        type=\"text\",\n                        interactive=True,\n                    )\n                    self.vae_model = gr.Textbox(\n                        label=\"VAE Model\",\n                        info=\"(optional) If set, this overrides the base model's default VAE model.\",\n                        type=\"text\",\n                        interactive=True,\n                    )\n            with gr.Column(scale=3):\n                with gr.Tab(\"Training Outputs\"):\n                    self.base_pipeline_config_group = BasePipelineConfigGroup()\n                    self.max_checkpoints = gr.Number(\n                        label=\"Maximum Number of Checkpoints\",\n                        info=\"The maximum number of checkpoints to keep on disk from this training run. Earlier \"\n                        \"checkpoints will be deleted to respect this limit.\",\n                        interactive=True,\n                        precision=0,\n                    )\n\n        gr.Markdown(\"## Data Configs\")\n        self.image_caption_sd_data_loader_config_group = ImageCaptionSDDataLoaderConfigGroup()\n\n        gr.Markdown(\"## Optimizer Configs\")\n        self.optimizer_config_group = OptimizerConfigGroup()\n\n        gr.Markdown(\"## Speed / Memory Configs\")\n        with gr.Group():\n            with gr.Row():\n                self.gradient_accumulation_steps = gr.Number(\n                    label=\"Gradient Accumulation Steps\",\n                    info=\"The number of gradient steps to accumulate before each weight update. This is an alternative\"\n                    \"to increasing the batch size when training with limited VRAM.\"\n                    \"effective_batch_size = train_batch_size * gradient_accumulation_steps.\",\n                    precision=0,\n                    interactive=True,\n                )\n            with gr.Row():\n                self.weight_dtype = gr.Dropdown(\n                    label=\"Weight Type\",\n                    info=\"The precision of the model weights. Lower precision can speed up training and reduce memory, \"\n                    \"with increased risk of numerical stability issues. 'bfloat16' is recommended for most use cases \"\n                    \"if your GPU supports it.\",\n                    choices=get_typing_literal_options(SdxlLoraConfig, \"weight_dtype\"),\n                    interactive=True,\n                )\n            with gr.Row():\n                self.cache_text_encoder_outputs = gr.Checkbox(\n                    label=\"Cache Text Encoder Outputs\",\n                    info=\"Cache the text encoder outputs to increase speed. This should not be used when training the \"\n                    \"text encoder or performing data augmentations that would change the text encoder outputs.\",\n                    interactive=True,\n                )\n                self.cache_vae_outputs = gr.Checkbox(\n                    label=\"Cache VAE Outputs\",\n                    info=\"Cache the VAE outputs to increase speed. This should not be used when training the UNet or \"\n                    \"performing data augmentations that would change the VAE outputs.\",\n                    interactive=True,\n                )\n            with gr.Row():\n                self.enable_cpu_offload_during_validation = gr.Checkbox(\n                    label=\"Enable CPU Offload during Validation\",\n                    info=\"Offload models to the CPU sequentially during validation. This reduces peak VRAM \"\n                    \"requirements at the cost of slower validation during training.\",\n                    interactive=True,\n                )\n                self.gradient_checkpointing = gr.Checkbox(\n                    label=\"Gradient Checkpointing\",\n                    info=\"If True, VRAM requirements are reduced at the cost of ~20% slower training\",\n                    interactive=True,\n                )\n\n        gr.Markdown(\"## General Training Configs\")\n        with gr.Tab(\"Core\"):\n            with gr.Row():\n                self.train_unet = gr.Checkbox(label=\"Train UNet\", interactive=True)\n                self.train_text_encoder = gr.Checkbox(label=\"Train Text Encoder\", interactive=True)\n            with gr.Row():\n                self.unet_learning_rate = gr.Number(\n                    label=\"UNet Learning Rate\",\n                    info=\"The UNet learning rate. Set to 0 or leave empty to inherit from the base optimizer \"\n                    \"learning rate.\",\n                    interactive=True,\n                )\n                self.text_encoder_learning_rate = gr.Number(\n                    label=\"Text Encoder Learning Rate\",\n                    info=\"The text encoder learning rate. Set to 0 or leave empty to inherit from the base optimizer \"\n                    \"learning rate.\",\n                    interactive=True,\n                )\n            with gr.Row():\n                self.lr_scheduler = gr.Dropdown(\n                    label=\"Learning Rate Scheduler\",\n                    choices=get_typing_literal_options(SdxlLoraConfig, \"lr_scheduler\"),\n                    interactive=True,\n                )\n                self.lr_warmup_steps = gr.Number(\n                    label=\"Warmup Steps\",\n                    info=\"The number of warmup steps in the \"\n                    \"learning rate schedule, if applicable to the selected scheduler.\",\n                    interactive=True,\n                )\n            with gr.Row():\n                self.use_masks = gr.Checkbox(\n                    label=\"Use Masks\", info=\"This can only be enabled if the dataset contains masks.\", interactive=True\n                )\n\n        with gr.Tab(\"Advanced\"):\n            with gr.Column():\n                self.lora_rank_dim = gr.Number(\n                    label=\"LoRA Rank Dim\",\n                    info=\"The rank dimension to use for the LoRA layers. Increasing the rank dimension increases\"\n                    \" the model's expressivity, but also increases the size of the generated LoRA model.\",\n                    interactive=True,\n                    precision=0,\n                )\n                self.min_snr_gamma = gr.Number(\n                    label=\"Minumum SNR Gamma\",\n                    info=\"min_snr_gamma acts like an an upper bound on the weight of samples with low noise \"\n                    \"levels. If None, then Min-SNR weighting will not be applied. If enabled, the recommended \"\n                    \"value is min_snr gamma = 5.0.\",\n                    interactive=True,\n                )\n                self.max_grad_norm = gr.Number(\n                    label=\"Max Gradient Norm\",\n                    info=\"Max gradient norm for clipping. Set to 0 or leave empty for no clipping (null).\",\n                    interactive=True,\n                )\n                self.train_batch_size = gr.Number(\n                    label=\"Batch Size\",\n                    info=\"The Training Batch Size - Higher values require increasing amounts of VRAM.\",\n                    precision=0,\n                    interactive=True,\n                )\n\n        gr.Markdown(\"## Validation\")\n        with gr.Group():\n            self.validation_prompts = gr.Textbox(\n                label=\"Validation Prompts\",\n                info=\"Enter one validation prompt per line. Optionally, add negative prompts after a '[NEG]' \"\n                \"delimiter. For example: `positive prompt[NEG]negative prompt`. \",\n                lines=5,\n                interactive=True,\n            )\n            self.num_validation_images_per_prompt = gr.Number(\n                label=\"# of Validation Images to Generate per Prompt\", precision=0, interactive=True\n            )\n\n    def update_ui_components_with_config_data(\n        self, config: SdxlLoraConfig\n    ) -> dict[gr.components.Component, typing.Any]:\n        update_dict = {\n            self.model: config.model,\n            self.hf_variant: config.hf_variant,\n            self.vae_model: config.vae_model,\n            self.max_checkpoints: config.max_checkpoints,\n            self.train_unet: config.train_unet,\n            self.unet_learning_rate: config.unet_learning_rate,\n            self.train_text_encoder: config.train_text_encoder,\n            self.text_encoder_learning_rate: config.text_encoder_learning_rate,\n            self.lr_scheduler: config.lr_scheduler,\n            self.lr_warmup_steps: config.lr_warmup_steps,\n            self.use_masks: config.use_masks,\n            self.max_grad_norm: config.max_grad_norm,\n            self.train_batch_size: config.train_batch_size,\n            self.cache_text_encoder_outputs: config.cache_text_encoder_outputs,\n            self.cache_vae_outputs: config.cache_vae_outputs,\n            self.enable_cpu_offload_during_validation: config.enable_cpu_offload_during_validation,\n            self.gradient_accumulation_steps: config.gradient_accumulation_steps,\n            self.weight_dtype: config.weight_dtype,\n            self.gradient_checkpointing: config.gradient_checkpointing,\n            self.lora_rank_dim: config.lora_rank_dim,\n            self.min_snr_gamma: config.min_snr_gamma,\n            self.validation_prompts: convert_pos_neg_prompts_to_ui_prompts(\n                config.validation_prompts, config.negative_validation_prompts\n            ),\n            self.num_validation_images_per_prompt: config.num_validation_images_per_prompt,\n        }\n        update_dict.update(\n            self.image_caption_sd_data_loader_config_group.update_ui_components_with_config_data(config.data_loader)\n        )\n        update_dict.update(self.base_pipeline_config_group.update_ui_components_with_config_data(config))\n        update_dict.update(self.optimizer_config_group.update_ui_components_with_config_data(config.optimizer))\n\n        # Sanity check to catch if we accidentally forget to update a UI component.\n        assert set(update_dict.keys()) == set(self.get_ui_output_components())\n\n        return update_dict\n\n    def update_config_with_ui_component_data(\n        self, orig_config: SdxlLoraConfig, ui_data: dict[gr.components.Component, typing.Any]\n    ) -> SdxlLoraConfig:\n        new_config = orig_config.model_copy(deep=True)\n\n        new_config.model = ui_data.pop(self.model)\n        new_config.hf_variant = ui_data.pop(self.hf_variant) or None\n        new_config.vae_model = ui_data.pop(self.vae_model) or None\n        new_config.max_checkpoints = ui_data.pop(self.max_checkpoints)\n        new_config.train_unet = ui_data.pop(self.train_unet)\n        unet_lr_value = ui_data.pop(self.unet_learning_rate)\n        new_config.unet_learning_rate = None if unet_lr_value == 0 else unet_lr_value\n        new_config.train_text_encoder = ui_data.pop(self.train_text_encoder)\n        text_encoder_lr_value = ui_data.pop(self.text_encoder_learning_rate)\n        new_config.text_encoder_learning_rate = None if text_encoder_lr_value == 0 else text_encoder_lr_value\n        new_config.lr_scheduler = ui_data.pop(self.lr_scheduler)\n        new_config.lr_warmup_steps = ui_data.pop(self.lr_warmup_steps)\n        new_config.use_masks = ui_data.pop(self.use_masks)\n        max_grad_norm_value = ui_data.pop(self.max_grad_norm)\n        new_config.max_grad_norm = None if max_grad_norm_value == 0 else max_grad_norm_value\n        new_config.train_batch_size = ui_data.pop(self.train_batch_size)\n        new_config.cache_text_encoder_outputs = ui_data.pop(self.cache_text_encoder_outputs)\n        new_config.cache_vae_outputs = ui_data.pop(self.cache_vae_outputs)\n        new_config.enable_cpu_offload_during_validation = ui_data.pop(self.enable_cpu_offload_during_validation)\n        new_config.gradient_accumulation_steps = ui_data.pop(self.gradient_accumulation_steps)\n        new_config.weight_dtype = ui_data.pop(self.weight_dtype)\n        new_config.gradient_checkpointing = ui_data.pop(self.gradient_checkpointing)\n        new_config.lora_rank_dim = ui_data.pop(self.lora_rank_dim)\n        new_config.min_snr_gamma = ui_data.pop(self.min_snr_gamma)\n        new_config.num_validation_images_per_prompt = ui_data.pop(self.num_validation_images_per_prompt)\n\n        positive_prompts, negative_prompts = convert_ui_prompts_to_pos_neg_prompts(ui_data.pop(self.validation_prompts))\n        new_config.validation_prompts = positive_prompts\n        new_config.negative_validation_prompts = negative_prompts\n\n        new_config.data_loader = self.image_caption_sd_data_loader_config_group.update_config_with_ui_component_data(\n            new_config.data_loader, ui_data\n        )\n        new_config = self.base_pipeline_config_group.update_config_with_ui_component_data(new_config, ui_data)\n        new_config.optimizer = self.optimizer_config_group.update_config_with_ui_component_data(\n            new_config.optimizer, ui_data\n        )\n\n        # We pop items from ui_data as we use them so that we can sanity check that all the input data was transferred\n        # to the config.\n        assert len(ui_data) == 0\n\n        return new_config\n"
  },
  {
    "path": "src/invoke_training/ui/config_groups/sdxl_textual_inversion_config_group.py",
    "content": "import typing\n\nimport gradio as gr\n\nfrom invoke_training.pipelines.stable_diffusion_xl.textual_inversion.config import SdxlTextualInversionConfig\nfrom invoke_training.ui.config_groups.base_pipeline_config_group import BasePipelineConfigGroup\nfrom invoke_training.ui.config_groups.optimizer_config_group import OptimizerConfigGroup\nfrom invoke_training.ui.config_groups.textual_inversion_sd_data_loader_config_group import (\n    TextualInversionSDDataLoaderConfigGroup,\n)\nfrom invoke_training.ui.config_groups.ui_config_element import UIConfigElement\nfrom invoke_training.ui.utils.prompts import (\n    convert_pos_neg_prompts_to_ui_prompts,\n    convert_ui_prompts_to_pos_neg_prompts,\n)\nfrom invoke_training.ui.utils.utils import get_typing_literal_options\n\n\nclass SdxlTextualInversionConfigGroup(UIConfigElement):\n    def __init__(self):\n        \"\"\"The SDXL_TEXTUAL_INVERSION configs.\"\"\"\n\n        gr.Markdown(\"## Basic Configs\")\n        with gr.Row():\n            with gr.Column(scale=1):\n                with gr.Tab(\"Base Model\"):\n                    self.model = gr.Textbox(\n                        label=\"Model\",\n                        info=\"The base model. Can be a Hugging Face Hub model name, or a path to a local model (in \"\n                        \"diffusers or checkpoint format).\",\n                        type=\"text\",\n                        interactive=True,\n                    )\n                    self.hf_variant = gr.Textbox(\n                        label=\"Variant\",\n                        info=\"(optional) The Hugging Face hub model variant (e.g., fp16, fp32) to use if the model is a\"\n                        \" HF Hub model name.\",\n                        type=\"text\",\n                        interactive=True,\n                    )\n                    self.vae_model = gr.Textbox(\n                        label=\"VAE Model\",\n                        info=\"(optional) If set, this overrides the base model's default VAE model.\",\n                        type=\"text\",\n                        interactive=True,\n                    )\n            with gr.Column(scale=3):\n                with gr.Tab(\"Training Outputs\"):\n                    self.base_pipeline_config_group = BasePipelineConfigGroup()\n                    self.max_checkpoints = gr.Number(\n                        label=\"Maximum Number of Checkpoints\",\n                        info=\"The maximum number of checkpoints to keep on disk from this training run. Earlier \"\n                        \"checkpoints will be deleted to respect this limit.\",\n                        interactive=True,\n                        precision=0,\n                    )\n\n        gr.Markdown(\"## Data Configs\")\n        self.textual_inversion_sd_data_loader_config_group = TextualInversionSDDataLoaderConfigGroup()\n\n        gr.Markdown(\"## Textual Inversion Configs\")\n        self.num_vectors = gr.Number(\n            label=\"Num Vectors\",\n            info=\"The number of TI vectors that will be trained. Can be overriden by 'Initial Phrase'.\",\n            interactive=True,\n            precision=0,\n        )\n        self.placeholder_token = gr.Textbox(\n            label=\"Placeholder Token\",\n            info=\"The special word to associate the learned embeddings with. Choose a unique token that is unlikely to \"\n            \"already exist in the tokenizer's vocabulary.\",\n            interactive=True,\n        )\n        self.initializer_token = gr.Textbox(\n            label=\"Initializer Token\",\n            info=\"Only one of 'Initializer Token' or 'Initial Phrase' should be set. A vocabulary token to use as an \"\n            \"initializer for the placeholder token. It should be a single word that roughly describes the object or \"\n            \"style that you're trying to train on. The initializer token ust map to a single tokenizer token.\",\n            interactive=True,\n        )\n        self.initial_phrase = gr.Textbox(\n            label=\"Initial Phrase\",\n            info=\"Only one of 'Initializer Token' or 'Initial Phrase' should be set. A phrase that will be used to \"\n            \"initialize the placeholder token embedding. The phrase will be tokenized, and the corresponding \"\n            \"embeddings will be used to initialize the placeholder tokens. The number of embedding vectors will be \"\n            \"inferred from the length of the tokenized phrase, so keep the phrase short.\",\n            interactive=True,\n        )\n\n        gr.Markdown(\"## Optimizer Configs\")\n        self.optimizer_config_group = OptimizerConfigGroup()\n\n        gr.Markdown(\"## Speed / Memory Configs\")\n        with gr.Group():\n            with gr.Row():\n                self.gradient_accumulation_steps = gr.Number(\n                    label=\"Gradient Accumulation Steps\",\n                    info=\"The number of gradient steps to accumulate before each weight update. This is an\"\n                    \" alternative to increasing the batch size when training with limited VRAM.\"\n                    \"effective_batch_size = train_batch_size * gradient_accumulation_steps.\",\n                    precision=0,\n                    interactive=True,\n                )\n            with gr.Row():\n                self.weight_dtype = gr.Dropdown(\n                    label=\"Weight Type\",\n                    info=\"The precision of the model weights. Lower precision can speed up training and reduce memory, \"\n                    \"with increased risk of numerical stability issues. 'bfloat16' is recommended for most use cases \"\n                    \"if your GPU supports it.\",\n                    choices=get_typing_literal_options(SdxlTextualInversionConfig, \"weight_dtype\"),\n                    interactive=True,\n                )\n            with gr.Row():\n                self.cache_vae_outputs = gr.Checkbox(\n                    label=\"Cache VAE Outputs\",\n                    info=\"Cache the VAE outputs to increase speed. This should not be used when training the UNet or \"\n                    \"performing data augmentations that would change the VAE outputs.\",\n                    interactive=True,\n                )\n            with gr.Row():\n                self.enable_cpu_offload_during_validation = gr.Checkbox(\n                    label=\"Enable CPU Offload during Validation\",\n                    info=\"Offload models to the CPU sequentially during validation. This reduces peak VRAM \"\n                    \"requirements at the cost of slower validation during training.\",\n                    interactive=True,\n                )\n                self.gradient_checkpointing = gr.Checkbox(\n                    label=\"Gradient Checkpointing\",\n                    info=\"If True, VRAM requirements are reduced at the cost of ~20% slower training\",\n                    interactive=True,\n                )\n\n        gr.Markdown(\"## General Training Configs\")\n        with gr.Tab(\"Core\"):\n            with gr.Row():\n                self.lr_scheduler = gr.Dropdown(\n                    label=\"Learning Rate Scheduler\",\n                    choices=get_typing_literal_options(SdxlTextualInversionConfig, \"lr_scheduler\"),\n                    interactive=True,\n                )\n                self.lr_warmup_steps = gr.Number(\n                    label=\"Warmup Steps\",\n                    info=\"The number of warmup steps in the \"\n                    \"learning rate schedule, if applicable to the selected scheduler.\",\n                    interactive=True,\n                )\n            with gr.Row():\n                self.use_masks = gr.Checkbox(\n                    label=\"Use Masks\", info=\"This can only be enabled if the dataset contains masks.\", interactive=True\n                )\n\n        with gr.Tab(\"Advanced\"):\n            with gr.Column():\n                self.min_snr_gamma = gr.Number(\n                    label=\"Minumum SNR Gamma\",\n                    info=\"min_snr_gamma acts like an an upper bound on the weight of samples with low noise \"\n                    \"levels. If None, then Min-SNR weighting will not be applied. If enabled, the recommended \"\n                    \"value is min_snr gamma = 5.0.\",\n                    interactive=True,\n                )\n                self.max_grad_norm = gr.Number(\n                    label=\"Max Gradient Norm\",\n                    info=\"Max gradient norm for clipping. Set to 0 or leave empty for no clipping (null).\",\n                    interactive=True,\n                )\n                self.train_batch_size = gr.Number(\n                    label=\"Batch Size\",\n                    info=\"The Training Batch Size - Higher values require increasing amounts of VRAM.\",\n                    precision=0,\n                    interactive=True,\n                )\n\n        gr.Markdown(\"## Validation\")\n        with gr.Group():\n            self.validation_prompts = gr.Textbox(\n                label=\"Validation Prompts\",\n                info=\"Enter one validation prompt per line. Optionally, add negative prompts after a '[NEG]' \"\n                \"delimiter. For example: `positive prompt[NEG]negative prompt`. \",\n                lines=5,\n                interactive=True,\n            )\n            self.num_validation_images_per_prompt = gr.Number(\n                label=\"# of Validation Images to Generate per Prompt\", precision=0, interactive=True\n            )\n\n    def update_ui_components_with_config_data(\n        self, config: SdxlTextualInversionConfig\n    ) -> dict[gr.components.Component, typing.Any]:\n        update_dict = {\n            self.model: config.model,\n            self.hf_variant: config.hf_variant,\n            self.vae_model: config.vae_model,\n            self.num_vectors: config.num_vectors,\n            self.placeholder_token: config.placeholder_token,\n            self.initializer_token: config.initializer_token,\n            self.initial_phrase: config.initial_phrase,\n            self.max_checkpoints: config.max_checkpoints,\n            self.lr_scheduler: config.lr_scheduler,\n            self.lr_warmup_steps: config.lr_warmup_steps,\n            self.use_masks: config.use_masks,\n            self.max_grad_norm: config.max_grad_norm,\n            self.train_batch_size: config.train_batch_size,\n            self.cache_vae_outputs: config.cache_vae_outputs,\n            self.enable_cpu_offload_during_validation: config.enable_cpu_offload_during_validation,\n            self.gradient_accumulation_steps: config.gradient_accumulation_steps,\n            self.weight_dtype: config.weight_dtype,\n            self.gradient_checkpointing: config.gradient_checkpointing,\n            self.min_snr_gamma: config.min_snr_gamma,\n            self.validation_prompts: convert_pos_neg_prompts_to_ui_prompts(\n                config.validation_prompts, config.negative_validation_prompts\n            ),\n            self.num_validation_images_per_prompt: config.num_validation_images_per_prompt,\n        }\n        update_dict.update(\n            self.textual_inversion_sd_data_loader_config_group.update_ui_components_with_config_data(config.data_loader)\n        )\n        update_dict.update(self.base_pipeline_config_group.update_ui_components_with_config_data(config))\n        update_dict.update(self.optimizer_config_group.update_ui_components_with_config_data(config.optimizer))\n\n        # Sanity check to catch if we accidentally forget to update a UI component.\n        assert set(update_dict.keys()) == set(self.get_ui_output_components())\n\n        return update_dict\n\n    def update_config_with_ui_component_data(\n        self, orig_config: SdxlTextualInversionConfig, ui_data: dict[gr.components.Component, typing.Any]\n    ) -> SdxlTextualInversionConfig:\n        new_config = orig_config.model_copy(deep=True)\n\n        new_config.model = ui_data.pop(self.model)\n        new_config.hf_variant = ui_data.pop(self.hf_variant) or None\n        new_config.vae_model = ui_data.pop(self.vae_model) or None\n        new_config.num_vectors = ui_data.pop(self.num_vectors)\n        new_config.placeholder_token = ui_data.pop(self.placeholder_token)\n        new_config.initializer_token = ui_data.pop(self.initializer_token) or None\n        new_config.initial_phrase = ui_data.pop(self.initial_phrase) or None\n        new_config.max_checkpoints = ui_data.pop(self.max_checkpoints)\n        new_config.lr_scheduler = ui_data.pop(self.lr_scheduler)\n        new_config.lr_warmup_steps = ui_data.pop(self.lr_warmup_steps)\n        new_config.use_masks = ui_data.pop(self.use_masks)\n        max_grad_norm_value = ui_data.pop(self.max_grad_norm)\n        new_config.max_grad_norm = None if max_grad_norm_value == 0 else max_grad_norm_value\n        new_config.train_batch_size = ui_data.pop(self.train_batch_size)\n        new_config.cache_vae_outputs = ui_data.pop(self.cache_vae_outputs)\n        new_config.enable_cpu_offload_during_validation = ui_data.pop(self.enable_cpu_offload_during_validation)\n        new_config.gradient_accumulation_steps = ui_data.pop(self.gradient_accumulation_steps)\n        new_config.weight_dtype = ui_data.pop(self.weight_dtype)\n        new_config.gradient_checkpointing = ui_data.pop(self.gradient_checkpointing)\n        new_config.min_snr_gamma = ui_data.pop(self.min_snr_gamma)\n        new_config.num_validation_images_per_prompt = ui_data.pop(self.num_validation_images_per_prompt)\n\n        positive_prompts, negative_prompts = convert_ui_prompts_to_pos_neg_prompts(ui_data.pop(self.validation_prompts))\n        new_config.validation_prompts = positive_prompts\n        new_config.negative_validation_prompts = negative_prompts\n\n        new_config.data_loader = (\n            self.textual_inversion_sd_data_loader_config_group.update_config_with_ui_component_data(\n                new_config.data_loader, ui_data\n            )\n        )\n        new_config = self.base_pipeline_config_group.update_config_with_ui_component_data(new_config, ui_data)\n        new_config.optimizer = self.optimizer_config_group.update_config_with_ui_component_data(\n            new_config.optimizer, ui_data\n        )\n\n        # We pop items from ui_data as we use them so that we can sanity check that all the input data was transferred\n        # to the config.\n        assert len(ui_data) == 0\n\n        return new_config\n"
  },
  {
    "path": "src/invoke_training/ui/config_groups/textual_inversion_sd_data_loader_config_group.py",
    "content": "from typing import Any\n\nimport gradio as gr\n\nfrom invoke_training.config.data.data_loader_config import (\n    TextualInversionSDDataLoaderConfig,\n)\nfrom invoke_training.ui.config_groups.aspect_ratio_bucket_config_group import AspectRatioBucketConfigGroup\nfrom invoke_training.ui.config_groups.dataset_config_group import DatasetConfigGroup\nfrom invoke_training.ui.config_groups.ui_config_element import UIConfigElement\n\n\nclass TextualInversionSDDataLoaderConfigGroup(UIConfigElement):\n    def __init__(self):\n        with gr.Row():\n            with gr.Column(scale=1):\n                with gr.Tab(\"Data Source Configs\"):\n                    with gr.Group():\n                        self.dataset = DatasetConfigGroup(\n                            allowed_types=[\n                                \"HF_HUB_IMAGE_CAPTION_DATASET\",\n                                \"IMAGE_CAPTION_JSONL_DATASET\",\n                                \"IMAGE_CAPTION_DIR_DATASET\",\n                                \"IMAGE_DIR_DATASET\",\n                            ]\n                        )\n            with gr.Column(scale=3):\n                with gr.Tab(\"Data Loading Configs\"):\n                    with gr.Group():\n                        self.caption_preset = gr.Dropdown(\n                            label=\"Caption Preset\",\n                            choices=[\"None\", \"style\", \"object\"],\n                            info=\"Only one of 'Caption Preset' or 'Caption Templates' should be set.\\nSelect a Caption \"\n                            \"Preset option to use a set of pre-configured templates.\",\n                            interactive=True,\n                        )\n                        self.caption_templates = gr.Textbox(\n                            label=\"Caption Templates\",\n                            info=\"Only one of 'Caption Preset' or 'Caption Templates' should be set. Enter one template\"\n                            \" per line. Each template should contain a single placeholder token slot indicated by '{}',\"\n                            \" for example 'a photo of a {}'.\",\n                            lines=5,\n                            interactive=True,\n                        )\n                        with gr.Row():\n                            self.keep_original_captions = gr.Checkbox(\n                                label=\"Keep Original Captions\",\n                                info=\"If True, the caption templates will be prepended to the original captions.\"\n                                \" If False, the caption templates will replace the original captions.\",\n                                interactive=True,\n                            )\n                            self.shuffle_caption_delimiter = gr.Textbox(\n                                label=\"Shuffle Caption Delimiter\",\n                                info=\"Set captions to split on the provided delimiter (e.g. ',') and shuffled.\",\n                                interactive=True,\n                            )\n\n                        with gr.Row():\n                            self.resolution = gr.Number(\n                                label=\"Resolution\",\n                                info=\"The resolution for input images. All of the images in the dataset will be\"\n                                \" resized to this resolution unless the aspect_ratio_buckets config is set.\",\n                                precision=0,\n                                interactive=True,\n                            )\n                            self.dataloader_num_workers = gr.Number(\n                                label=\"Dataloading Workers\",\n                                info=\"Number of subprocesses to use for data loading. 0 means that the data will\"\n                                \" be loaded in the main process.\",\n                                precision=0,\n                                interactive=True,\n                            )\n                        with gr.Row():\n                            self.center_crop = gr.Checkbox(\n                                label=\"Center Crop\",\n                                info=\"If set, input images will be center-cropped to the target resolution. Otherwise,\"\n                                \" input images will be randomly cropped to the target resolution.\",\n                                interactive=True,\n                            )\n                            self.random_flip = gr.Checkbox(\n                                label=\"Random Flip\",\n                                info=\"If set, random flip augmentations will be applied to input images.\",\n                                interactive=True,\n                            )\n                with gr.Tab(\"Aspect Ratio Bucketing Configs\"):\n                    self.aspect_ratio_bucket_config_group = AspectRatioBucketConfigGroup()\n\n    def update_ui_components_with_config_data(\n        self, config: TextualInversionSDDataLoaderConfig\n    ) -> dict[gr.components.Component, Any]:\n        # Special handling of caption_preset to translate None to \"None\".\n        caption_preset = \"None\"\n        if config.caption_preset is not None:\n            caption_preset = config.caption_preset\n\n        update_dict = {\n            self.caption_preset: caption_preset,\n            self.caption_templates: \"\\n\".join(config.caption_templates or []),\n            self.keep_original_captions: config.keep_original_captions,\n            self.shuffle_caption_delimiter: config.shuffle_caption_delimiter,\n            self.resolution: config.resolution,\n            self.center_crop: config.center_crop,\n            self.random_flip: config.random_flip,\n            self.dataloader_num_workers: config.dataloader_num_workers,\n        }\n\n        update_dict.update(self.dataset.update_ui_components_with_config_data(config.dataset))\n        update_dict.update(\n            self.aspect_ratio_bucket_config_group.update_ui_components_with_config_data(config.aspect_ratio_buckets)\n        )\n\n        return update_dict\n\n    def update_config_with_ui_component_data(\n        self, orig_config: TextualInversionSDDataLoaderConfig, ui_data: dict[gr.components.Component, Any]\n    ) -> TextualInversionSDDataLoaderConfig:\n        new_config = orig_config.model_copy(deep=True)\n\n        # Special handling of caption_preset to translate \"None\" to None.\n        caption_presets = {\"None\": None, \"style\": \"style\", \"object\": \"object\"}\n        caption_preset = caption_presets[ui_data.pop(self.caption_preset)]\n\n        # Special handling of caption_templates.\n        caption_templates: list[str] = ui_data.pop(self.caption_templates).split(\"\\n\")\n        caption_templates = [x.strip() for x in caption_templates if x.strip() != \"\"] or None\n\n        new_config.dataset = self.dataset.update_config_with_ui_component_data(orig_config.dataset, ui_data)\n        new_config.aspect_ratio_buckets = self.aspect_ratio_bucket_config_group.update_config_with_ui_component_data(\n            orig_config.aspect_ratio_buckets, ui_data\n        )\n        new_config.caption_preset = caption_preset\n        new_config.caption_templates = caption_templates\n        new_config.keep_original_captions = ui_data.pop(self.keep_original_captions)\n        new_config.shuffle_caption_delimiter = ui_data.pop(self.shuffle_caption_delimiter) or None\n        new_config.resolution = ui_data.pop(self.resolution)\n        new_config.center_crop = ui_data.pop(self.center_crop)\n        new_config.random_flip = ui_data.pop(self.random_flip)\n        new_config.dataloader_num_workers = ui_data.pop(self.dataloader_num_workers)\n\n        return new_config\n"
  },
  {
    "path": "src/invoke_training/ui/config_groups/ui_config_element.py",
    "content": "from typing import Any\n\nimport gradio as gr\n\n\nclass UIConfigElement:\n    \"\"\"A base class for UI blocks that represent a part of a config.\"\"\"\n\n    def get_ui_output_components(self) -> list[gr.components.Component]:\n        \"\"\"Recursively return a list of all valid output UI components.\"\"\"\n        all_ui_components = []\n        for attribute in vars(self).values():\n            if isinstance(attribute, (gr.components.Component, gr.Group)):\n                all_ui_components.append(attribute)\n            elif isinstance(attribute, UIConfigElement):\n                all_ui_components.extend(attribute.get_ui_output_components())\n        return all_ui_components\n\n    def get_ui_input_components(self) -> list[gr.components.Component]:\n        \"\"\"Recursively return a list of all valid input UI components.\"\"\"\n        all_ui_components = []\n        for attribute in vars(self).values():\n            if isinstance(attribute, (gr.components.Component)):\n                all_ui_components.append(attribute)\n            elif isinstance(attribute, UIConfigElement):\n                all_ui_components.extend(attribute.get_ui_input_components())\n        return all_ui_components\n\n    def update_ui_components_with_config_data(self, config) -> dict[gr.components.Component, Any]:\n        \"\"\"Produce a dictionary of UI components to their corresponding updated data from the config.\"\"\"\n        raise NotImplementedError()\n\n    def update_config_with_ui_component_data(self, orig_config, ui_data: dict[gr.components.Component, Any]):\n        \"\"\"Update the orig_config with the data from the UI components. Return the updated config.\"\"\"\n        raise NotImplementedError()\n"
  },
  {
    "path": "src/invoke_training/ui/gradio_blocks/header.py",
    "content": "import gradio as gr\n\nfrom invoke_training.ui.utils.utils import get_assets_dir_path\n\n\nclass Header:\n    def __init__(self):\n        logo_path = get_assets_dir_path() / \"logo.png\"\n        gr.Image(\n            value=logo_path,\n            label=\"Invoke Training App\",\n            width=200,\n            interactive=False,\n            container=False,\n        )\n        gr.Markdown(\n            \"[Home](/)\\n\\n\"\n            \"*Invoke Training* - [Documentation](https://invoke-ai.github.io/invoke-training/) --\"\n            \" Learn more about Invoke at [invoke.com](https://www.invoke.com/)\"\n        )\n"
  },
  {
    "path": "src/invoke_training/ui/gradio_blocks/pipeline_tab.py",
    "content": "import typing\n\nimport gradio as gr\nimport yaml\n\nfrom invoke_training.config.pipeline_config import PipelineConfig\nfrom invoke_training.ui.config_groups.ui_config_element import UIConfigElement\nfrom invoke_training.ui.utils.utils import load_config_from_yaml\n\n\nclass PipelineTab:\n    def __init__(\n        self,\n        name: str,\n        default_config_file_path: str,\n        pipeline_config_cls: typing.Type[PipelineConfig],\n        config_group_cls: typing.Type[UIConfigElement],\n        run_training_cb: typing.Callable[[PipelineConfig], None],\n        app: gr.Blocks,\n    ):\n        \"\"\"A tab for a single training pipeline type.\n\n        Args:\n            run_training_cb (typing.Callable[[PipelineConfig], None]): A callback function to run the training process.\n        \"\"\"\n        self._name = name\n        self._default_config_file_path = default_config_file_path\n        self._pipeline_config_cls = pipeline_config_cls\n        self._run_training_cb = run_training_cb\n\n        # self._default_config is the config that was last loaded from the reference config file.\n        self._default_config = None\n        # self._current_config is the config that was most recently generated from the UI.\n        self._current_config = None\n\n        gr.Markdown(f\"# {self._name} Training Config\")\n        self.reference_config_file = gr.Textbox(\n            label=\"Reference Config File Path\", value=default_config_file_path, interactive=True\n        )\n        reset_config_button = gr.Button(value=\"Reload reference config\")\n        self.pipeline_config_group = config_group_cls()\n\n        gr.Markdown(\"## Config Output\")\n        generate_config_button = gr.Button(value=\"Generate Config\")\n        self._config_yaml = gr.Code(label=\"Config YAML\", language=\"yaml\", interactive=False)\n\n        gr.Markdown(\n            \"\"\"# Run Training\n\n            'Start Training' starts the training process in the background. Check the terminal for logs.\n\n            **Warning: Click 'Generate Config' to capture all of the latest changes before starting training.**\n            \"\"\"\n        )\n        run_training_button = gr.Button(value=\"Start Training\")\n\n        gr.Markdown(\n            \"\"\"# Visualize Results\n\n        Once you've started training, you can see the results by launching tensorboard with the following\n        command:\n\n        ```bash\n        tensorboard --logdir /path/to/output_dir\n        ```\n\n        Alternatively, you can browse the output directory directly to find model checkpoints, logs, and validation\n        images.\n        \"\"\"\n        )\n\n        reset_config_button.click(\n            self.on_reset_config_button_click,\n            inputs=self.reference_config_file,\n            outputs=self.pipeline_config_group.get_ui_output_components() + [self._config_yaml],\n        )\n        generate_config_button.click(\n            self.on_generate_config_button_click,\n            inputs=set(self.pipeline_config_group.get_ui_input_components()),\n            outputs=self.pipeline_config_group.get_ui_output_components() + [self._config_yaml],\n        )\n\n        run_training_button.click(self.on_run_training_button_click, inputs=[], outputs=[])\n\n        # On app load, reset the configs based on the default reference config file.\n        # We'll wrap this in a try-except block to handle any errors during loading\n        def safe_load_config(file_path):\n            try:\n                return self.on_reset_config_button_click(file_path)\n            except Exception as e:\n                print(f\"Error during app.load for {self._name}: {e}\")\n                # Return empty values for all outputs to avoid UI errors\n                output_components = self.pipeline_config_group.get_ui_output_components() + [self._config_yaml]\n                return {comp: None for comp in output_components}\n\n        app.load(\n            safe_load_config,\n            inputs=self.reference_config_file,\n            outputs=self.pipeline_config_group.get_ui_output_components() + [self._config_yaml],\n        )\n\n    def on_reset_config_button_click(self, file_path: str):\n        try:\n            print(f\"Resetting UI configs for {self._name} to {file_path}.\")\n            default_config = load_config_from_yaml(file_path)\n\n            if not isinstance(default_config, self._pipeline_config_cls):\n                raise TypeError(\n                    f\"Wrong config type. Expected '{self._pipeline_config_cls.__name__}', got \"\n                    f\"'{type(default_config).__name__}'.\"\n                )\n\n            self._default_config = default_config\n            self._current_config = self._default_config.model_copy(deep=True)\n            update_dict = self.pipeline_config_group.update_ui_components_with_config_data(self._current_config)\n            update_dict.update({self._config_yaml: None})\n            return update_dict\n        except Exception as e:\n            print(f\"Error resetting config: {e}\")\n            # Return a minimal update dict to avoid UI errors\n            if self._current_config:\n                return {\n                    self._config_yaml: yaml.safe_dump(\n                        self._current_config.model_dump(), default_flow_style=False, sort_keys=False\n                    )\n                }\n            return {self._config_yaml: f\"Error loading config: {e}\"}\n\n    def on_generate_config_button_click(self, data: dict):\n        try:\n            print(f\"Generating config for {self._name}.\")\n            self._current_config = self.pipeline_config_group.update_config_with_ui_component_data(\n                self._current_config, data\n            )\n\n            # Roundtrip to make sure that the config is valid.\n            self._current_config = self._pipeline_config_cls.model_validate(self._current_config.model_dump())\n\n            # Update the UI to reflect the new state of the config\n            # (in case some values were rounded or otherwise modified\n            # in the process).\n            update_dict = self.pipeline_config_group.update_ui_components_with_config_data(self._current_config)\n            update_dict.update(\n                {\n                    self._config_yaml: yaml.safe_dump(\n                        self._current_config.model_dump(), default_flow_style=False, sort_keys=False\n                    )\n                }\n            )\n            return update_dict\n        except Exception as e:\n            print(f\"Error generating config: {e}\")\n            # Return a minimal update dict to avoid UI errors\n            if self._current_config:\n                return {\n                    self._config_yaml: yaml.safe_dump(\n                        self._current_config.model_dump(), default_flow_style=False, sort_keys=False\n                    )\n                }\n            return {self._config_yaml: f\"Error generating config: {e}\"}\n\n    def on_run_training_button_click(self):\n        self._run_training_cb(self._current_config)\n"
  },
  {
    "path": "src/invoke_training/ui/index.html",
    "content": "<!DOCTYPE html>\n<html lang=\"en\">\n<head>\n    <meta charset=\"UTF-8\">\n    <meta name=\"viewport\" content=\"width=device-width, initial-scale=1.0\">\n    <title>invoke-training</title>\n    <link rel=\"icon\" type=\"image/x-icon\" href=\"/assets/favicon.png\">\n    <link href=\"https://fonts.googleapis.com/css?family=Inter\" rel=\"stylesheet\">\n    <style>\n        body {\n            font-family: 'Inter';\n            font-size: 22px;\n        }\n        header {\n            margin: 20px;\n        }\n        #button_container {\n            margin-top: 40px;\n            margin-bottom: 40px;\n            margin-left: 20px;\n            margin-right: 20px;\n            display: flex;\n            flex-wrap: wrap;\n            justify-content: center;\n            align-items: center;\n        }\n        .main_link {\n            display: inline-block;\n            margin: 20px;\n            padding: 20px;\n            width: 300px;\n            border: 1px solid #000;\n            border-radius: 10px;\n            text-align: center;\n            text-decoration: none;\n            color: #000;\n            background-color: #ffffff;\n        }\n        .main_link:hover {\n            background-color: #f7f7f7;\n        }\n        .text-sm {\n            font-size: 16px;\n        }\n        .text-gray {\n            color: #6b7280;\n        }\n    </style>\n</head>\n\n<body>\n    <header>\n        <img src=\"/assets/logo.png\" alt=\"Invoke logo.\" width=\"200\">\n        <h1>invoke-training</h1>\n        <p><i>Invoke Training - </i><a href=\"https://invoke-ai.github.io/invoke-training/\" target=\"_blank\">Documentation</a></p>\n        <p>Learn more about Invoke at <a href=\"https://www.invoke.com/\" target=\"_blank\">invoke.com</a></p>\n    </header>\n    \n    <div id=\"button_container\">\n        <a href=\"/data\" class=\"main_link\">\n            <p>Datasets</p>\n            <p class=\"text-sm text-gray\">Prepare a dataset.</p>\n        </a>\n        <a href=\"/train\" class=\"main_link\">\n            <p>Training</p>\n            <p class=\"text-sm text-gray\">Train a model.</p>\n        </a>\n    </div>\n\n</body>\n</html>"
  },
  {
    "path": "src/invoke_training/ui/pages/data_page.py",
    "content": "from pathlib import Path\n\nimport gradio as gr\nfrom PIL import Image\n\nfrom invoke_training._shared.data.datasets.image_caption_jsonl_dataset import (\n    CAPTION_COLUMN_DEFAULT,\n    IMAGE_COLUMN_DEFAULT,\n    ImageCaptionExample,\n    ImageCaptionJsonlDataset,\n)\nfrom invoke_training._shared.utils.jsonl import save_jsonl\nfrom invoke_training.ui.gradio_blocks.header import Header\n\nIMAGE_EXTENSIONS = [\".jpg\", \".jpeg\", \".png\"]\n\n\nclass DataPage:\n    def __init__(self):\n        # The dataset that is currently being edited.\n        self._jsonl_path: str | None = None\n        self._dataset: ImageCaptionJsonlDataset | None = None\n\n        # Define the theme with dark mode as default\n        theme = gr.themes.Default(\n            # Optional: Customize colors, fonts, etc.\n            # primary_hue=gr.themes.colors.blue,\n            # ...\n        )\n        theme._dark_mode = True\n\n        # Custom CSS\n        custom_css = \"\"\"\n        .dark {\n            /* Override the default accent color for dark mode */\n            --color-accent: #e6fd13 !important;\n            --color-accent-soft: #e6fd1333 !important; /* Optional: Adjust soft accent too */\n        }\n\n        .dark .tabs button[aria-selected=\"true\"] {\n            /* Keep selected tab text color override */\n            color: #e6fd13 !important;\n            /* Optional: Remove background if --color-accent handles it */\n            /* background-color: transparent !important; */\n        }\n\n        /* Style checkbox checkmark in dark mode when checked */\n        .dark input[type=\"checkbox\"]:checked + span svg path {\n             /* Target the SVG path inside the checked checkbox */\n            stroke: black !important; /* Set the checkmark color to black */\n        }\n        \"\"\"\n\n        # Pass the theme and css to gr.Blocks\n        with gr.Blocks(\n            theme=theme,\n            css=custom_css,  # Use updated CSS\n            title=\"invoke-training\",\n            analytics_enabled=False,\n            head='<link rel=\"icon\" type=\"image/x-icon\" href=\"/assets/favicon.png\">',\n        ) as app:\n            self._header = Header()\n            gr.Markdown(\"# Data Annotation\")\n            gr.Markdown(\n                \"Note: This UI creates datasets in `IMAGE_CAPTION_JSONL_DATASET` format. For more information about \"\n                \"this format see [the docs](https://invoke-ai.github.io/invoke-training/concepts/dataset_formats/)\"\n            )\n\n            # HACK: I use a column as a wrapper to control visbility of this group of UI elements. gr.Group sounds like\n            # a more natural choice for this purpose, but it applies some styling that makes the group look weird.\n            with gr.Column() as select_dataset_group:\n                gr.Markdown(\"## Load Existing Dataset\")\n                with gr.Group():\n                    self._existing_jsonl_path = gr.Textbox(\n                        label=\"Existing .jsonl Path\",\n                        info=\"Enter the path to an existing dataset's .jsonl file.\",\n                        placeholder=\"/path/to/dataset.jsonl\",\n                    )\n                    with gr.Row():\n                        self._image_column_textbox = gr.Textbox(\n                            label=\"Image Column (Optional)\", placeholder=IMAGE_COLUMN_DEFAULT\n                        )\n                        self._caption_column_textbox = gr.Textbox(\n                            label=\"Caption Column (Optional)\", placeholder=CAPTION_COLUMN_DEFAULT\n                        )\n                    self._load_existing_dataset_button = gr.Button(\"Load Existing Dataset\")\n                gr.Markdown(\"## Create New Dataset\")\n                with gr.Group():\n                    self._new_jsonl_path = gr.Textbox(\n                        label=\"New .jsonl Path\",\n                        info=\"Enter the path for a new .jsonl file.\",\n                        placeholder=\"/path/to/dataset.jsonl\",\n                    )\n                    self._create_new_dataset_button = gr.Button(\"Create New Dataset\")\n            self._select_dataset_group = select_dataset_group\n\n            # HACK: I use a column as a wrapper to control visbility of this group of UI elements. gr.Group sounds like\n            # a more natural choice for this purpose, but it applies some styling that makes the group look weird.\n            with gr.Column(visible=False) as edit_dataset_group:\n                with gr.Row():\n                    self._current_jsonl_path = gr.Textbox(label=\"Currently editing:\", interactive=False)\n                    self._change_dataset_button = gr.Button(\"Change\")\n                gr.Markdown(\"## Add Images\")\n                with gr.Group():\n                    self._image_source_textbox = gr.Textbox(\n                        label=\"Image Source\",\n                        info=\"Enter the path to a single image or a directory containing images. If a directory path \"\n                        \"is passed, it will be searched recursively for image files.\",\n                        placeholder=\"/path/to/image_dir\",\n                    )\n                    self._add_images_button = gr.Button(\"Add Images\")\n\n                gr.Markdown(\"## Edit Captions\")\n                with gr.Row():\n                    with gr.Column():\n                        with gr.Row():\n                            self._cur_example_index = gr.Number(label=\"Current index\", precision=0, interactive=True)\n                            self._cur_len_number = gr.Number(label=\"Dataset length\", interactive=False)\n                        with gr.Row():\n                            self._beyond_dataset_limits_warning = gr.Markdown(\n                                \"**Current index is beyond dataset limits.** If you have completed all captions, click \"\n                                \"'Home' to begin training.\"\n                            )\n                        with gr.Row():\n                            self._cur_image = gr.Image(value=None, label=\"Image\", interactive=False, width=500)\n                    with gr.Column():\n                        self._cur_caption = gr.Textbox(label=\"Caption\", interactive=True, lines=25)\n\n                with gr.Row():\n                    self._save_and_prev_button = gr.Button(\"Save and Go-To Previous\")\n                    self._save_and_next_button = gr.Button(\"Save and Go-To Next\")\n\n                gr.Markdown(\"## Raw JSONL\")\n                self._data_jsonl = gr.Code(label=\"Dataset .jsonl\", language=\"json\", interactive=False)\n\n            self._edit_dataset_group = edit_dataset_group\n            self._app = app\n\n            standard_outputs = [\n                self._select_dataset_group,\n                self._edit_dataset_group,\n                self._current_jsonl_path,\n                self._cur_len_number,\n                self._cur_example_index,\n                self._cur_image,\n                self._cur_caption,\n                self._beyond_dataset_limits_warning,\n                self._data_jsonl,\n            ]\n\n            self._load_existing_dataset_button.click(\n                self._on_load_existing_dataset_button_click,\n                inputs=set([self._existing_jsonl_path, self._image_column_textbox, self._caption_column_textbox]),\n                outputs=standard_outputs,\n            )\n\n            self._create_new_dataset_button.click(\n                self._on_create_dataset_button_click,\n                inputs=set([self._new_jsonl_path]),\n                outputs=standard_outputs,\n            )\n\n            self._change_dataset_button.click(\n                self._on_change_dataset_button_click, inputs=None, outputs=standard_outputs\n            )\n            self._save_and_prev_button.click(\n                self._on_save_and_prev_button_click,\n                inputs=set([self._cur_example_index, self._cur_caption]),\n                outputs=standard_outputs,\n            )\n\n            self._save_and_next_button.click(\n                self._on_save_and_next_button_click,\n                inputs=set([self._cur_example_index, self._cur_caption]),\n                outputs=standard_outputs,\n            )\n\n            self._add_images_button.click(\n                self._on_add_images_button_click,\n                inputs=set([self._image_source_textbox]),\n                outputs=standard_outputs,\n            )\n\n            self._cur_example_index.input(\n                self._on_cur_example_index_change,\n                inputs=set([self._cur_example_index]),\n                outputs=standard_outputs,\n            )\n\n    def _update_state(self, idx: int):\n        if self._dataset is None or self._jsonl_path is None:\n            return {\n                self._select_dataset_group: gr.Group(visible=True),\n                self._edit_dataset_group: gr.Column(visible=False),\n                self._current_jsonl_path: None,\n                self._cur_len_number: 0,\n                self._cur_example_index: 0,\n                self._cur_image: None,\n                self._cur_caption: None,\n                self._beyond_dataset_limits_warning: gr.Markdown(visible=False),\n                self._data_jsonl: \"\",\n            }\n\n        idx = idx\n        image = None\n        caption = None\n        beyond_limits = True\n        if 0 <= idx and idx < len(self._dataset):\n            beyond_limits = False\n            example = self._dataset[idx]\n            image: Image.Image = example[\"image\"]\n            caption = example[\"caption\"]\n\n            # Resize the image to have a max dimension of 1024. On slow connections, sending the full-size image can be\n            # very slow.\n            max_dim = 1024\n            if image.width > max_dim or image.height > max_dim:\n                scale = max_dim / max(image.width, image.height)\n                image = image.resize((int(image.width * scale), int(image.height * scale)))\n\n        jsonl_str = \"\\n\".join([example.model_dump_json() for example in self._dataset.examples])\n        return {\n            self._select_dataset_group: gr.Group(visible=self._dataset is None),\n            self._edit_dataset_group: gr.Column(visible=self._dataset is not None),\n            self._current_jsonl_path: str(self._jsonl_path),\n            self._cur_len_number: len(self._dataset),\n            self._cur_example_index: idx,\n            self._cur_image: image,\n            self._cur_caption: caption,\n            self._beyond_dataset_limits_warning: gr.Markdown(visible=beyond_limits),\n            self._data_jsonl: jsonl_str,\n        }\n\n    def _on_load_existing_dataset_button_click(self, data: dict):\n        \"\"\"Load an existing dataset.\"\"\"\n        jsonl_path = Path(data[self._existing_jsonl_path])\n        jsonl_path = jsonl_path.resolve()\n        if not jsonl_path.exists():\n            raise ValueError(f\"'{jsonl_path}' does not exist.\")\n\n        self._jsonl_path = jsonl_path\n        self._dataset = ImageCaptionJsonlDataset(\n            jsonl_path=jsonl_path,\n            image_column=data[self._image_column_textbox] or IMAGE_COLUMN_DEFAULT,\n            caption_column=data[self._caption_column_textbox] or CAPTION_COLUMN_DEFAULT,\n        )\n        return self._update_state(0)\n\n    def _on_create_dataset_button_click(self, data: dict):\n        \"\"\"Create a new dataset.\"\"\"\n        jsonl_path = Path(data[self._new_jsonl_path])\n        jsonl_path = jsonl_path.resolve()\n        if jsonl_path.exists():\n            raise ValueError(f\"'{jsonl_path}' already exists.\")\n\n        if jsonl_path.suffix != \".jsonl\":\n            raise ValueError(\"Invalid file extension. Expected '.jsonl'.\")\n\n        print(f\"Creating new dataset at '{jsonl_path}'.\")\n        jsonl_path.parent.mkdir(parents=True, exist_ok=True)\n        # Create an empty jsonl file.\n        save_jsonl([], jsonl_path)\n\n        self._jsonl_path = jsonl_path\n        self._dataset = ImageCaptionJsonlDataset(jsonl_path=jsonl_path)\n\n        return self._update_state(0)\n\n    def _on_change_dataset_button_click(self):\n        self._jsonl_path = None\n        self._dataset = None\n        return self._update_state(0)\n\n    def _on_save_and_go_button_click(self, data: dict, idx_change: int):\n        # Update the current caption and re-save the jsonl file.\n        idx: int = data[self._cur_example_index]\n        if idx < 0 or idx >= len(self._dataset):\n            # idx is out of bounds, so don't update the caption, but still change the index.\n            return self._update_state(idx + idx_change)\n\n        print(f\"Updating caption for example {idx} of '{self._jsonl_path}'.\")\n        caption = data[self._cur_caption]\n        self._dataset.examples[idx].caption = caption\n        self._dataset.save_jsonl()\n\n        return self._update_state(idx + idx_change)\n\n    def _on_save_and_next_button_click(self, data: dict):\n        return self._on_save_and_go_button_click(data, 1)\n\n    def _on_save_and_prev_button_click(self, data: dict):\n        return self._on_save_and_go_button_click(data, -1)\n\n    def _on_cur_example_index_change(self, data: dict):\n        return self._update_state(data[self._cur_example_index])\n\n    def _on_add_images_button_click(self, data: dict):\n        \"\"\"Add images to the dataset.\"\"\"\n        image_source_path = Path(data[self._image_source_textbox])\n\n        if not image_source_path.exists():\n            raise ValueError(f\"'{image_source_path}' does not exist.\")\n\n        # Determine the list of image paths to add to the dataset.\n        image_paths = []\n        if image_source_path.is_file():\n            if image_source_path.suffix.lower() not in IMAGE_EXTENSIONS:\n                raise ValueError(\n                    f\"'{image_source_path}' is not a valid image file. Expected one of {IMAGE_EXTENSIONS}.\"\n                )\n\n            image_paths.append(image_source_path.resolve())\n        else:\n            # Recursively search for image files in the image_source_path directory.\n            for file_path in image_source_path.glob(\"**/*\"):\n                if file_path.is_file() and file_path.suffix.lower() in IMAGE_EXTENSIONS:\n                    image_paths.append(file_path.resolve())\n\n        # Avoid adding duplicate images.\n        cur_image_paths = set([Path(example.image_path) for example in self._dataset.examples])\n        image_paths = set(image_paths)\n        new_image_paths = image_paths - cur_image_paths\n        if len(new_image_paths) < len(image_paths):\n            print(f\"Skipping {len(image_paths) - len(new_image_paths)} images that are already in the dataset.\")\n\n        # Add the new images to the dataset.\n        print(f\"Adding {len(new_image_paths)} images to '{self._jsonl_path}'.\")\n        for image_path in new_image_paths:\n            self._dataset.examples.append(ImageCaptionExample(image_path=str(image_path), caption=\"\"))\n\n        # Save the updated dataset.\n        self._dataset.save_jsonl()\n\n        return self._update_state(0)\n\n    def app(self):\n        return self._app\n"
  },
  {
    "path": "src/invoke_training/ui/pages/training_page.py",
    "content": "import os\nimport subprocess\nimport tempfile\nimport time\n\nimport gradio as gr\nimport yaml\n\nfrom invoke_training.config.pipeline_config import PipelineConfig\nfrom invoke_training.pipelines.flux.lora.config import FluxLoraConfig\nfrom invoke_training.pipelines.stable_diffusion.lora.config import SdLoraConfig\nfrom invoke_training.pipelines.stable_diffusion.textual_inversion.config import SdTextualInversionConfig\nfrom invoke_training.pipelines.stable_diffusion_xl.finetune.config import SdxlFinetuneConfig\nfrom invoke_training.pipelines.stable_diffusion_xl.lora.config import SdxlLoraConfig\nfrom invoke_training.pipelines.stable_diffusion_xl.lora_and_textual_inversion.config import (\n    SdxlLoraAndTextualInversionConfig,\n)\nfrom invoke_training.pipelines.stable_diffusion_xl.textual_inversion.config import SdxlTextualInversionConfig\nfrom invoke_training.ui.config_groups.flux_lora_config_group import FluxLoraConfigGroup\nfrom invoke_training.ui.config_groups.sd_lora_config_group import SdLoraConfigGroup\nfrom invoke_training.ui.config_groups.sd_textual_inversion_config_group import SdTextualInversionConfigGroup\nfrom invoke_training.ui.config_groups.sdxl_finetune_config_group import SdxlFinetuneConfigGroup\nfrom invoke_training.ui.config_groups.sdxl_lora_and_textual_inversion_config_group import (\n    SdxlLoraAndTextualInversionConfigGroup,\n)\nfrom invoke_training.ui.config_groups.sdxl_lora_config_group import SdxlLoraConfigGroup\nfrom invoke_training.ui.config_groups.sdxl_textual_inversion_config_group import SdxlTextualInversionConfigGroup\nfrom invoke_training.ui.gradio_blocks.header import Header\nfrom invoke_training.ui.gradio_blocks.pipeline_tab import PipelineTab\nfrom invoke_training.ui.utils.utils import get_config_dir_path\n\n\nclass TrainingPage:\n    def __init__(self):\n        self._config_temp_directory = tempfile.TemporaryDirectory()\n        self._training_process = None\n\n        # Define the theme with dark mode as default\n        theme = gr.themes.Default()\n        theme._dark_mode = True\n\n        # Custom CSS\n        custom_css = \"\"\"\n        .dark {\n            /* Override the default accent color for dark mode */\n            --color-accent: #e6fd13 !important;\n            --color-accent-soft: #e6fd1333 !important; /* Optional: Adjust soft accent too */\n        }\n\n        .dark .tabs button[aria-selected=\"true\"] {\n            /* Keep selected tab text color override */\n            color: #e6fd13 !important;\n        }\n\n        /* Style checkbox checkmark in dark mode when checked */\n        .dark input[type=\"checkbox\"]:checked + span svg path {\n             /* Target the SVG path inside the checked checkbox */\n            stroke: black !important; /* Set the checkmark color to black */\n        }\n        \"\"\"\n\n        # Pass the theme and css to gr.Blocks\n        with gr.Blocks(\n            theme=theme,\n            css=custom_css,\n            title=\"invoke-training\",\n            analytics_enabled=False,\n            head=\"\"\"\n                <link rel=\"icon\" type=\"image/x-icon\" href=\"/assets/favicon.png\">\n                <script>\n                    window.addEventListener('beforeunload', function(e) {\n                        if (window.gradio_client) {\n                            try {\n                                window.gradio_client.cancel_all();\n                            } catch (err) {\n                                console.error('Error cancelling requests:', err);\n                            }\n                        }\n                    });\n                </script>\n            \"\"\",\n        ) as app:\n            self._header = Header()\n            with gr.Tab(label=\"SD LoRA\"):\n                PipelineTab(\n                    name=\"SD LoRA\",\n                    default_config_file_path=str(get_config_dir_path() / \"sd_lora_baroque_1x8gb.yaml\"),\n                    pipeline_config_cls=SdLoraConfig,\n                    config_group_cls=SdLoraConfigGroup,\n                    run_training_cb=self._run_training,\n                    app=app,\n                )\n            with gr.Tab(label=\"SDXL LoRA\"):\n                PipelineTab(\n                    name=\"SDXL LoRA\",\n                    default_config_file_path=str(get_config_dir_path() / \"sdxl_lora_baroque_1x24gb.yaml\"),\n                    pipeline_config_cls=SdxlLoraConfig,\n                    config_group_cls=SdxlLoraConfigGroup,\n                    run_training_cb=self._run_training,\n                    app=app,\n                )\n            with gr.Tab(label=\"SD Textual Inversion\"):\n                PipelineTab(\n                    name=\"SD Textual Inversion\",\n                    default_config_file_path=str(get_config_dir_path() / \"sd_textual_inversion_gnome_1x8gb.yaml\"),\n                    pipeline_config_cls=SdTextualInversionConfig,\n                    config_group_cls=SdTextualInversionConfigGroup,\n                    run_training_cb=self._run_training,\n                    app=app,\n                )\n            with gr.Tab(label=\"SDXL Textual Inversion\"):\n                PipelineTab(\n                    name=\"SDXL Textual Inversion\",\n                    default_config_file_path=str(get_config_dir_path() / \"sdxl_textual_inversion_gnome_1x24gb.yaml\"),\n                    pipeline_config_cls=SdxlTextualInversionConfig,\n                    config_group_cls=SdxlTextualInversionConfigGroup,\n                    run_training_cb=self._run_training,\n                    app=app,\n                )\n            with gr.Tab(label=\"SDXL LoRA and Textual Inversion\"):\n                PipelineTab(\n                    name=\"SDXL LoRA and Textual Inversion\",\n                    default_config_file_path=str(get_config_dir_path() / \"sdxl_lora_and_ti_gnome_1x24gb.yaml\"),\n                    pipeline_config_cls=SdxlLoraAndTextualInversionConfig,\n                    config_group_cls=SdxlLoraAndTextualInversionConfigGroup,\n                    run_training_cb=self._run_training,\n                    app=app,\n                )\n            with gr.Tab(label=\"SDXL Finetune\"):\n                PipelineTab(\n                    name=\"SDXL Finetune\",\n                    default_config_file_path=str(get_config_dir_path() / \"sdxl_finetune_baroque_1x24gb.yaml\"),\n                    pipeline_config_cls=SdxlFinetuneConfig,\n                    config_group_cls=SdxlFinetuneConfigGroup,\n                    run_training_cb=self._run_training,\n                    app=app,\n                )\n            with gr.Tab(label=\"Flux LoRA\"):\n                PipelineTab(\n                    name=\"Flux LoRA\",\n                    default_config_file_path=str(\n                        get_config_dir_path() / \"flux_lora_1x40gb.yaml\"\n                    ),  # Changed from 8gb to 40gb # noqa: E501\n                    pipeline_config_cls=FluxLoraConfig,\n                    config_group_cls=FluxLoraConfigGroup,\n                    run_training_cb=self._run_training,\n                    app=app,\n                )\n\n        self._app = app\n\n    def app(self):\n        return self._app\n\n    def _run_training(self, config: PipelineConfig):\n        # Check if there is already a training process running.\n        if self._training_process is not None:\n            if self._training_process.poll() is None:\n                print(\n                    \"Tried to start a new training process, but another training process is already running. \"\n                    \"Terminate the existing process first.\"\n                )\n                return\n            else:\n                self._training_process = None\n\n        print(f\"Starting {config.type} training...\")\n\n        # Write the config to a temporary config file where the training subprocess can read it.\n        timestamp = str(time.time()).replace(\".\", \"_\")\n        config_path = os.path.join(self._config_temp_directory.name, f\"{timestamp}.yaml\")\n        with open(config_path, \"w\") as f:\n            yaml.safe_dump(config.model_dump(), f, default_flow_style=False, sort_keys=False)\n\n        self._training_process = subprocess.Popen([\"invoke-train\", \"-c\", str(config_path)])\n\n        print(f\"Started {config.type} training.\")\n"
  },
  {
    "path": "src/invoke_training/ui/utils/prompts.py",
    "content": "NEGATIVE_PROMPT_DELIMITER = \"[NEG]\"\n\n\ndef split_pos_neg_prompts(prompt: str) -> tuple[str, str]:\n    \"\"\"Split a prompt containing a '[NEG]' delimiter into a positive prompt and a negative prompt.\n\n    Examples:\n    - 'positive prompt[NEG]negative prompt'     -> ('positive prompt', 'negative prompt')\n    - 'positive prompt'                         -> ('positive prompt', '')\n    - 'positive prompt[NEG]negative[NEG]prompt' -> Raises ValueError\n    \"\"\"\n    prompt = prompt.strip()\n\n    splits = prompt.split(NEGATIVE_PROMPT_DELIMITER)\n    if len(splits) == 1:\n        # This is a positive prompt only.\n        return splits[0], \"\"\n    elif len(splits) == 2:\n        # This is a positive prompt followed by a negative prompt.\n        return splits[0], splits[1]\n\n    raise ValueError(\n        f\"Failed to split the prompt into a positive and negative prompt. Expected at most one \"\n        f\"'{NEGATIVE_PROMPT_DELIMITER}' delimiter. Prompt: '{prompt}'.\"\n    )\n\n\ndef merge_pos_neg_prompts(positive_prompt: str, negative_prompt: str) -> str:\n    \"\"\"Merge a positive prompt and a negative prompt into a single prompt of the form:\n    'positive prompt[NEG]negative prompt'\n    \"\"\"\n    if NEGATIVE_PROMPT_DELIMITER in positive_prompt:\n        raise ValueError(\n            f\"Positive prompt cannot contain the '{NEGATIVE_PROMPT_DELIMITER}' delimiter. Prompt: '{positive_prompt}'\"\n        )\n    if NEGATIVE_PROMPT_DELIMITER in negative_prompt:\n        raise ValueError(\n            f\"Negative prompt cannot contain the '{NEGATIVE_PROMPT_DELIMITER}' delimiter. Prompt: '{negative_prompt}'\"\n        )\n\n    if negative_prompt == \"\":\n        return positive_prompt\n\n    return f\"{positive_prompt}{NEGATIVE_PROMPT_DELIMITER}{negative_prompt}\"\n\n\ndef convert_ui_prompts_to_pos_neg_prompts(prompts: str) -> tuple[list[str], list[str] | None]:\n    \"\"\"Convert prompts from the UI textbox format to lists of positive and negative prompts.\"\"\"\n\n    ui_prompt_list = prompts.split(\"\\n\")\n    positive_prompts = []\n    negative_prompts = []\n    for prompt in ui_prompt_list:\n        positive_prompt, negative_prompt = split_pos_neg_prompts(prompt)\n\n        # Skip empty lines.\n        if positive_prompt == \"\" and negative_prompt == \"\":\n            continue\n\n        positive_prompts.append(positive_prompt)\n        negative_prompts.append(negative_prompt)\n\n    # Convert negative_prompts to None if all negative prompts are empty.\n    if all([p == \"\" for p in negative_prompts]):\n        negative_prompts = None\n    return positive_prompts, negative_prompts\n\n\ndef convert_pos_neg_prompts_to_ui_prompts(positive_prompts: list[str], negative_prompts: list[str] | None) -> str:\n    \"\"\"Convert lists of positive and negative prompts to the UI textbox format.\"\"\"\n    if negative_prompts is None:\n        negative_prompts = [\"\"] * len(positive_prompts)\n\n    ui_prompts = \"\"\n    for positive_prompt, negative_prompt in zip(positive_prompts, negative_prompts, strict=True):\n        ui_prompts += merge_pos_neg_prompts(positive_prompt, negative_prompt) + \"\\n\"\n    return ui_prompts.strip()\n"
  },
  {
    "path": "src/invoke_training/ui/utils/utils.py",
    "content": "import typing\nfrom pathlib import Path\n\nimport yaml\nfrom pydantic import TypeAdapter\n\nfrom invoke_training.config.pipeline_config import PipelineConfig\n\n\ndef get_config_dir_path() -> Path:\n    p = Path(__file__).parent.parent.parent / \"sample_configs\"\n    if not p.exists():\n        raise FileNotFoundError(f\"Config directory not found: '{p}'\")\n    return p\n\n\ndef get_assets_dir_path() -> Path:\n    p = Path(__file__).parent.parent.parent / \"assets\"\n    if not p.exists():\n        pass\n    return p\n\n\ndef load_config_from_yaml(file_path: Path | str) -> PipelineConfig:\n    file_path = Path(file_path)\n    with open(file_path, \"r\") as f:\n        cfg = yaml.safe_load(f)\n\n    pipeline_adapter: TypeAdapter[PipelineConfig] = TypeAdapter(PipelineConfig)\n    train_config = pipeline_adapter.validate_python(cfg)\n\n    return train_config\n\n\ndef get_typing_literal_options(cls, field_name: str) -> list[str]:\n    literal_type_hint = typing.get_type_hints(cls)[field_name]\n    return list(typing.get_args(literal_type_hint))\n"
  },
  {
    "path": "tests/invoke_training/_shared/__init__.py",
    "content": ""
  },
  {
    "path": "tests/invoke_training/_shared/checkpoints/test_checkpoint_tracker.py",
    "content": "import os\nimport tempfile\nfrom pathlib import Path\n\nimport pytest\n\nfrom invoke_training._shared.checkpoints.checkpoint_tracker import CheckpointTracker\n\n\ndef test_checkpoint_tracker_get_path_file():\n    \"\"\"Test the CheckpointTracker.get_path(...) method with an extension.\"\"\"\n    checkpoint_tracker = CheckpointTracker(\n        base_dir=\"base_dir\",\n        prefix=\"prefix\",\n        extension=\".ckpt\",\n        index_padding=8,\n    )\n\n    path = checkpoint_tracker.get_path(epoch=1, step=55)\n\n    assert Path(path) == Path(\"base_dir/prefix-epoch_00000001-step_00000055.ckpt\")\n\n\ndef test_checkpoint_tracker_get_path_directory():\n    \"\"\"Test the CheckpointTracker.get_path(...) method without an extension.\"\"\"\n    checkpoint_tracker = CheckpointTracker(\n        base_dir=\"base_dir\",\n        prefix=\"prefix\",\n        extension=None,\n        index_padding=8,\n    )\n\n    path = checkpoint_tracker.get_path(epoch=1, step=55)\n\n    assert Path(path) == Path(\"base_dir/prefix-epoch_00000001-step_00000055\")\n\n\ndef test_checkpoint_tracker_bad_extension():\n    \"\"\"Test that CheckpointTracker raises a ValueError if an attempt is made to initialize it with an invalid\n    extension.\n    \"\"\"\n    with pytest.raises(ValueError):\n        _ = CheckpointTracker(base_dir=\"base_dir\", prefix=\"prefix\", extension=\"ckpt\")\n\n\ndef test_checkpoint_tracker_prune_files():\n    \"\"\"Test the CheckpointTracker.prune() method with checkpoint files.\"\"\"\n    with tempfile.TemporaryDirectory() as dir_name:\n        checkpoint_tracker = CheckpointTracker(base_dir=dir_name, prefix=\"prefix\", extension=\".ckpt\", max_checkpoints=5)\n        # Create 6 checkpoints.\n        for i in range(6):\n            path = checkpoint_tracker.get_path(epoch=0, step=i)\n            with open(path, \"w\") as f:\n                f.write(\"hi\")\n\n        # Prune the 3 checkpoints with the lowest step counts.\n        num_pruned = checkpoint_tracker.prune(2)\n        assert num_pruned == 3\n\n        # Verify that the correct checkpoints were pruned.\n        assert all([not os.path.exists(checkpoint_tracker.get_path(epoch=0, step=i)) for i in range(3)])\n        assert all([os.path.exists(checkpoint_tracker.get_path(epoch=0, step=i)) for i in range(3, 6)])\n\n\ndef test_checkpoint_tracker_prune_directories():\n    \"\"\"Test the CheckpointTracker.prune() method with checkpoint directories.\"\"\"\n    with tempfile.TemporaryDirectory() as dir_name:\n        checkpoint_tracker = CheckpointTracker(base_dir=dir_name, prefix=\"prefix\", extension=None, max_checkpoints=5)\n        # Create 6 checkpoints.\n        for i in range(6):\n            path = checkpoint_tracker.get_path(epoch=0, step=i)\n            # Create checkpoint directory and add file to it.\n            os.makedirs(path)\n            with open(os.path.join(path, \"tmp.txt\"), \"w\") as f:\n                f.write(\"hi\")\n\n        # Prune the 3 checkpoints with lowest indices.\n        num_pruned = checkpoint_tracker.prune(2)\n        assert num_pruned == 3\n\n        # Verify that the correct checkpoints were pruned.\n        assert all([not os.path.exists(checkpoint_tracker.get_path(epoch=0, step=i)) for i in range(3)])\n        assert all([os.path.exists(checkpoint_tracker.get_path(epoch=0, step=i)) for i in range(3, 6)])\n\n\ndef test_checkpoint_tracker_prune_no_max():\n    \"\"\"Test that CheckpointTracker.prune() is a no-op when max_checkpoints is None.\"\"\"\n    with tempfile.TemporaryDirectory() as dir_name:\n        checkpoint_tracker = CheckpointTracker(\n            base_dir=dir_name, prefix=\"prefix\", extension=\".ckpt\", max_checkpoints=None\n        )\n        # Create 6 checkpoints.\n        for i in range(6):\n            path = checkpoint_tracker.get_path(epoch=0, step=i)\n            with open(path, \"w\") as f:\n                f.write(\"hi\")\n\n        # Call prune, which should have no effect.\n        num_pruned = checkpoint_tracker.prune(2)\n        assert num_pruned == 0\n\n        # Verify that no checkpoints were deleted.\n        assert all([os.path.exists(checkpoint_tracker.get_path(epoch=0, step=i)) for i in range(6)])\n"
  },
  {
    "path": "tests/invoke_training/_shared/checkpoints/test_serialization.py",
    "content": "import os\nimport tempfile\n\nimport pytest\nimport torch\n\nfrom invoke_training._shared.checkpoints.serialization import (\n    load_state_dict,\n    save_state_dict,\n)\n\n\n@pytest.mark.parametrize(\"file_name\", [\"state.ckpt\", \"state.pt\", \"state.safetensors\"])\ndef test_state_dict_save_and_load_roundtrip(file_name):\n    with tempfile.TemporaryDirectory() as dir_name:\n        file_path = os.path.join(dir_name, file_name)\n\n        in_state_dict = {\"a\": torch.Tensor([1.0, 2.0])}\n\n        # Perform save-load roundtrip.\n        save_state_dict(in_state_dict, file_path)\n        out_state_dict = load_state_dict(file_path)\n\n    assert len(in_state_dict) == len(out_state_dict)\n    for key in in_state_dict:\n        assert torch.equal(in_state_dict[key], out_state_dict[key])\n\n\ndef test_save_state_dict_bad_extension():\n    \"\"\"Test that save_state_dict(...) raises a ValueError if it receives an unsupported file extension.\"\"\"\n    with pytest.raises(ValueError):\n        save_state_dict({}, \"state.txt\")\n\n\ndef test_load_state_dict_bad_extension():\n    \"\"\"Test that load_state_dict(...) raises a ValueError if it receives an unsupported file extension.\"\"\"\n    with pytest.raises(ValueError):\n        load_state_dict(\"state.txt\")\n"
  },
  {
    "path": "tests/invoke_training/_shared/data/__init__.py",
    "content": ""
  },
  {
    "path": "tests/invoke_training/_shared/data/data_loaders/__init__.py",
    "content": ""
  },
  {
    "path": "tests/invoke_training/_shared/data/data_loaders/test_dreambooth_sd_dataloader.py",
    "content": "import torch\n\nfrom invoke_training._shared.data.data_loaders.dreambooth_sd_dataloader import (\n    build_dreambooth_sd_dataloader,\n)\nfrom invoke_training.config.data.data_loader_config import AspectRatioBucketConfig, DreamboothSDDataLoaderConfig\nfrom invoke_training.config.data.dataset_config import ImageDirDatasetConfig\n\nfrom ..dataset_fixtures import image_dir  # noqa: F401\n\n\ndef test_build_dreambooth_sd_dataloader(image_dir):  # noqa: F811\n    \"\"\"Smoke test of build_dreambooth_sd_dataloader(...).\"\"\"\n    config = DreamboothSDDataLoaderConfig(\n        instance_caption=\"test instance prompt\",\n        instance_dataset=ImageDirDatasetConfig(dataset_dir=str(image_dir)),\n        class_caption=\"test class prompt\",\n        # For testing, we just use the same directory for the instance and class datasets.\n        class_dataset=ImageDirDatasetConfig(dataset_dir=str(image_dir)),\n    )\n    data_loader = build_dreambooth_sd_dataloader(config=config, batch_size=2)\n\n    assert len(data_loader) == 5  # (5 class images + 5 instance images) / batch size 2\n\n    example = next(iter(data_loader))\n    assert set(example.keys()) == {\"image\", \"id\", \"caption\", \"original_size_hw\", \"crop_top_left_yx\", \"loss_weight\"}\n\n    image = example[\"image\"]\n    assert image.shape == (2, 3, 512, 512)\n    assert image.dtype == torch.float32\n\n    original_size_hw = example[\"original_size_hw\"]\n    assert len(original_size_hw) == 2\n    assert len(original_size_hw[0]) == 2\n\n    crop_top_left_yx = example[\"crop_top_left_yx\"]\n    assert len(crop_top_left_yx) == 2\n    assert len(crop_top_left_yx[0]) == 2\n\n    caption = example[\"caption\"]\n    assert caption == [\"test instance prompt\", \"test class prompt\"]\n\n    loss_weight = example[\"loss_weight\"]\n    assert loss_weight.shape == (2,)\n    assert loss_weight.dtype == torch.float32\n\n\ndef test_build_dreambooth_sd_dataloader_no_class_dataset(image_dir):  # noqa: F811\n    \"\"\"Smoke test of build_dreambooth_sd_dataloader(...) without a class dataset.\"\"\"\n\n    config = DreamboothSDDataLoaderConfig(\n        instance_caption=\"test instance prompt\",\n        instance_dataset=ImageDirDatasetConfig(dataset_dir=str(image_dir)),\n    )\n    data_loader = build_dreambooth_sd_dataloader(config=config, batch_size=2)\n\n    assert len(data_loader) == 3  # 5 instance images, batch size 2\n\n    example = next(iter(data_loader))\n    assert set(example.keys()) == {\"image\", \"id\", \"caption\", \"original_size_hw\", \"crop_top_left_yx\", \"loss_weight\"}\n\n    image = example[\"image\"]\n    assert image.shape == (2, 3, 512, 512)\n    assert image.dtype == torch.float32\n\n    original_size_hw = example[\"original_size_hw\"]\n    assert len(original_size_hw) == 2\n    assert len(original_size_hw[0]) == 2\n\n    crop_top_left_yx = example[\"crop_top_left_yx\"]\n    assert len(crop_top_left_yx) == 2\n    assert len(crop_top_left_yx[0]) == 2\n\n    caption = example[\"caption\"]\n    assert caption == [\"test instance prompt\", \"test instance prompt\"]\n\n    loss_weight = example[\"loss_weight\"]\n    assert loss_weight.shape == (2,)\n    assert loss_weight.dtype == torch.float32\n\n\ndef test_build_dreambooth_sd_dataloader_with_bucketing(image_dir):  # noqa: F811\n    \"\"\"Smoke test of build_dreambooth_sd_dataloader(...).\"\"\"\n    config = DreamboothSDDataLoaderConfig(\n        instance_caption=\"test instance prompt\",\n        instance_dataset=ImageDirDatasetConfig(dataset_dir=str(image_dir)),\n        class_caption=\"test class prompt\",\n        # For testing, we just use the same directory for the instance and class datasets.\n        class_dataset=ImageDirDatasetConfig(dataset_dir=str(image_dir)),\n        aspect_ratio_buckets=AspectRatioBucketConfig(\n            target_resolution=256, start_dim=128, end_dim=512, divisible_by=64\n        ),\n    )\n    data_loader = build_dreambooth_sd_dataloader(config=config, batch_size=2, shuffle=False, sequential_batching=True)\n\n    assert len(data_loader) == 6  # 5 class images -> 3 batches + 5 instance images -> 3 batches\n\n    example = next(iter(data_loader))\n    assert set(example.keys()) == {\"image\", \"id\", \"caption\", \"original_size_hw\", \"crop_top_left_yx\", \"loss_weight\"}\n\n    image = example[\"image\"]\n    assert image.shape == (2, 3, 256, 256)\n    assert image.dtype == torch.float32\n\n    original_size_hw = example[\"original_size_hw\"]\n    assert len(original_size_hw) == 2\n    assert len(original_size_hw[0]) == 2\n\n    crop_top_left_yx = example[\"crop_top_left_yx\"]\n    assert len(crop_top_left_yx) == 2\n    assert len(crop_top_left_yx[0]) == 2\n\n    caption = example[\"caption\"]\n    assert caption == [\"test instance prompt\", \"test instance prompt\"]\n\n    loss_weight = example[\"loss_weight\"]\n    assert loss_weight.shape == (2,)\n    assert loss_weight.dtype == torch.float32\n"
  },
  {
    "path": "tests/invoke_training/_shared/data/data_loaders/test_image_caption_sd_dataloader.py",
    "content": "import math\n\nimport torch\n\nfrom invoke_training._shared.data.data_loaders.image_caption_sd_dataloader import build_image_caption_sd_dataloader\nfrom invoke_training.config.data.data_loader_config import ImageCaptionSDDataLoaderConfig\nfrom invoke_training.config.data.dataset_config import ImageCaptionJsonlDatasetConfig\n\nfrom ..dataset_fixtures import image_caption_jsonl  # noqa: F401\n\n\ndef test_build_image_caption_sd_dataloader(image_caption_jsonl):  # noqa: F811\n    \"\"\"Smoke test of build_image_caption_sd_dataloader(...).\"\"\"\n\n    config = ImageCaptionSDDataLoaderConfig(\n        dataset=ImageCaptionJsonlDatasetConfig(jsonl_path=str(image_caption_jsonl)),\n    )\n    data_loader = build_image_caption_sd_dataloader(config, 4)\n\n    # The dataset has length 5, so the data loader should have 2 batches.\n    assert len(data_loader) == math.ceil(5 / 4)\n\n    example = next(iter(data_loader))\n    assert set(example.keys()) == {\"image\", \"id\", \"caption\", \"original_size_hw\", \"crop_top_left_yx\"}\n\n    image = example[\"image\"]\n    assert image.shape == (4, 3, 512, 512)\n    assert image.dtype == torch.float32\n\n    assert len(example[\"caption\"]) == 4\n\n    original_size_hw = example[\"original_size_hw\"]\n    assert len(original_size_hw) == 4\n    assert len(original_size_hw[0]) == 2\n\n    crop_top_left_yx = example[\"crop_top_left_yx\"]\n    assert len(crop_top_left_yx) == 4\n    assert len(crop_top_left_yx[0]) == 2\n\n\ndef test_build_image_caption_sd_dataloader_with_masks(image_caption_jsonl):  # noqa: F811\n    \"\"\"Smoke test of build_image_caption_sd_dataloader(...).\"\"\"\n\n    config = ImageCaptionSDDataLoaderConfig(\n        dataset=ImageCaptionJsonlDatasetConfig(jsonl_path=str(image_caption_jsonl)),\n    )\n    data_loader = build_image_caption_sd_dataloader(config, 4, use_masks=True)\n\n    # The dataset has length 5, so the data loader should have 2 batches.\n    assert len(data_loader) == math.ceil(5 / 4)\n\n    example = next(iter(data_loader))\n    assert set(example.keys()) == {\"image\", \"mask\", \"id\", \"caption\", \"original_size_hw\", \"crop_top_left_yx\"}\n\n    image = example[\"image\"]\n    assert image.shape == (4, 3, 512, 512)\n    assert image.dtype == torch.float32\n\n    mask = example[\"mask\"]\n    assert mask.shape == (4, 1, 512, 512)\n    assert mask.dtype == torch.float32\n\n    assert len(example[\"caption\"]) == 4\n\n    original_size_hw = example[\"original_size_hw\"]\n    assert len(original_size_hw) == 4\n    assert len(original_size_hw[0]) == 2\n\n    crop_top_left_yx = example[\"crop_top_left_yx\"]\n    assert len(crop_top_left_yx) == 4\n    assert len(crop_top_left_yx[0]) == 2\n"
  },
  {
    "path": "tests/invoke_training/_shared/data/data_loaders/test_image_pair_preference_sd_dataloader.py",
    "content": "import pytest\nimport torch\n\nfrom invoke_training._shared.data.data_loaders.image_pair_preference_sd_dataloader import (\n    build_image_pair_preference_sd_dataloader,\n)\nfrom invoke_training.pipelines._experimental.sd_dpo_lora.config import (\n    HFHubImagePairPreferenceDatasetConfig,\n    ImagePairPreferenceSDDataLoaderConfig,\n)\n\n\n@pytest.mark.skip(\n    reason=\"No yuvalkirstain/pickapic_v2 dataset on HF Hub: https://huggingface.co/datasets/yuvalkirstain/pickapic_v2\"\n)\ndef test_build_image_pair_preference_sd_dataloader():\n    \"\"\"Smoke test of build_image_pair_preference_sd_dataloader(...).\"\"\"\n\n    config = ImagePairPreferenceSDDataLoaderConfig(dataset=HFHubImagePairPreferenceDatasetConfig())\n    data_loader = build_image_pair_preference_sd_dataloader(config, 4)\n\n    example = next(iter(data_loader))\n    assert set(example.keys()) == {\n        \"id\",\n        \"image_0\",\n        \"original_size_hw_0\",\n        \"crop_top_left_yx_0\",\n        \"prefer_0\",\n        \"image_1\",\n        \"original_size_hw_1\",\n        \"crop_top_left_yx_1\",\n        \"prefer_1\",\n        \"caption\",\n    }\n\n    for image_key in [\"image_0\", \"image_1\"]:\n        image = example[image_key]\n        assert image.shape == (4, 3, 512, 512)\n        assert image.dtype == torch.float32\n\n    assert len(example[\"caption\"]) == 4\n\n    for orig_size_key in [\"original_size_hw_0\", \"original_size_hw_1\"]:\n        original_size_hw = example[orig_size_key]\n        assert len(original_size_hw) == 4\n        assert len(original_size_hw[0]) == 2\n\n    for crop_key in [\"crop_top_left_yx_0\", \"crop_top_left_yx_1\"]:\n        crop_top_left_yx = example[crop_key]\n        assert len(crop_top_left_yx) == 4\n        assert len(crop_top_left_yx[0]) == 2\n"
  },
  {
    "path": "tests/invoke_training/_shared/data/data_loaders/test_textual_inversion_sd_dataloader.py",
    "content": "import torch\n\nfrom invoke_training._shared.data.data_loaders.textual_inversion_sd_dataloader import (\n    build_textual_inversion_sd_dataloader,\n)\nfrom invoke_training.config.data.data_loader_config import TextualInversionSDDataLoaderConfig\nfrom invoke_training.config.data.dataset_config import ImageCaptionJsonlDatasetConfig, ImageDirDatasetConfig\n\nfrom ..dataset_fixtures import (\n    image_caption_jsonl,  # noqa: F401\n    image_dir,  # noqa: F401\n)\n\n\ndef test_build_textual_inversion_sd_dataloader(image_dir):  # noqa: F811\n    \"\"\"Smoke test of build_textual_inversion_sd_dataloader(...).\"\"\"\n\n    config = TextualInversionSDDataLoaderConfig(\n        dataset=ImageDirDatasetConfig(dataset_dir=str(image_dir)), caption_preset=\"object\"\n    )\n\n    data_loader = build_textual_inversion_sd_dataloader(\n        config=config,\n        placeholder_token=\"placeholder\",\n        batch_size=2,\n    )\n\n    assert len(data_loader) == 3  # ceil(5 images / batch size 2)\n\n    example = next(iter(data_loader))\n    assert set(example.keys()) == {\"image\", \"id\", \"caption\", \"original_size_hw\", \"crop_top_left_yx\"}\n\n    image = example[\"image\"]\n    assert image.shape == (2, 3, 512, 512)\n    assert image.dtype == torch.float32\n\n    assert len(example[\"caption\"]) == 2\n    for caption in example[\"caption\"]:\n        assert \"placeholder\" in caption\n\n    original_size_hw = example[\"original_size_hw\"]\n    assert len(original_size_hw) == 2\n    assert len(original_size_hw[0]) == 2\n\n    crop_top_left_yx = example[\"crop_top_left_yx\"]\n    assert len(crop_top_left_yx) == 2\n    assert len(crop_top_left_yx[0]) == 2\n\n\ndef test_build_textual_inversion_sd_dataloader_keep_original_captions(image_caption_jsonl):  # noqa: F811\n    \"\"\"Test the keep_original_captions=True option.\"\"\"\n    config = TextualInversionSDDataLoaderConfig(\n        dataset=ImageCaptionJsonlDatasetConfig(jsonl_path=str(image_caption_jsonl)),\n        caption_templates=[\"{}\"],\n        keep_original_captions=True,\n    )\n\n    data_loader = build_textual_inversion_sd_dataloader(\n        config=config,\n        placeholder_token=\"placeholder\",\n        batch_size=2,\n    )\n\n    example = next(iter(data_loader))\n    assert set(example.keys()) == {\"image\", \"id\", \"caption\", \"original_size_hw\", \"crop_top_left_yx\"}\n\n    assert len(example[\"caption\"]) == 2\n    for caption in example[\"caption\"]:\n        assert caption.startswith(\"placeholder \")\n\n\ndef test_build_textual_inversion_sd_dataloader_with_masks(image_caption_jsonl):  # noqa: F811\n    \"\"\"Test the use_masks=True option.\"\"\"\n    config = TextualInversionSDDataLoaderConfig(\n        dataset=ImageCaptionJsonlDatasetConfig(jsonl_path=str(image_caption_jsonl)),\n        caption_templates=[\"{}\"],\n    )\n\n    data_loader = build_textual_inversion_sd_dataloader(\n        config=config,\n        placeholder_token=\"placeholder\",\n        batch_size=2,\n        use_masks=True,\n    )\n\n    example = next(iter(data_loader))\n    assert set(example.keys()) == {\"image\", \"mask\", \"id\", \"caption\", \"original_size_hw\", \"crop_top_left_yx\"}\n\n    image = example[\"image\"]\n    assert image.shape == (2, 3, 512, 512)\n    assert image.dtype == torch.float32\n\n    mask = example[\"mask\"]\n    assert mask.shape == (2, 1, 512, 512)\n    assert mask.dtype == torch.float32\n\n    assert len(example[\"caption\"]) == 2\n    for caption in example[\"caption\"]:\n        assert \"placeholder\" in caption\n\n    original_size_hw = example[\"original_size_hw\"]\n    assert len(original_size_hw) == 2\n    assert len(original_size_hw[0]) == 2\n\n    crop_top_left_yx = example[\"crop_top_left_yx\"]\n    assert len(crop_top_left_yx) == 2\n    assert len(crop_top_left_yx[0]) == 2\n"
  },
  {
    "path": "tests/invoke_training/_shared/data/dataset_fixtures.py",
    "content": "import numpy as np\nimport PIL.Image\nimport pytest\n\nfrom invoke_training._shared.data.datasets.image_pair_preference_dataset import ImagePairPreferenceDataset\nfrom invoke_training._shared.utils.jsonl import save_jsonl\n\n\n@pytest.fixture(scope=\"session\")\ndef image_dir(tmp_path_factory: pytest.TempPathFactory):\n    \"\"\"A fixture that populates a temp directory with some test images and returns the directory path.\n\n    Note that the 'session' scope is used to share the same directory across all tests in a session, because it is\n    costly to populate the directory.\n\n    Refer to https://docs.pytest.org/en/7.4.x/how-to/tmp_path.html#the-tmp-path-factory-fixture for details on the use\n    of tmp_path_factory.\n    \"\"\"\n    tmp_dir = tmp_path_factory.mktemp(\"dataset\")\n\n    for i in range(5):\n        rgb_np = np.ones((128, 128, 3), dtype=np.uint8)\n        rgb_pil = PIL.Image.fromarray(rgb_np)\n        rgb_pil.save(tmp_dir / f\"{i}.jpg\")\n\n    return tmp_dir\n\n\n@pytest.fixture(scope=\"session\")\ndef image_caption_dir(tmp_path_factory: pytest.TempPathFactory):\n    \"\"\"A fixture that populates a temp directory with some test images and caption files and returns the directory path.\n\n    Note that the 'session' scope is used to share the same directory across all tests in a session, because it is\n    costly to populate the directory.\n\n    Refer to https://docs.pytest.org/en/7.4.x/how-to/tmp_path.html#the-tmp-path-factory-fixture for details on the use\n    of tmp_path_factory.\n    \"\"\"\n    tmp_dir = tmp_path_factory.mktemp(\"dataset\")\n\n    for i in range(5):\n        rgb_np = np.ones((128, 128, 3), dtype=np.uint8)\n        rgb_pil = PIL.Image.fromarray(rgb_np)\n        rgb_pil.save(tmp_dir / f\"{i}.jpg\")\n\n        with open(tmp_dir / f\"{i}.txt\", \"w\") as f:\n            f.write(f\"caption {i}\")\n\n    return tmp_dir\n\n\n@pytest.fixture(scope=\"session\")\ndef image_caption_jsonl(tmp_path_factory: pytest.TempPathFactory):\n    \"\"\"A fixture that populates a temp directory with a ImageCaptionJsonlDataset and returns the jsonl file path.\n\n    Note that the 'session' scope is used to share the same directory across all tests in a session, because it is\n    costly to populate the directory.\n\n    Refer to https://docs.pytest.org/en/7.4.x/how-to/tmp_path.html#the-tmp-path-factory-fixture for details on the use\n    of tmp_path_factory.\n    \"\"\"\n    tmp_dir = tmp_path_factory.mktemp(\"dataset\")\n\n    masks_dir = tmp_dir / \"masks\"\n    masks_dir.mkdir()\n\n    data = []\n\n    for i in range(5):\n        rgb_np = np.ones((128, 128, 3), dtype=np.uint8)\n        rgb_pil = PIL.Image.fromarray(rgb_np)\n        rgb_rel_path = f\"{i}.jpg\"\n        rgb_pil.save(tmp_dir / rgb_rel_path)\n\n        mask_np = np.ones((128, 128), dtype=np.uint8)\n        mask_pil = PIL.Image.fromarray(mask_np).convert(\"L\")\n        mask_rel_path = f\"masks/{i}.png\"\n        mask_pil.save(tmp_dir / mask_rel_path)\n\n        data.append({\"image\": str(rgb_rel_path), \"mask\": str(mask_rel_path), \"text\": f\"caption {i}\"})\n\n    data_jsonl_path = tmp_dir / \"data.jsonl\"\n    save_jsonl(data, data_jsonl_path)\n    return data_jsonl_path\n\n\n@pytest.fixture(scope=\"session\")\ndef image_pair_preference_dir(tmp_path_factory: pytest.TempPathFactory):\n    \"\"\"A fixture that populates a temp directory with a mock dataset intended to be consumed by\n    ImagePairPreferenceDataset, and returns the directory path.\n\n    Note that the 'session' scope is used to share the same directory across all tests in a session, because it is\n    costly to populate the directory.\n\n    Refer to https://docs.pytest.org/en/7.4.x/how-to/tmp_path.html#the-tmp-path-factory-fixture for details on the use\n    of tmp_path_factory.\n    \"\"\"\n    tmp_dir = tmp_path_factory.mktemp(\"dataset\")\n\n    prompts = [\"mock prompt 1\", \"mock prompt 2\"]\n    metadata = []\n\n    for prompt_idx in range(len(prompts)):\n        for set_idx in range(3):\n            set_dir = tmp_dir / f\"prompt-{prompt_idx:0>4}\" / f\"set-{set_idx:0>4}\"\n            set_dir.mkdir(parents=True)\n            set_metadata_dict = {\"prompt\": prompts[prompt_idx]}\n            for image_idx in range(2):\n                rgb_np = np.ones((32, 32, 3), dtype=np.uint8)\n                rgb_pil = PIL.Image.fromarray(rgb_np)\n                image_path = set_dir / f\"image-{image_idx}.jpg\"\n                rgb_pil.save(image_path)\n                set_metadata_dict[f\"image_{image_idx}\"] = str(image_path.relative_to(tmp_dir))\n                set_metadata_dict[f\"prefer_{image_idx}\"] = image_idx == 0  # Always prefer image 0.\n\n            metadata.append(set_metadata_dict)\n\n    ImagePairPreferenceDataset.save_metadata(metadata=metadata, dataset_dir=tmp_dir)\n\n    return tmp_dir\n"
  },
  {
    "path": "tests/invoke_training/_shared/data/datasets/__init__.py",
    "content": ""
  },
  {
    "path": "tests/invoke_training/_shared/data/datasets/test_hf_image_caption_dataset.py",
    "content": "from pathlib import Path\n\nimport numpy as np\nimport PIL\nimport pytest\nfrom PIL import Image\n\nfrom invoke_training._shared.data.datasets.hf_image_caption_dataset import (\n    HFImageCaptionDataset,\n)\nfrom invoke_training._shared.data.utils.resolution import Resolution\nfrom invoke_training._shared.utils.jsonl import save_jsonl\n\n################################################\n# Tests for HFImageCaptionDataset.from_dir(...)\n################################################\n\n\ndef create_hf_imagefolder_dataset(tmp_dir: Path, num_images: int):\n    \"\"\"Construct a mock Hugging Face imagefolder dataset in a temporary directory.\n\n    Args:\n        tmp_dir (Path): The temporary directory where the mock dataset will be created.\n        num_images (int): The number of mock images to include in the dataset.\n    \"\"\"\n    # Construct mock images and save them to disk.\n    rel_img_paths = []\n    for i in range(num_images):\n        rgb_np = np.ones((128, 128, 3), dtype=np.uint8)\n        rgb_pil = Image.fromarray(rgb_np)\n        rel_img_path = f\"{i}.jpg\"\n        rel_img_paths.append(rel_img_path)\n        rgb_pil.save(tmp_dir / rel_img_path)\n\n    # Construct a mock metadata dict.\n    metadata = []\n    for rel_img_path in rel_img_paths:\n        metadata.append({\"file_name\": rel_img_path, \"text\": f\"Caption for {rel_img_path}\"})\n\n    # Write the metadata.jsonl to disk.\n    metadata_path = tmp_dir / \"metadata.jsonl\"\n    save_jsonl(metadata, metadata_path)\n\n\n@pytest.fixture(scope=\"session\")\ndef hf_imagefolder_dir(tmp_path_factory: pytest.TempPathFactory):\n    \"\"\"A fixture that prepares a temp directory with a mock Hugging Face imagefolder dataset and returns the directory\n    path.\n\n    Note that the 'session' scope is used to share the same directory across all tests in a session, because it is\n    costly to populate the directory.\n\n    Refer to https://docs.pytest.org/en/7.4.x/how-to/tmp_path.html#the-tmp-path-factory-fixture for details on the use\n    of tmp_path_factory.\n    \"\"\"\n    tmp_dir = tmp_path_factory.mktemp(\"dataset\")\n    create_hf_imagefolder_dataset(tmp_dir, 5)\n    return tmp_dir\n\n\n@pytest.fixture()\ndef hf_dir_dataset(hf_imagefolder_dir: Path):\n    return HFImageCaptionDataset.from_dir(str(hf_imagefolder_dir))\n\n\ndef test_hf_dir_image_caption_dataset_bad_image_column(hf_imagefolder_dir: Path):\n    \"\"\"Test that a ValueError is raised if HFImageCaptionDataset is initialized with an `image_column` that does not\n    exist.\n    \"\"\"\n    with pytest.raises(ValueError):\n        _ = HFImageCaptionDataset.from_dir(str(hf_imagefolder_dir), image_column=\"does_not_exist\")\n\n\ndef test_hf_dir_image_caption_dataset_bad_caption_column(hf_imagefolder_dir: Path):\n    \"\"\"Test that a ValueError is raised if HFImageCaptionDataset is initialized with a `caption_column` that does not\n    exist.\n    \"\"\"\n    with pytest.raises(ValueError):\n        _ = HFImageCaptionDataset.from_dir(str(hf_imagefolder_dir), caption_column=\"does_not_exist\")\n\n\ndef test_hf_dir_image_caption_dataset_len(hf_dir_dataset: HFImageCaptionDataset):\n    \"\"\"Test the behaviour of HFImageCaptionDataset.__len__().\"\"\"\n    assert len(hf_dir_dataset) == 5\n\n\ndef test_hf_dir_image_caption_dataset_index_error(hf_dir_dataset: HFImageCaptionDataset):\n    \"\"\"Test that an IndexError is raised if a dataset element is accessed with an index that is out-of-bounds.\"\"\"\n    with pytest.raises(IndexError):\n        _ = hf_dir_dataset[1000]\n\n\ndef test_hf_dir_image_caption_dataset_getitem(hf_dir_dataset: HFImageCaptionDataset):\n    \"\"\"Test that HFImageCaptionDataset.__getitem__(...) returns a valid example.\"\"\"\n    example = hf_dir_dataset[0]\n\n    assert set(example.keys()) == {\"image\", \"caption\", \"id\"}\n    assert isinstance(example[\"image\"], PIL.Image.Image)\n    assert example[\"image\"].mode == \"RGB\"\n    assert isinstance(example[\"caption\"], str)\n    assert example[\"id\"] == 0\n\n\ndef test_hf_dir_image_caption_dataset_get_image_dimensions(hf_dir_dataset: HFImageCaptionDataset):\n    \"\"\"Test HFImageCaptionDataset.get_image_dimensions().\"\"\"\n\n    image_dims = hf_dir_dataset.get_image_dimensions()\n\n    assert len(image_dims) == 5\n    for image_dim in image_dims:\n        assert image_dim == Resolution(128, 128)\n\n\n################################################\n# Tests for HFImageCaptionDataset.from_hub(...)\n################################################\n\n\n@pytest.mark.skip(reason=\"The lambdalabs/pokemon-blip-captions dataset is no longer available.\")\n@pytest.mark.loads_model\ndef test_hf_hub_image_caption_dataset_bad_image_column():\n    \"\"\"Test that a ValueError is raised if HFImageCaptionDataset is initialized with an `image_column` that does not\n    exist.\n    \"\"\"\n    with pytest.raises(ValueError):\n        _ = HFImageCaptionDataset.from_hub(\n            \"lambdalabs/pokemon-blip-captions\",\n            hf_load_dataset_kwargs={\"revision\": \"8b762e1dac1b31d60e01ee8f08a9d8a232b59e17\"},\n            image_column=\"does_not_exist\",\n        )\n\n\n@pytest.mark.skip(reason=\"The lambdalabs/pokemon-blip-captions dataset is no longer available.\")\n@pytest.mark.loads_model\ndef test_hf_hub_image_caption_dataset_bad_caption_column():\n    \"\"\"Test that a ValueError is raised if HFImageCaptionDataset is initialized with a `caption_column` that does not\n    exist.\n    \"\"\"\n    with pytest.raises(ValueError):\n        _ = HFImageCaptionDataset.from_hub(\n            \"lambdalabs/pokemon-blip-captions\",\n            hf_load_dataset_kwargs={\"revision\": \"8b762e1dac1b31d60e01ee8f08a9d8a232b59e17\"},\n            caption_column=\"does_not_exist\",\n        )\n\n\n@pytest.fixture\ndef hf_hub_dataset():\n    return HFImageCaptionDataset.from_hub(\n        \"lambdalabs/pokemon-blip-captions\",\n        hf_load_dataset_kwargs={\"revision\": \"8b762e1dac1b31d60e01ee8f08a9d8a232b59e17\"},\n    )\n\n\n@pytest.mark.skip(reason=\"The lambdalabs/pokemon-blip-captions dataset is no longer available.\")\n@pytest.mark.loads_model\ndef test_hf_hub_image_caption_dataset_index_error(hf_hub_dataset: HFImageCaptionDataset):\n    \"\"\"Test that an IndexError is raised if a dataset element is accessed with an index that is out-of-bounds.\"\"\"\n    with pytest.raises(IndexError):\n        _ = hf_hub_dataset[1000]\n\n\n@pytest.mark.skip(reason=\"The lambdalabs/pokemon-blip-captions dataset is no longer available.\")\n@pytest.mark.loads_model\ndef test_hf_hub_image_caption_dataset_len(hf_hub_dataset: HFImageCaptionDataset):\n    \"\"\"Test the behaviour of HFImageCaptionDataset.__len__().\"\"\"\n    # Expected dataset length was checked manually here:\n    # https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions\n    assert len(hf_hub_dataset) == 833\n\n\n@pytest.mark.skip(reason=\"The lambdalabs/pokemon-blip-captions dataset is no longer available.\")\n@pytest.mark.loads_model\ndef test_hf_hub_image_caption_dataset_getitem(hf_hub_dataset: HFImageCaptionDataset):\n    \"\"\"Test that HFImageCaptionDataset.__getitem__(...) returns a valid example.\"\"\"\n    example = hf_hub_dataset[0]\n\n    assert set(example.keys()) == {\"image\", \"caption\", \"id\"}\n    assert isinstance(example[\"image\"], PIL.Image.Image)\n    assert example[\"image\"].mode == \"RGB\"\n    assert isinstance(example[\"caption\"], str)\n    assert example[\"id\"] == 0\n\n\n@pytest.mark.skip(reason=\"The lambdalabs/pokemon-blip-captions dataset is no longer available.\")\n@pytest.mark.loads_model\ndef test_hf_hub_image_caption_dataset_get_image_dimensions(hf_hub_dataset: HFImageCaptionDataset):\n    \"\"\"Test HFImageCaptionDataset.get_image_dimensions().\"\"\"\n    image_dims = hf_hub_dataset.get_image_dimensions()\n\n    # This is just a smoke test. We don't actually check that the dimensions are correct.\n    assert len(image_dims) == 833\n"
  },
  {
    "path": "tests/invoke_training/_shared/data/datasets/test_hf_image_pair_preference_dataset.py",
    "content": "import pytest\nfrom datasets import VerificationMode\nfrom PIL.Image import Image\n\nfrom invoke_training._shared.data.datasets.hf_image_pair_preference_dataset import HFImagePairPreferenceDataset\n\n\n@pytest.mark.loads_model\ndef test_hf_hub_image_caption_dataset_getitem():\n    \"\"\"Test that HFImagePairPreferenceDataset.__getitem__(...) returns a valid example.\"\"\"\n    # HACK(ryand): This funky configuration is done so that we just download a small slice of the very large\n    # 'yuvalkirstain/pickapic_v2' dataset.\n    dataset = HFImagePairPreferenceDataset.from_hub(\n        \"yuvalkirstain/pickapic_v2\",\n        split=\"validation_unique\",\n        hf_load_dataset_kwargs={\n            \"data_files\": {\n                \"validation_unique\": \"data/validation_unique-00000-of-00001-33ead111845fc9c4.parquet\",\n            },\n            # Disable checks so that it doesn't complain that I haven't downloaded the other splits.\n            \"verification_mode\": VerificationMode.NO_CHECKS,\n        },\n    )\n\n    example = dataset[0]\n\n    assert set(example.keys()) == {\"id\", \"image_0\", \"image_1\", \"prefer_0\", \"prefer_1\", \"caption\"}\n\n    assert example[\"id\"] == 0\n\n    assert isinstance(example[\"image_0\"], Image)\n    assert example[\"image_0\"].mode == \"RGB\"\n    assert isinstance(example[\"image_1\"], Image)\n    assert example[\"image_1\"].mode == \"RGB\"\n\n    assert isinstance(example[\"prefer_0\"], bool)\n    assert isinstance(example[\"prefer_1\"], bool)\n    # The following is not always true, but is usually true.\n    assert example[\"prefer_0\"] != example[\"prefer_1\"]\n\n    assert isinstance(example[\"caption\"], str)\n\n\n@pytest.mark.loads_model\ndef test_hf_hub_image_caption_dataset_len():\n    \"\"\"Test that HFImagePairPreferenceDataset.__len__(...) returns the correct value.\"\"\"\n    # HACK(ryand): This funky configuration is done so that we just download a small slice of the very large\n    # 'yuvalkirstain/pickapic_v2' dataset.\n    dataset = HFImagePairPreferenceDataset.from_hub(\n        \"yuvalkirstain/pickapic_v2\",\n        skip_no_preference=False,\n        split=\"validation_unique\",\n        hf_load_dataset_kwargs={\n            \"data_files\": {\n                \"validation_unique\": \"data/validation_unique-00000-of-00001-33ead111845fc9c4.parquet\",\n            },\n            # Disable checks so that it doesn't complain that I haven't downloaded the other splits.\n            \"verification_mode\": VerificationMode.NO_CHECKS,\n        },\n    )\n\n    assert len(dataset) == 500\n\n\n@pytest.mark.loads_model\ndef test_hf_hub_image_caption_dataset_skip_no_preference_len():\n    \"\"\"Test the HFImagePairPreferenceDataset skip_no_preference parameter.\"\"\"\n    # HACK(ryand): This funky configuration is done so that we just download a small slice of the very large\n    # 'yuvalkirstain/pickapic_v2' dataset.\n    dataset = HFImagePairPreferenceDataset.from_hub(\n        \"yuvalkirstain/pickapic_v2\",\n        skip_no_preference=True,\n        split=\"validation_unique\",\n        hf_load_dataset_kwargs={\n            \"data_files\": {\n                \"validation_unique\": \"data/validation_unique-00000-of-00001-33ead111845fc9c4.parquet\",\n            },\n            # Disable checks so that it doesn't complain that I haven't downloaded the other splits.\n            \"verification_mode\": VerificationMode.NO_CHECKS,\n        },\n    )\n\n    assert len(dataset) == 429\n"
  },
  {
    "path": "tests/invoke_training/_shared/data/datasets/test_image_caption_dir_dataset.py",
    "content": "from pathlib import Path\n\nimport PIL.Image\nimport pytest\n\nfrom invoke_training._shared.data.datasets.image_caption_dir_dataset import ImageCaptionDirDataset\n\nfrom ..dataset_fixtures import image_caption_dir  # noqa: F401\n\n\ndef test_image_caption_dir_dataset_len(image_caption_dir):  # noqa: F811\n    dataset = ImageCaptionDirDataset(str(image_caption_dir))\n\n    assert len(dataset) == 5\n\n\ndef test_image_caption_dir_dataset_getitem(image_caption_dir):  # noqa: F811\n    dataset = ImageCaptionDirDataset(str(image_caption_dir))\n\n    example = dataset[0]\n\n    assert set(example.keys()) == {\"image\", \"id\", \"caption\"}\n    assert isinstance(example[\"image\"], PIL.Image.Image)\n    assert example[\"image\"].mode == \"RGB\"\n    assert example[\"id\"] == \"0\"\n    assert example[\"caption\"] == \"caption 0\"\n\n\ndef test_image_caption_dir_dataset_keep_in_memory(image_caption_dir):  # noqa: F811\n    dataset = ImageCaptionDirDataset(str(image_caption_dir), keep_in_memory=True)\n\n    example = dataset[0]\n\n    assert set(example.keys()) == {\"image\", \"id\", \"caption\"}\n    assert isinstance(example[\"image\"], PIL.Image.Image)\n    assert example[\"image\"].mode == \"RGB\"\n    assert example[\"id\"] == \"0\"\n    assert example[\"caption\"] == \"caption 0\"\n\n\ndef test_image_caption_dir_dataset_get_image_dimensions(image_caption_dir):  # noqa: F811\n    dataset = ImageCaptionDirDataset(str(image_caption_dir))\n\n    image_dims = dataset.get_image_dimensions()\n\n    assert len(image_dims) == len(dataset)\n\n\ndef test_image_caption_dir_dataset_missing_caption_file(tmp_path: Path):  # noqa: F811\n    # Create a directory with an image but no caption file.\n    with open(tmp_path / \"0.jpg\", \"w\"):\n        pass\n\n    with pytest.raises(Exception, match=r\"The following expected caption files are missing: \\['.*0.txt'\\]\"):\n        ImageCaptionDirDataset(str(tmp_path))\n"
  },
  {
    "path": "tests/invoke_training/_shared/data/datasets/test_image_caption_jsonl_dataset.py",
    "content": "import shutil\nfrom pathlib import Path\n\nimport PIL.Image\n\nfrom invoke_training._shared.data.datasets.image_caption_jsonl_dataset import ImageCaptionJsonlDataset\nfrom invoke_training._shared.utils.jsonl import load_jsonl\n\nfrom ..dataset_fixtures import image_caption_jsonl  # noqa: F401\n\n\ndef test_image_caption_jsonl_dataset_len(image_caption_jsonl):  # noqa: F811\n    dataset = ImageCaptionJsonlDataset(str(image_caption_jsonl))\n\n    assert len(dataset) == 5\n\n\ndef test_image_caption_jsonl_dataset_getitem(image_caption_jsonl):  # noqa: F811\n    dataset = ImageCaptionJsonlDataset(str(image_caption_jsonl))\n\n    example = dataset[0]\n\n    assert set(example.keys()) == {\"image\", \"id\", \"caption\", \"mask\"}\n    assert isinstance(example[\"image\"], PIL.Image.Image)\n    assert example[\"image\"].mode == \"RGB\"\n    assert example[\"id\"] == \"0\"\n    assert example[\"caption\"] == \"caption 0\"\n    assert isinstance(example[\"mask\"], PIL.Image.Image)\n    assert example[\"mask\"].mode == \"L\"\n\n\ndef test_image_caption_jsonl_dataset_keep_in_memory(image_caption_jsonl):  # noqa: F811\n    dataset = ImageCaptionJsonlDataset(str(image_caption_jsonl), keep_in_memory=True)\n\n    example = dataset[0]\n\n    assert set(example.keys()) == {\"image\", \"id\", \"caption\", \"mask\"}\n    assert isinstance(example[\"image\"], PIL.Image.Image)\n    assert example[\"image\"].mode == \"RGB\"\n    assert example[\"id\"] == \"0\"\n    assert example[\"caption\"] == \"caption 0\"\n    assert isinstance(example[\"mask\"], PIL.Image.Image)\n    assert example[\"mask\"].mode == \"L\"\n\n    # Confirm that accessing the same example again returns a shallow copy of the original example.\n    # In other words, modifying the returned dict should not modify the cached example, but the same image should be\n    # returned.\n    same_example = dataset[0]\n    assert same_example is not example\n    assert same_example[\"image\"] is example[\"image\"]\n\n\ndef test_image_caption_jsonl_dataset_get_image_dimensions(image_caption_jsonl):  # noqa: F811\n    dataset = ImageCaptionJsonlDataset(str(image_caption_jsonl))\n\n    image_dims = dataset.get_image_dimensions()\n\n    assert len(image_dims) == len(dataset)\n\n\ndef test_image_caption_jsonl_dataset_save_jsonl(image_caption_jsonl, tmp_path: Path):  # noqa: F811\n    # Create a copy of the image_caption_jsonl file to avoid modifying the original file.\n    image_caption_jsonl_copy = tmp_path / \"test.jsonl\"\n    shutil.copy(image_caption_jsonl, image_caption_jsonl_copy)\n\n    # Load the dataset from the copied jsonl file.\n    dataset = ImageCaptionJsonlDataset(str(image_caption_jsonl))\n\n    # Save the dataset to a new jsonl file.\n    dataset.save_jsonl()\n\n    # Verify that the roundtrip was successful.\n    assert image_caption_jsonl != image_caption_jsonl_copy\n    original_jsonl = load_jsonl(image_caption_jsonl)\n    roundtrip_jsonl = load_jsonl(image_caption_jsonl_copy)\n    assert original_jsonl == roundtrip_jsonl\n"
  },
  {
    "path": "tests/invoke_training/_shared/data/datasets/test_image_dir_dataset.py",
    "content": "import PIL.Image\n\nfrom invoke_training._shared.data.datasets.image_dir_dataset import ImageDirDataset\n\nfrom ..dataset_fixtures import image_dir  # noqa: F401\n\n\ndef test_image_dir_dataset_len(image_dir):  # noqa: F811\n    dataset = ImageDirDataset(str(image_dir))\n\n    assert len(dataset) == 5\n\n\ndef test_image_dir_dataset_getitem(image_dir):  # noqa: F811\n    dataset = ImageDirDataset(str(image_dir))\n\n    example = dataset[0]\n\n    assert set(example.keys()) == {\"image\", \"id\"}\n    assert isinstance(example[\"image\"], PIL.Image.Image)\n    assert example[\"image\"].mode == \"RGB\"\n    assert example[\"id\"] == \"0\"\n\n\ndef test_image_dir_dataset_keep_in_memory(image_dir):  # noqa: F811\n    dataset = ImageDirDataset(str(image_dir), keep_in_memory=True)\n\n    example = dataset[0]\n\n    assert set(example.keys()) == {\"image\", \"id\"}\n    assert isinstance(example[\"image\"], PIL.Image.Image)\n    assert example[\"image\"].mode == \"RGB\"\n    assert example[\"id\"] == \"0\"\n\n    # Confirm that accessing the same example again returns a shallow copy of the original example.\n    # In other words, modifying the returned dict should not modify the cached example, but the same image should be\n    # returned.\n    same_example = dataset[0]\n    assert same_example is not example\n    assert same_example[\"image\"] is example[\"image\"]\n\n\ndef test_image_dir_dataset_get_image_dimensions(image_dir):  # noqa: F811\n    dataset = ImageDirDataset(str(image_dir))\n\n    image_dims = dataset.get_image_dimensions()\n\n    assert len(image_dims) == len(dataset)\n"
  },
  {
    "path": "tests/invoke_training/_shared/data/datasets/test_image_pair_preference_dataset.py",
    "content": "import PIL.Image\n\nfrom invoke_training._shared.data.datasets.image_pair_preference_dataset import ImagePairPreferenceDataset\n\nfrom ..dataset_fixtures import image_pair_preference_dir  # noqa: F401\n\n\ndef test_image_dir_dataset_len(image_pair_preference_dir):  # noqa: F811\n    dataset = ImagePairPreferenceDataset(str(image_pair_preference_dir))\n\n    assert len(dataset) == 6\n\n\ndef test_image_dir_dataset_getitem(image_pair_preference_dir):  # noqa: F811\n    dataset = ImagePairPreferenceDataset(str(image_pair_preference_dir))\n\n    example = dataset[0]\n\n    assert set(example.keys()) == {\"id\", \"image_0\", \"image_1\", \"caption\", \"prefer_0\", \"prefer_1\"}\n\n    assert example[\"id\"] == \"0\"\n\n    assert isinstance(example[\"image_0\"], PIL.Image.Image)\n    assert example[\"image_0\"].mode == \"RGB\"\n    assert isinstance(example[\"image_1\"], PIL.Image.Image)\n    assert example[\"image_1\"].mode == \"RGB\"\n\n    assert example[\"prefer_0\"]\n    assert not example[\"prefer_1\"]\n"
  },
  {
    "path": "tests/invoke_training/_shared/data/datasets/test_transform_dataset.py",
    "content": "import unittest.mock\n\nfrom invoke_training._shared.data.datasets.transform_dataset import TransformDataset\n\n\ndef test_transform_dataset_len():\n    \"\"\"Test the TransformDataset len() function.\"\"\"\n    mock_dataset = unittest.mock.MagicMock()\n    mock_dataset.__len__.return_value = 5\n\n    dataset = TransformDataset(mock_dataset, [])\n\n    assert len(dataset) == 5\n\n\ndef test_transform_dataset_getitem():\n    \"\"\"Test the TransformDataset __getitem__() function.\"\"\"\n    field1 = 1\n    field2 = \"2\"\n    base_example = {\"field1\": field1}\n\n    mock_dataset = unittest.mock.MagicMock()\n    mock_dataset.__getitem__.return_value = base_example\n\n    def mock_transform(example):\n        example[\"field2\"] = field2\n        return example\n\n    dataset = TransformDataset(mock_dataset, [mock_transform])\n\n    out_example = dataset[0]\n\n    assert out_example[\"field1\"] == field1\n    assert out_example[\"field2\"] == field2\n"
  },
  {
    "path": "tests/invoke_training/_shared/data/samplers/__init__.py",
    "content": ""
  },
  {
    "path": "tests/invoke_training/_shared/data/samplers/test_aspect_ratio_bucket_batch_sampler.py",
    "content": "from invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import (\n    AspectRatioBucketBatchSampler,\n)\nfrom invoke_training._shared.data.utils.aspect_ratio_bucket_manager import AspectRatioBucketManager\nfrom invoke_training._shared.data.utils.resolution import Resolution\n\n\ndef assert_shuffled_samples_match(samples_1, samples_2):\n    \"\"\"Utility function to assert that two batch sampler outputs are equivalent aside from having been shuffled.\"\"\"\n    # Same number of batches.\n    assert len(samples_1) == len(samples_2)\n    # Same total number of examples.\n    assert sum([len(b) for b in samples_1]) == sum([len(b) for b in samples_2])\n    # Same set of examples.\n    assert {x for batch in samples_1 for x in batch} == {x for batch in samples_2 for x in batch}\n\n\ndef test_aspect_ratio_bucket_batch_sampler():\n    \"\"\"Basic test of AspectRatioBucketBatchSampler.\"\"\"\n    sampler = AspectRatioBucketBatchSampler(\n        buckets={Resolution(256, 768): [1, 3, 5], Resolution(512, 512): [4], Resolution(768, 256): [0, 2]},\n        batch_size=2,\n        shuffle=False,\n        seed=None,\n    )\n\n    assert list(sampler) == [[1, 3], [5], [4], [0, 2]]\n\n\ndef test_aspect_ratio_bucket_batch_sampler_len():\n    \"\"\"Basic test of AspectRatioBucketBatchSampler len(...) function.\"\"\"\n    sampler = AspectRatioBucketBatchSampler(\n        buckets={Resolution(256, 768): [1, 3, 5], Resolution(512, 512): [4], Resolution(768, 256): [0, 2]},\n        batch_size=2,\n        shuffle=False,\n        seed=None,\n    )\n\n    assert len(sampler) == len(list(sampler))\n\n\ndef test_aspect_ratio_bucket_batch_sampler_from_image_sizes():\n    \"\"\"Test AspectRatioBucketBatchSampler when initialized with AspectRatioBucketBatchSampler.from_image_size(...).\"\"\"\n    # Configure bucket_manager to have the following aspect ratio buckets:\n    # (256, 1024), (256, 768), (512, 512), (768, 256), (1024, 768)\n    bucket_manager = AspectRatioBucketManager.from_constraints(\n        target_resolution=512, start_dim=256, end_dim=768, divisible_by=256\n    )\n\n    image_sizes = [\n        Resolution(256, 768),  # Bucket 1 (256, 768)\n        Resolution(512, 512),  # Bucket 2 (512, 512)\n        Resolution(768, 256),  # Bucket 3 (768, 256)\n        Resolution(264, 768),  # Bucket 1 (256, 768)\n        Resolution(272, 768),  # Bucket 1 (256, 768)\n        Resolution(768, 264),  # Bucket 3 (768, 256)\n    ]\n\n    sampler = AspectRatioBucketBatchSampler.from_image_sizes(\n        bucket_manager=bucket_manager, image_sizes=image_sizes, batch_size=2, shuffle=False\n    )\n\n    assert list(sampler) == [[0, 3], [4], [1], [2, 5]]\n\n\ndef test_aspect_ratio_bucket_batch_sampler_shuffle():\n    \"\"\"Test AspectRatioBucketBatchSampler shuffle behavior.\"\"\"\n    buckets = {Resolution(256, 512): [1, 3, 5, 6, 7], Resolution(512, 512): [4], Resolution(512, 256): [0, 2]}\n    batch_size = 2\n    unshuffled_sampler = AspectRatioBucketBatchSampler(buckets=buckets, batch_size=batch_size, shuffle=False, seed=None)\n    shuffled_sampler = AspectRatioBucketBatchSampler(buckets=buckets, batch_size=batch_size, shuffle=True, seed=None)\n\n    unshuffled_samples = list(unshuffled_sampler)\n    shuffled_samples = list(shuffled_sampler)\n\n    assert_shuffled_samples_match(shuffled_samples, unshuffled_samples)\n    # Not equal, because one is shuffled.\n    assert shuffled_samples != unshuffled_samples\n\n\ndef test_aspect_ratio_bucket_batch_sampler_seed():\n    \"\"\"Test AspectRatioBucketBatchSampler seed behavior.\"\"\"\n    buckets = {Resolution(256, 512): [1, 3, 5, 6, 7], Resolution(512, 512): [4], Resolution(512, 256): [0, 2]}\n    batch_size = 2\n    base_sampler = AspectRatioBucketBatchSampler(buckets=buckets, batch_size=batch_size, shuffle=True, seed=1)\n    same_seed_sampler = AspectRatioBucketBatchSampler(buckets=buckets, batch_size=batch_size, shuffle=True, seed=1)\n    diff_seed_sampler = AspectRatioBucketBatchSampler(buckets=buckets, batch_size=batch_size, shuffle=True, seed=2)\n\n    base_samples = list(base_sampler)\n    same_seed_samples = list(same_seed_sampler)\n    diff_seed_samples = list(diff_seed_sampler)\n\n    # Samples generated with the same seed should match exactly.\n    assert base_samples == same_seed_samples\n\n    # Samples generated with different seeds should match, except for the example ordering.\n    assert_shuffled_samples_match(base_samples, diff_seed_samples)\n    assert base_samples != diff_seed_samples\n"
  },
  {
    "path": "tests/invoke_training/_shared/data/samplers/test_batch_offset_sampler.py",
    "content": "from torch.utils.data.sampler import BatchSampler, SequentialSampler\n\nfrom invoke_training._shared.data.samplers.batch_offset_sampler import BatchOffsetSampler\n\n\ndef test_batch_offset_sampler():\n    \"\"\"Test that the BatchOffsetSampler yields the correct sequence of values.\"\"\"\n    sequential_sampler = SequentialSampler([0] * 5)\n    batch_sampler = BatchSampler(sequential_sampler, batch_size=2, drop_last=False)\n\n    batch_offset_sampler = BatchOffsetSampler(sampler=batch_sampler, offset=10)\n\n    assert list(batch_offset_sampler) == [[10, 11], [12, 13], [14]]\n    # Assert that it can be iterated multiple times.\n    assert list(batch_offset_sampler) == [[10, 11], [12, 13], [14]]\n\n\ndef test_batch_offset_sampler_len():\n    \"\"\"Test the BatchOffsetSampler len() function.\"\"\"\n    sequential_sampler = SequentialSampler([0] * 5)\n    batch_sampler = BatchSampler(sequential_sampler, batch_size=2, drop_last=False)\n    batch_offset_sampler = BatchOffsetSampler(sampler=batch_sampler, offset=10)\n    assert len(batch_offset_sampler) == 3\n"
  },
  {
    "path": "tests/invoke_training/_shared/data/samplers/test_concat_sampler.py",
    "content": "from invoke_training._shared.data.samplers.concat_sampler import ConcatSampler\n\n\ndef test_concat_sampler():\n    \"\"\"Test that the ConcatSampler yields the correct sequence.\"\"\"\n    sampler_1 = [0, 1, 2, 3]\n    sampler_2 = [4, 5, 6]\n    sampler_3 = [7, 8, 9, 10, 11, 12]\n\n    sampler = ConcatSampler([sampler_1, sampler_2, sampler_3])\n    samples = list(sampler)\n\n    assert samples == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]\n\n\ndef test_concat_sampler_batches():\n    \"\"\"Test that the ConcatSampler yields the correct sequence with batch samplers.\"\"\"\n    sampler_1 = [[0, 1, 2], [3, 4, 5], [6]]\n    sampler_2 = [[7, 8], [9]]\n    sampler_3 = [[10, 11, 12], [13, 14, 15], [16, 17, 18], [19, 20, 21]]\n\n    sampler = ConcatSampler([sampler_1, sampler_2, sampler_3])\n    samples = list(sampler)\n\n    assert samples == [[0, 1, 2], [3, 4, 5], [6], [7, 8], [9], [10, 11, 12], [13, 14, 15], [16, 17, 18], [19, 20, 21]]\n\n\ndef test_concat_sampler_len():\n    \"\"\"Test the ConcatSampler len() function.\"\"\"\n    sampler_1 = [0, 1, 2, 3]\n    sampler_2 = [4, 5, 6]\n    sampler_3 = [7, 8, 9, 10, 11, 12]\n\n    sampler = ConcatSampler([sampler_1, sampler_2, sampler_3])\n    assert len(sampler) == 13\n"
  },
  {
    "path": "tests/invoke_training/_shared/data/samplers/test_interleaved_sampler.py",
    "content": "from invoke_training._shared.data.samplers.interleaved_sampler import InterleavedSampler\n\n\ndef test_interleaved_sampler():\n    \"\"\"Test that the InterleavedSampler yields the correct sequence.\"\"\"\n    sampler_1 = [0, 1, 2, 3]\n    sampler_2 = [4, 5, 6]\n    sampler_3 = [7, 8, 9, 10, 11, 12]\n\n    sampler = InterleavedSampler([sampler_1, sampler_2, sampler_3])\n    samples = list(sampler)\n\n    assert samples == [0, 4, 7, 1, 5, 8, 2, 6, 9]\n\n\ndef test_interleaved_sampler_batches():\n    \"\"\"Test that the InterleavedSampler yields the correct sequence with batch samplers.\"\"\"\n    sampler_1 = [[0, 1, 2], [3, 4, 5], [6]]\n    sampler_2 = [[7, 8], [9]]\n    sampler_3 = [[10, 11, 12], [13, 14, 15], [16, 17, 18], [19, 20, 21]]\n\n    sampler = InterleavedSampler([sampler_1, sampler_2, sampler_3])\n    samples = list(sampler)\n\n    assert samples == [[0, 1, 2], [7, 8], [10, 11, 12], [3, 4, 5], [9], [13, 14, 15]]\n\n\ndef test_interleaved_sampler_len():\n    \"\"\"Test the InterleavedSampler len() function.\"\"\"\n    sampler_1 = [0, 1, 2, 3]\n    sampler_2 = [4, 5]\n    sampler_3 = [7, 8, 9, 10, 11, 12]\n\n    sampler = InterleavedSampler([sampler_1, sampler_2, sampler_3])\n    assert len(sampler) == 2 * 3\n"
  },
  {
    "path": "tests/invoke_training/_shared/data/samplers/test_offset_sampler.py",
    "content": "from torch.utils.data.sampler import SequentialSampler\n\nfrom invoke_training._shared.data.samplers.offset_sampler import OffsetSampler\n\n\ndef test_offset_sampler():\n    \"\"\"Test that the OffsetSampler yields the correct sequence of values.\"\"\"\n    sequential_sampler = SequentialSampler([0] * 5)\n    offset_sampler = OffsetSampler(sampler=sequential_sampler, offset=10)\n\n    assert list(offset_sampler) == list(range(10, 15))\n    # Assert that it can be iterated multiple times.\n    assert list(offset_sampler) == list(range(10, 15))\n\n\ndef test_offset_sampler_len():\n    \"\"\"Test the OffsetSampler len() function.\"\"\"\n    sequential_sampler = SequentialSampler([0] * 5)\n    offset_sampler = OffsetSampler(sampler=sequential_sampler, offset=10)\n    assert len(offset_sampler) == 5\n"
  },
  {
    "path": "tests/invoke_training/_shared/data/transforms/__init__.py",
    "content": ""
  },
  {
    "path": "tests/invoke_training/_shared/data/transforms/test_caption_prefix_transform.py",
    "content": "from invoke_training._shared.data.transforms.caption_prefix_transform import CaptionPrefixTransform\n\n\ndef test_caption_prefix_transform():\n    tf = CaptionPrefixTransform(caption_field_name=\"caption\", prefix=\"prefix \")\n\n    in_example = {\"caption\": \"original caption\", \"other\": 2}\n\n    out_example = tf(in_example)\n\n    assert out_example == {\"caption\": \"prefix original caption\", \"other\": 2}\n"
  },
  {
    "path": "tests/invoke_training/_shared/data/transforms/test_concat_fields_transform.py",
    "content": "from invoke_training._shared.data.transforms.concat_fields_transform import ConcatFieldsTransform\n\n\ndef test_caption_prefix_transform():\n    tf = ConcatFieldsTransform(src_field_names=[\"caption\", \"caption_2\"], dst_field_name=\"caption\", separator=\", \")\n\n    in_example = {\"caption\": \"original caption\", \"caption_2\": \"another caption\", \"other\": 2}\n\n    out_example = tf(in_example)\n\n    assert out_example == {\"caption\": \"original caption, another caption\", \"caption_2\": \"another caption\", \"other\": 2}\n"
  },
  {
    "path": "tests/invoke_training/_shared/data/transforms/test_constant_field_transform.py",
    "content": "from invoke_training._shared.data.transforms.constant_field_transform import ConstantFieldTransform\n\n\ndef test_constant_field_transform():\n    tf = ConstantFieldTransform(\"test_field\", 1)\n\n    in_example = {\"existing\": 2}\n\n    out_example = tf(in_example)\n\n    assert out_example == {\"existing\": 2, \"test_field\": 1}\n"
  },
  {
    "path": "tests/invoke_training/_shared/data/transforms/test_drop_field_transform.py",
    "content": "from invoke_training._shared.data.transforms.drop_field_transform import DropFieldTransform\n\n\ndef test_drop_field_transform():\n    tf = DropFieldTransform(\"drop\")\n\n    in_example = {\"keep\": 1, \"drop\": 2}\n\n    out_example = tf(in_example)\n\n    assert out_example == {\"keep\": 1}\n"
  },
  {
    "path": "tests/invoke_training/_shared/data/transforms/test_load_cache_transform.py",
    "content": "import unittest.mock\n\nimport torch\n\nfrom invoke_training._shared.data.transforms.load_cache_transform import LoadCacheTransform\n\n\ndef test_load_cache_transform():\n    cached_tensor = torch.Tensor([1.0, 2.0, 3.0])\n    mock_cache = unittest.mock.MagicMock()\n    mock_cache.load.return_value = {\"cached_tensor\": cached_tensor}\n\n    tf = LoadCacheTransform(\n        cache=mock_cache, cache_key_field=\"cache_key\", cache_field_to_output_field={\"cached_tensor\": \"output\"}\n    )\n\n    in_example = {\"cache_key\": 1}\n\n    out_example = tf(in_example)\n\n    mock_cache.load.assert_called_once_with(1)\n    assert out_example[\"output\"] is cached_tensor\n"
  },
  {
    "path": "tests/invoke_training/_shared/data/transforms/test_sd_image_transform.py",
    "content": "import unittest.mock\n\nimport numpy as np\nimport pytest\nimport torch\nfrom PIL import Image\n\nfrom invoke_training._shared.data.transforms.sd_image_transform import SDImageTransform\nfrom invoke_training._shared.data.utils.aspect_ratio_bucket_manager import AspectRatioBucketManager\nfrom invoke_training._shared.data.utils.resolution import Resolution\n\n\ndef denormalize_image(img: np.ndarray) -> np.ndarray:\n    \"\"\"Convert a normalized CxHxW image in range [-1.0, 1.0] to a HxWxC image in the range [0, 255].\n\n    Args:\n        img (np.ndarray): Image to denormalize.\n\n    Returns:\n        np.ndarray: Result image.\n    \"\"\"\n    # Convert back to range [0, 1.0].\n    img = img * 0.5 + 0.5\n    # Convert back to range [0, 255].\n    img *= 255\n    # Move channel axis from first dimension to last dimension.\n    img = np.moveaxis(img, 0, -1)\n\n    return img\n\n\ndef denormalize_mask(mask: np.ndarray) -> np.ndarray:\n    \"\"\"Convert a normalized CxHxW mask in range [0.0, 1.0] to a HxW mask in the range [0, 255].\"\"\"\n    # Convert back to range [0, 255].\n    mask *= 255\n    # Squeeze the channel dimension.\n    mask = mask.squeeze(0)\n    return mask\n\n\ndef test_sd_image_transform_resolution():\n    \"\"\"Test that SDImageTransform resizes and crops to the target resolution, and correctly sets original_size_hw.\"\"\"\n    in_image_np = np.ones((256, 128, 3), dtype=np.uint8)\n    in_image_pil = Image.fromarray(in_image_np)\n    in_mask_np = np.ones((256, 128), dtype=np.uint8)\n    in_mask_pil = Image.fromarray(in_mask_np)\n\n    resolution = Resolution(768, 512)\n    tf = SDImageTransform(\n        image_field_names=[\"image\", \"mask\"],\n        fields_to_normalize_to_range_minus_one_to_one=[\"image\"],\n        resolution=resolution,\n    )\n\n    out_example = tf({\"image\": in_image_pil, \"mask\": in_mask_pil})\n\n    out_image = out_example[\"image\"]\n    assert isinstance(out_image, torch.Tensor)\n    assert out_image.shape == (3, resolution.height, resolution.width)\n\n    out_mask = out_example[\"mask\"]\n    assert isinstance(out_mask, torch.Tensor)\n    assert out_mask.shape == (1, resolution.height, resolution.width)\n\n    original_size_hw = out_example[\"original_size_hw\"]\n    assert original_size_hw == (256, 128)\n\n\ndef test_sd_image_transform_without_mask():\n    \"\"\"Test that SDImageTransform works correctly when no mask is provided.\"\"\"\n    in_image_np = np.ones((256, 128, 3), dtype=np.uint8)\n    in_image_pil = Image.fromarray(in_image_np)\n\n    resolution = Resolution(768, 512)\n    tf = SDImageTransform(\n        image_field_names=[\"image\"],\n        fields_to_normalize_to_range_minus_one_to_one=[\"image\"],\n        resolution=resolution,\n    )\n\n    # No mask is provided.\n    out_example = tf({\"image\": in_image_pil})\n\n    out_image = out_example[\"image\"]\n    assert isinstance(out_image, torch.Tensor)\n    assert out_image.shape == (3, resolution.height, resolution.width)\n\n    original_size_hw = out_example[\"original_size_hw\"]\n    assert original_size_hw == (256, 128)\n\n\ndef test_sd_image_transform_range():\n    \"\"\"Test that SDImageTransform normalizes the image to the range [-1.0, 1.0], and the mask to the range\n    [0.0, 1.0].\n    \"\"\"\n    resolution = 128\n    in_image_np = np.zeros((resolution, resolution, 3), dtype=np.uint8)\n    in_image_np[0, 0, :] = 255  # Image contains one pixel with value 255, and the rest are zeros.\n    in_image_pil = Image.fromarray(in_image_np)\n\n    in_mask_np = np.zeros((resolution, resolution), dtype=np.uint8)\n    in_mask_np[0, 0] = 255  # Mask contains one pixel with value 255, and the rest are zeros.\n    in_mask_pil = Image.fromarray(in_mask_np)\n\n    tf = SDImageTransform(\n        image_field_names=[\"image\", \"mask\"],\n        fields_to_normalize_to_range_minus_one_to_one=[\"image\"],\n        resolution=resolution,\n    )\n\n    out_example = tf({\"image\": in_image_pil, \"mask\": in_mask_pil})\n\n    out_image = out_example[\"image\"]\n    out_np = np.array(out_image)\n    assert np.allclose(out_np[:, 0, 0], 1.0)\n    assert np.allclose(out_np[:, 1:, 1:], -1.0)\n\n    out_mask = out_example[\"mask\"]\n    out_np = np.array(out_mask)\n    assert np.allclose(out_np[0, 0, 0], 1.0)\n    assert np.allclose(out_np[0, 1:, 1:], 0.0)\n\n\ndef test_sd_image_transform_center_crop():\n    \"\"\"Test SDImageTransform center cropping.\"\"\"\n    # Input image is 9 x 5.\n    in_image_np = np.arange(9 * 5 * 3, dtype=np.uint8).reshape((9, 5, 3))\n    in_image_pil = Image.fromarray(np.copy(in_image_np))\n\n    mask_image_np = np.arange(9 * 5, dtype=np.uint8).reshape((9, 5))\n    mask_image_pil = Image.fromarray(np.copy(mask_image_np))\n\n    # The target resolution is 3x5 (with center cropping).\n    tf = SDImageTransform(\n        image_field_names=[\"image\", \"mask\"],\n        fields_to_normalize_to_range_minus_one_to_one=[\"image\"],\n        resolution=(3, 5),\n        center_crop=True,\n    )\n\n    out_example = tf({\"image\": in_image_pil, \"mask\": mask_image_pil})\n\n    # Verify that the correct region of the image was cropped.\n    out_image = out_example[\"image\"]\n    out_image_np = np.array(out_image)\n    assert np.allclose(denormalize_image(out_image_np), in_image_np[3:-3, :, :])\n    assert out_example[\"crop_top_left_yx\"] == (3, 0)\n\n    # Verify that the correct region of the mask was cropped.\n    out_mask = out_example[\"mask\"]\n    out_mask_np = np.array(out_mask)\n    assert np.allclose(denormalize_mask(out_mask_np), mask_image_np[3:-3, :])\n\n\ndef test_sd_image_transform_random_crop():\n    \"\"\"Test SDImageTransform random cropping.\"\"\"\n    # Input image is 9 x 5.\n    in_image_np = np.arange(9 * 5 * 3, dtype=np.uint8).reshape((9, 5, 3))\n    in_image_pil = Image.fromarray(np.copy(in_image_np))\n\n    mask_image_np = np.arange(9 * 5, dtype=np.uint8).reshape((9, 5))\n    mask_image_pil = Image.fromarray(np.copy(mask_image_np))\n\n    # The target resolution is 3x5 (with random cropping).\n    resolution = Resolution(3, 5)\n    tf = SDImageTransform(\n        image_field_names=[\"image\", \"mask\"],\n        fields_to_normalize_to_range_minus_one_to_one=[\"image\"],\n        resolution=resolution,\n        center_crop=False,\n    )\n\n    out_example = tf({\"image\": in_image_pil, \"mask\": mask_image_pil})\n\n    # Verify that the crop_top_left_yx value is correct.\n    out_image = out_example[\"image\"]\n    out_image_np = np.array(out_image)\n    crop_y, crop_x = out_example[\"crop_top_left_yx\"]\n    assert np.allclose(\n        denormalize_image(out_image_np),\n        in_image_np[crop_y : crop_y + resolution.height, crop_x : crop_x + resolution.width, :],\n    )\n\n    # Verify that the mask was cropped in the same way as the image.\n    out_mask = out_example[\"mask\"]\n    out_mask_np = np.array(out_mask)\n    assert np.allclose(\n        denormalize_mask(out_mask_np),\n        mask_image_np[crop_y : crop_y + resolution.height, crop_x : crop_x + resolution.width],\n    )\n\n\ndef test_sd_image_transform_center_crop_flip():\n    \"\"\"Test SDImageTransform center cropping with a horizontal flip.\"\"\"\n    # Input image is 5 x 9.\n    in_image_np = np.arange(5 * 9 * 3, dtype=np.uint8).reshape((5, 9, 3))\n    in_image_pil = Image.fromarray(np.copy(in_image_np))\n\n    in_mask_np = np.arange(5 * 9, dtype=np.uint8).reshape((5, 9))\n    in_mask_pil = Image.fromarray(np.copy(in_mask_np))\n\n    # The target resolution is 5x3 (with center cropping and horizontal flipping).\n    tf = SDImageTransform(\n        image_field_names=[\"image\", \"mask\"],\n        fields_to_normalize_to_range_minus_one_to_one=[\"image\"],\n        resolution=Resolution(5, 3),\n        center_crop=True,\n        random_flip=True,\n    )\n\n    # Note: We patch random.random() to force a horizontal flip to be applied.\n    with unittest.mock.patch(\"random.random\", return_value=0.0):\n        out_example = tf({\"image\": in_image_pil, \"mask\": in_mask_pil})\n\n    # Verify that the correct region of the image was cropped/flipped.\n    # For this comparison, we flip the in_image_np first, then apply the expected crop.\n    out_image = out_example[\"image\"]\n    out_image_np = np.array(out_image)\n    assert np.allclose(denormalize_image(out_image_np), in_image_np[:, ::-1, :][:, 3:-3, :])\n    assert out_example[\"crop_top_left_yx\"] == (0, 3)\n\n    # Verify that the correct region of the mask was cropped/flipped.\n    out_mask = out_example[\"mask\"]\n    out_mask_np = np.array(out_mask)\n    assert np.allclose(denormalize_mask(out_mask_np), in_mask_np[:, ::-1][:, 3:-3])\n\n\ndef test_sd_image_transform_random_crop_flip():\n    \"\"\"Test SDImageTransform random cropping with a horizontal flip.\"\"\"\n    # Input image is 5 x 9.\n    in_image_np = np.arange(5 * 9 * 3, dtype=np.uint8).reshape((5, 9, 3))\n    in_image_pil = Image.fromarray(np.copy(in_image_np))\n\n    in_mask_np = np.arange(5 * 9, dtype=np.uint8).reshape((5, 9))\n    in_mask_pil = Image.fromarray(np.copy(in_mask_np))\n\n    # The target resolution is 5x3 (with random cropping and horizontal flipping).\n    resolution = Resolution(5, 3)\n    tf = SDImageTransform(\n        image_field_names=[\"image\", \"mask\"],\n        fields_to_normalize_to_range_minus_one_to_one=[\"image\"],\n        resolution=resolution,\n        center_crop=False,\n        random_flip=True,\n    )\n\n    # Note: We patch random.random() to force a horizontal flip to be applied.\n    with unittest.mock.patch(\"random.random\", return_value=0.0):\n        out_example = tf({\"image\": in_image_pil, \"mask\": in_mask_pil})\n\n    # Verify that the crop_top_left_yx value is correct.\n    # For this comparison, we flip the in_image_np first, then apply the expected crop.\n    out_image = out_example[\"image\"]\n    out_image_np = np.array(out_image)\n    crop_y, crop_x = out_example[\"crop_top_left_yx\"]\n    assert np.allclose(\n        denormalize_image(out_image_np),\n        in_image_np[:, ::-1, :][crop_y : crop_y + resolution.height, crop_x : crop_x + resolution.width, :],\n    )\n\n    # Verify thath the mask was cropped in the same way as the image.\n    out_mask = out_example[\"mask\"]\n    out_mask_np = np.array(out_mask)\n    assert np.allclose(\n        denormalize_mask(out_mask_np),\n        in_mask_np[:, ::-1][crop_y : crop_y + resolution.height, crop_x : crop_x + resolution.width],\n    )\n\n\ndef test_sd_image_transform_aspect_ratio_bucket_manager():\n    # Input image is 9 x 5.\n    in_image_np = np.arange(9 * 5 * 3, dtype=np.uint8).reshape((9, 5, 3))\n    in_image_pil = Image.fromarray(np.copy(in_image_np))\n\n    in_mask_np = np.arange(9 * 5, dtype=np.uint8).reshape((9, 5))\n    in_mask_pil = Image.fromarray(np.copy(in_mask_np))\n\n    # Initialize SDImageTransform with an AspectRatioBucketManager that has a single 3x5 bucket.\n    aspect_ratio_bucket_manager = AspectRatioBucketManager(buckets={Resolution(3, 5)})\n    tf = SDImageTransform(\n        image_field_names=[\"image\", \"mask\"],\n        fields_to_normalize_to_range_minus_one_to_one=[\"image\"],\n        resolution=None,\n        aspect_ratio_bucket_manager=aspect_ratio_bucket_manager,\n        center_crop=True,\n    )\n\n    out_example = tf({\"image\": in_image_pil, \"mask\": in_mask_pil})\n\n    # Verify that the correct region of the image was cropped.\n    out_image = out_example[\"image\"]\n    out_image_np = np.array(out_image)\n    assert np.allclose(denormalize_image(out_image_np), in_image_np[3:-3, :, :])\n    assert out_example[\"crop_top_left_yx\"] == (3, 0)\n\n    # Verify that the correct region of the mask was cropped.\n    out_mask = out_example[\"mask\"]\n    out_mask_np = np.array(out_mask)\n    assert np.allclose(denormalize_mask(out_mask_np), in_mask_np[3:-3, :])\n\n\n@pytest.mark.parametrize(\n    [\"resolution\", \"aspect_ratio_bucket_manager\"],\n    [\n        (Resolution(512, 512), AspectRatioBucketManager({})),\n        (None, None),\n    ],\n)\ndef test_sd_image_transform_resolution_input_validation(\n    resolution: Resolution | None, aspect_ratio_bucket_manager: AspectRatioBucketManager | None\n):\n    with pytest.raises(ValueError):\n        _ = SDImageTransform(\n            image_field_names=[\"image\", \"mask\"],\n            fields_to_normalize_to_range_minus_one_to_one=[\"image\"],\n            resolution=resolution,\n            aspect_ratio_bucket_manager=aspect_ratio_bucket_manager,\n        )\n"
  },
  {
    "path": "tests/invoke_training/_shared/data/transforms/test_shuffle_caption_transform.py",
    "content": "from invoke_training._shared.data.transforms.shuffle_caption_transform import ShuffleCaptionTransform\n\n\ndef test_shuffle_caption_transform():\n    tf = ShuffleCaptionTransform(field_name=\"test_field\", seed=3)\n\n    in_example = {\"test_field\": \"prompt part 1, prompt part 2\"}\n\n    out_example = tf(in_example)\n\n    # Note that the expected output depends on the seed.\n    assert out_example == {\"test_field\": \"prompt part 2, prompt part 1\"}\n\n\ndef test_shuffle_caption_transform_no_delimiter():\n    tf = ShuffleCaptionTransform(field_name=\"test_field\")\n\n    in_example = {\"test_field\": \"prompt part 1\"}\n\n    out_example = tf(in_example)\n\n    assert out_example == {\"test_field\": \"prompt part 1\"}\n"
  },
  {
    "path": "tests/invoke_training/_shared/data/transforms/test_template_caption_transform.py",
    "content": "import pytest\n\nfrom invoke_training._shared.data.transforms.template_caption_transform import (\n    TemplateCaptionTransform,\n)\n\n\ndef test_template_caption_transform():\n    tf = TemplateCaptionTransform(\n        field_name=\"test_field\", placeholder_str=\"placeholder\", caption_templates=[\"template 1 {}\"]\n    )\n\n    in_example = {\"existing\": 2}\n\n    out_example = tf(in_example)\n\n    assert out_example == {\"existing\": 2, \"test_field\": \"template 1 placeholder\"}\n\n\ndef test_template_caption_transform_seed():\n    field_name = \"test_field\"\n    placeholder_str = \"placeholder\"\n    caption_templates = [\"template 1 {}\", \"template 2 {}\"]\n    tf = TemplateCaptionTransform(\n        field_name=field_name,\n        placeholder_str=placeholder_str,\n        caption_templates=caption_templates,\n        seed=123,\n    )\n\n    # Run on 10 examples with baseline seed 123.\n    out_examples = [tf({}) for _ in range(10)]\n\n    # Run on 10 examples with same seed and assert that results match.\n    tf = TemplateCaptionTransform(\n        field_name=field_name,\n        placeholder_str=placeholder_str,\n        caption_templates=caption_templates,\n        seed=123,\n    )\n    out_examples_same_seed = [tf({}) for _ in range(10)]\n    assert out_examples == out_examples_same_seed\n\n    # Run on 10 examples with a different seed and assert that the results don't match.\n    tf = TemplateCaptionTransform(\n        field_name=field_name,\n        placeholder_str=placeholder_str,\n        caption_templates=caption_templates,\n        seed=456,\n    )\n    out_examples_diff_seed = [tf({}) for _ in range(10)]\n    assert out_examples != out_examples_diff_seed\n\n\ndef test_template_caption_transform_bad_templates():\n    tf = TemplateCaptionTransform(\n        field_name=\"test_field\", placeholder_str=\"placeholder\", caption_templates=[\"template 1\"]\n    )\n\n    in_example = {\"existing\": 2}\n\n    with pytest.raises(AssertionError):\n        _ = tf(in_example)\n"
  },
  {
    "path": "tests/invoke_training/_shared/data/transforms/test_tensor_disk_cache.py",
    "content": "from pathlib import Path\n\nimport pytest\nimport torch\n\nfrom invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache\n\n\ndef test_tensor_disk_cache_roundtrip(tmp_path: Path):\n    \"\"\"Test a TensorDiskCache cache roundtrip.\"\"\"\n    cache = TensorDiskCache(str(tmp_path))\n\n    in_dict = {\"test_tensor\": torch.rand((1, 2, 3)), \"test_tuple\": (1, 2), \"test_list\": [3, 4], \"test_scalar\": 1}\n\n    # Roundtrip\n    cache.save(0, in_dict)\n    out_dict = cache.load(0)\n\n    assert set(in_dict.keys()) == set(out_dict.keys())\n    torch.testing.assert_close(out_dict[\"test_tensor\"], in_dict[\"test_tensor\"])\n    assert out_dict[\"test_tuple\"] == in_dict[\"test_tuple\"]\n    assert out_dict[\"test_list\"] == in_dict[\"test_list\"]\n    assert out_dict[\"test_scalar\"] == in_dict[\"test_scalar\"]\n\n\ndef test_tensor_disk_cache_fail_overwrite(tmp_path):\n    \"\"\"Test that an attempt to overwrite an existing TensorDiskCache cache entry raises a ValueError.\"\"\"\n    cache = TensorDiskCache(str(tmp_path))\n    in_dict = {\"test_tensor\": torch.rand((1, 2, 3))}\n    cache.save(0, in_dict)\n\n    with pytest.raises(AssertionError):\n        cache.save(0, in_dict)\n"
  },
  {
    "path": "tests/invoke_training/_shared/data/utils/__init__.py",
    "content": ""
  },
  {
    "path": "tests/invoke_training/_shared/data/utils/test_aspect_ratio_bucket_manager.py",
    "content": "from contextlib import nullcontext\n\nimport pytest\n\nfrom invoke_training._shared.data.utils.aspect_ratio_bucket_manager import AspectRatioBucketManager\nfrom invoke_training._shared.data.utils.resolution import Resolution\n\n\n@pytest.mark.parametrize(\n    [\"target_resolution\", \"start_dim\", \"end_dim\", \"divisible_by\", \"should_raise\"],\n    [\n        (1024, 512, 2048, 64, False),\n        (1025, 512, 2048, 64, True),  # target_resolution not divisible by divisible_by.\n        (1024, 513, 2048, 64, True),  # start_dim not divisible by divisible_by.\n        (1024, 512, 2049, 64, True),  # end_dim not divisible by divisible_by.\n        (1024, 1024, 512, 64, True),  # start_dim > end_dim.\n    ],\n)\ndef test_build_aspect_ratio_buckets_input_validation(\n    target_resolution: int, start_dim: int, end_dim: int, divisible_by: int, should_raise: bool\n):\n    \"\"\"Test validation of all input params to AspectRatioBucketManager.build_aspect_ratio_buckets(...).\"\"\"\n    expectation = pytest.raises(AssertionError) if should_raise else nullcontext()\n    with expectation:\n        _ = AspectRatioBucketManager.build_aspect_ratio_buckets(\n            target_resolution=target_resolution,\n            start_dim=start_dim,\n            end_dim=end_dim,\n            divisible_by=divisible_by,\n        )\n\n\n@pytest.mark.parametrize(\n    [\"target_resolution\", \"start_dim\", \"end_dim\", \"divisible_by\", \"expected\"],\n    [\n        # 1 bucket\n        (1024, 1024, 1024, 64, {Resolution(1024, 1024)}),\n        # Multiple buckets.\n        (\n            1024,\n            768,\n            1280,\n            128,\n            {\n                Resolution(768, 1280),\n                Resolution(896, 1152),\n                Resolution(1024, 1024),\n                Resolution(1152, 896),\n                Resolution(1280, 768),\n            },\n        ),\n    ],\n)\ndef test_build_aspect_ratio_buckets(\n    target_resolution: int,\n    start_dim: int,\n    end_dim: int,\n    divisible_by: int,\n    expected: set[Resolution],\n):\n    buckets = AspectRatioBucketManager.build_aspect_ratio_buckets(\n        target_resolution=target_resolution,\n        start_dim=start_dim,\n        end_dim=end_dim,\n        divisible_by=divisible_by,\n    )\n\n    assert buckets == expected\n\n\n@pytest.mark.parametrize(\n    [\"resolution\", \"expected_bucket\"],\n    [\n        (Resolution(1024, 1024), Resolution(1024, 1024)),  # Exact match.\n        (Resolution(128, 1024), Resolution(768, 1280)),  # Small aspect ratio.\n        (Resolution(1024, 128), Resolution(1280, 768)),  # Large aspect ratio.\n    ],\n)\ndef test_get_aspect_ratio_bucket(resolution: Resolution, expected_bucket: Resolution):\n    arbm = AspectRatioBucketManager.from_constraints(\n        target_resolution=1024, start_dim=768, end_dim=1280, divisible_by=128\n    )\n\n    nearest_bucket = arbm.get_aspect_ratio_bucket(resolution)\n\n    assert nearest_bucket == expected_bucket\n"
  },
  {
    "path": "tests/invoke_training/_shared/data/utils/test_resize.py",
    "content": "import numpy as np\nimport pytest\nfrom PIL import Image\n\nfrom invoke_training._shared.data.utils.resize import resize_to_cover\nfrom invoke_training._shared.data.utils.resolution import Resolution\n\n\n@pytest.mark.parametrize(\n    [\"in_resolution\", \"size_to_cover\", \"expected_resolution\"],\n    [\n        # Perfect match, no resize necessary.\n        (Resolution(512, 768), Resolution(512, 768), Resolution(512, 768)),\n        # Height matches, width covers, no resize necessary.\n        (Resolution(768, 768), Resolution(768, 512), Resolution(768, 768)),\n        # Width matches, height covers, no resize necessary.\n        (Resolution(768, 768), Resolution(512, 768), Resolution(768, 768)),\n        # Height matches, width does not cover, scale up.\n        (Resolution(768, 256), Resolution(768, 512), Resolution(1536, 512)),\n        # Width matches, height does not cover, scale up.\n        (Resolution(256, 768), Resolution(512, 768), Resolution(512, 1536)),\n        # Both width and height exceed target, scale down, limited by height.\n        (Resolution(1024, 768), Resolution(768, 512), Resolution(768, 576)),\n        # Both width and height exceed target, scale down, limited by width.\n        (Resolution(768, 1024), Resolution(512, 768), Resolution(576, 768)),\n    ],\n)\ndef test_resize_to_cover(in_resolution: Resolution, size_to_cover: Resolution, expected_resolution: Resolution):\n    in_img = np.zeros((in_resolution.height, in_resolution.width, 3), dtype=np.uint8)\n    in_img = Image.fromarray(in_img)\n\n    out_img = resize_to_cover(in_img, size_to_cover)\n\n    assert out_img.height == expected_resolution.height\n    assert out_img.width == expected_resolution.width\n"
  },
  {
    "path": "tests/invoke_training/_shared/data/utils/test_resolution.py",
    "content": "import pytest\n\nfrom invoke_training._shared.data.utils.resolution import Resolution\n\n\n@pytest.mark.parametrize(\n    [\"input\", \"expected_resolution\"],\n    [\n        (5, Resolution(5, 5)),  # From int.\n        ((5, 6), Resolution(5, 6)),  # From tuple[int, int].\n        (Resolution(5, 6), Resolution(5, 6)),  # From Resolution.\n    ],\n)\ndef test_resolution_parse(input, expected_resolution: Resolution):\n    resolution = Resolution.parse(input)\n    assert resolution == expected_resolution\n"
  },
  {
    "path": "tests/invoke_training/_shared/stable_diffusion/__init__.py",
    "content": ""
  },
  {
    "path": "tests/invoke_training/_shared/stable_diffusion/test_base_model_version.py",
    "content": "import pytest\nfrom transformers import PretrainedConfig\n\nfrom invoke_training._shared.stable_diffusion.base_model_version import (\n    BaseModelVersionEnum,\n    check_base_model_version,\n    get_base_model_version,\n)\n\n\n@pytest.mark.loads_model\n@pytest.mark.parametrize(\n    [\"diffusers_model_name\", \"expected_version\"],\n    [\n        (\"runwayml/stable-diffusion-v1-5\", BaseModelVersionEnum.STABLE_DIFFUSION_V1),\n        (\"stabilityai/stable-diffusion-2-1\", BaseModelVersionEnum.STABLE_DIFFUSION_V2),\n        (\"stabilityai/stable-diffusion-xl-base-1.0\", BaseModelVersionEnum.STABLE_DIFFUSION_SDXL_BASE),\n        (\"stabilityai/stable-diffusion-xl-refiner-1.0\", BaseModelVersionEnum.STABLE_DIFFUSION_SDXL_REFINER),\n    ],\n)\ndef test_get_base_model_version(diffusers_model_name: str, expected_version: BaseModelVersionEnum):\n    \"\"\"Test get_base_model_version(...) with one test model for each supported version.\"\"\"\n    # Check if the diffusers_model_name model is downloaded and xfail if not.\n    # This check ensures that users don't have to download all of the test models just to run the test suite.\n    try:\n        _ = PretrainedConfig.from_pretrained(\n            pretrained_model_name_or_path=diffusers_model_name,\n            subfolder=\"unet\",\n            local_files_only=True,\n        )\n    except OSError:\n        pytest.xfail(f\"'{diffusers_model_name}' is not downloaded.\")\n\n    version = get_base_model_version(diffusers_model_name)\n    assert version == expected_version\n\n\n@pytest.mark.loads_model\ndef test_check_base_model_version_pass():\n    \"\"\"Test that check_base_model_version(...) does not raise an Exception when the model is valid.\"\"\"\n    check_base_model_version({BaseModelVersionEnum.STABLE_DIFFUSION_V1}, \"runwayml/stable-diffusion-v1-5\")\n\n\n@pytest.mark.loads_model\ndef test_check_base_model_version_fail():\n    \"\"\"Test that check_base_model_version(...) raises a ValueError when the model is invalid.\"\"\"\n    with pytest.raises(ValueError):\n        check_base_model_version({BaseModelVersionEnum.STABLE_DIFFUSION_V2}, \"runwayml/stable-diffusion-v1-5\")\n"
  },
  {
    "path": "tests/invoke_training/_shared/stable_diffusion/test_lora_checkpoint_utils.py",
    "content": "from pathlib import Path\n\nimport pytest\n\nfrom invoke_training._shared.stable_diffusion.lora_checkpoint_utils import (\n    convert_sd_peft_checkpoint_to_kohya_state_dict,\n)\n\n\ndef test_convert_sd_peft_checkpoint_to_kohya_state_dict_raise_on_empty_directory(tmp_path: Path):\n    with pytest.raises(ValueError, match=\"No checkpoint files found in directory\"):\n        convert_sd_peft_checkpoint_to_kohya_state_dict(\n            in_checkpoint_dir=tmp_path, out_checkpoint_file=tmp_path / \"out.safetensors\"\n        )\n\n\ndef test_convert_sd_peft_checkpoint_to_kohya_state_dict_raise_on_unexpected_subdirectory(tmp_path: Path):\n    subdirectory = tmp_path / \"subdir\"\n    subdirectory.mkdir()\n\n    with pytest.raises(ValueError, match=f\"Unrecognized checkpoint directory: '{subdirectory}'.\"):\n        convert_sd_peft_checkpoint_to_kohya_state_dict(\n            in_checkpoint_dir=tmp_path, out_checkpoint_file=tmp_path / \"out.safetensors\"\n        )\n"
  },
  {
    "path": "tests/invoke_training/_shared/stable_diffusion/test_model_loading_utils.py",
    "content": "import logging\nfrom pathlib import Path\n\nimport pytest\nimport torch\nfrom transformers import CLIPTextModel, CLIPTokenizer\n\nfrom invoke_training._shared.stable_diffusion.model_loading_utils import load_models_sd, load_models_sdxl\n\nfrom .ti_embedding_checkpoint_fixture import (  # noqa: F401\n    sdv1_embedding_path,\n    sdxl_embedding_path,\n)\n\n\n@pytest.mark.loads_model\ndef test_load_models_sd(sdv1_embedding_path):  # noqa: F811\n    model_name = \"runwayml/stable-diffusion-v1-5\"\n\n    tokenizer, noise_scheduler, text_encoder, vae, unet = load_models_sd(\n        logger=logging.getLogger(__name__),\n        model_name_or_path=model_name,\n        hf_variant=\"fp16\",\n        base_embeddings={\"special_test_token\": str(sdv1_embedding_path)},\n    )\n\n    token_ids = tokenizer.encode(\"special_test_token special_test_token_1\", add_special_tokens=False)\n    assert len(token_ids) == 2\n\n    token_embeds = text_encoder.get_input_embeddings().weight.data\n    for token_id in token_ids:\n        # The embedding should be all zeros, because that is how it was initialized in the sdv1_embedding_path\n        # fixture.\n        assert torch.allclose(token_embeds[token_id], torch.zeros_like(token_embeds[token_id]))\n\n\n@pytest.mark.loads_model\ndef test_load_models_sdxl(sdxl_embedding_path: Path):  # noqa: F811\n    model_name = \"stabilityai/stable-diffusion-xl-base-1.0\"\n\n    tokenizer_1, tokenizer_2, noise_scheduler, text_encoder_1, text_encoder_2, vae, unet = load_models_sdxl(\n        logger=logging.getLogger(__name__),\n        model_name_or_path=model_name,\n        hf_variant=\"fp16\",\n        base_embeddings={\"special_test_token\": str(sdxl_embedding_path)},\n    )\n\n    # Validate that the embeddings were applied correctly.\n    def validate_ti_embeddings(tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel):\n        token_ids = tokenizer.encode(\"special_test_token special_test_token_1\", add_special_tokens=False)\n        assert len(token_ids) == 2\n\n        token_embeds = text_encoder.get_input_embeddings().weight.data\n        for token_id in token_ids:\n            # The embedding should be all zeros, because that is how it was initialized in the sdxl_embedding_path\n            # fixture.\n            assert torch.allclose(token_embeds[token_id], torch.zeros_like(token_embeds[token_id]))\n\n    validate_ti_embeddings(tokenizer_1, text_encoder_1)\n    validate_ti_embeddings(tokenizer_2, text_encoder_2)\n"
  },
  {
    "path": "tests/invoke_training/_shared/stable_diffusion/test_textual_inversion.py",
    "content": "import logging\nfrom pathlib import Path\n\nimport pytest\nimport torch\n\nfrom invoke_training._shared.stable_diffusion.model_loading_utils import load_models_sd\nfrom invoke_training._shared.stable_diffusion.textual_inversion import (\n    _expand_placeholder_token,\n    initialize_placeholder_tokens_from_initial_embedding,\n    initialize_placeholder_tokens_from_initial_phrase,\n    initialize_placeholder_tokens_from_initializer_token,\n)\n\nfrom .ti_embedding_checkpoint_fixture import sdv1_embedding_path  # noqa: F401\n\n\n@pytest.mark.parametrize(\n    [\"placeholder_token\", \"num_vectors\", \"expected_placeholder_tokens\"],\n    [(\"abc\", 1, [\"abc\"]), (\"abc\", 2, [\"abc\", \"abc_1\"]), (\"abc\", 3, [\"abc\", \"abc_1\", \"abc_2\"])],\n)\ndef test_expand_placeholder_token(placeholder_token: str, num_vectors: int, expected_placeholder_tokens: list[str]):\n    assert _expand_placeholder_token(placeholder_token, num_vectors) == expected_placeholder_tokens\n\n\ndef test_expand_placeholder_token_raises_on_invalid_num_vectors():\n    with pytest.raises(ValueError):\n        _expand_placeholder_token(\"abc\", 0)\n\n\n@pytest.mark.loads_model\ndef test_initialize_placeholder_tokens_from_initializer_token():\n    tokenizer, noise_scheduler, text_encoder, vae, unet = load_models_sd(\n        logger=logging.getLogger(__name__), model_name_or_path=\"runwayml/stable-diffusion-v1-5\", hf_variant=\"fp16\"\n    )\n\n    initializer_token = \"dog\"\n    num_vectors = 2\n    placeholder_tokens, placeholder_token_ids = initialize_placeholder_tokens_from_initializer_token(\n        tokenizer=tokenizer,\n        text_encoder=text_encoder,\n        initializer_token=initializer_token,\n        placeholder_token=\"dog_placeholder\",\n        num_vectors=num_vectors,\n        logger=logging.getLogger(),\n    )\n\n    assert len(placeholder_tokens) == num_vectors\n    assert len(placeholder_token_ids) == num_vectors\n    assert placeholder_tokens == [\"dog_placeholder\", \"dog_placeholder_1\"]\n\n    token_embeds = text_encoder.get_input_embeddings().weight.data\n    initializer_token_id = tokenizer.encode(initializer_token, add_special_tokens=False)[0]\n    with torch.no_grad():\n        for placeholder_token_id in placeholder_token_ids:\n            assert torch.allclose(token_embeds[placeholder_token_id], token_embeds[initializer_token_id])\n\n\n@pytest.mark.loads_model\ndef test_initialize_placeholder_tokens_from_initial_phrase():\n    tokenizer, noise_scheduler, text_encoder, vae, unet = load_models_sd(\n        logger=logging.getLogger(__name__), model_name_or_path=\"runwayml/stable-diffusion-v1-5\", hf_variant=\"fp16\"\n    )\n\n    initial_phrase = \"little brown dog\"\n    placeholder_tokens, placeholder_token_ids = initialize_placeholder_tokens_from_initial_phrase(\n        tokenizer=tokenizer,\n        text_encoder=text_encoder,\n        initial_phrase=initial_phrase,\n        placeholder_token=\"dog_placeholder\",\n    )\n\n    expected_num_vectors = 3\n    assert len(placeholder_tokens) == expected_num_vectors\n    assert len(placeholder_token_ids) == expected_num_vectors\n    assert placeholder_tokens == [\"dog_placeholder\", \"dog_placeholder_1\", \"dog_placeholder_2\"]\n\n    token_embeds = text_encoder.get_input_embeddings().weight.data\n    initial_token_ids = tokenizer.encode(initial_phrase, add_special_tokens=False)\n    assert len(initial_token_ids) == expected_num_vectors\n    with torch.no_grad():\n        for placeholder_token_id, initial_token_id in zip(placeholder_token_ids, initial_token_ids):\n            assert torch.allclose(token_embeds[placeholder_token_id], token_embeds[initial_token_id])\n\n\n@pytest.mark.loads_model\ndef test_initialize_placeholder_tokens_from_initial_embedding(sdv1_embedding_path: Path):  # noqa: F811\n    tokenizer, noise_scheduler, text_encoder, vae, unet = load_models_sd(\n        logger=logging.getLogger(__name__), model_name_or_path=\"runwayml/stable-diffusion-v1-5\", hf_variant=\"fp16\"\n    )\n\n    placeholder_token = \"custom_token\"\n    num_vectors = 2\n    placeholder_tokens, placeholder_token_ids = initialize_placeholder_tokens_from_initial_embedding(\n        tokenizer=tokenizer,\n        text_encoder=text_encoder,\n        initial_embedding_file=str(sdv1_embedding_path),\n        placeholder_token=placeholder_token,\n        num_vectors=num_vectors,\n    )\n\n    assert len(placeholder_tokens) == num_vectors\n    assert len(placeholder_token_ids) == num_vectors\n    assert placeholder_tokens == [\"custom_token\", \"custom_token_1\"]\n\n    token_embeds = text_encoder.get_input_embeddings().weight.data\n    with torch.no_grad():\n        for placeholder_token_id in placeholder_token_ids:\n            # The placeholder embeddings should be initialized to zero, because this is how they are initialized in the\n            # dummy sdv1_embedding_path checkpoint.\n            assert torch.allclose(\n                token_embeds[placeholder_token_id], torch.zeros_like(token_embeds[placeholder_token_id])\n            )\n"
  },
  {
    "path": "tests/invoke_training/_shared/stable_diffusion/ti_embedding_checkpoint_fixture.py",
    "content": "import pytest\nimport torch\n\nfrom invoke_training._shared.checkpoints.serialization import save_state_dict\n\n\n@pytest.fixture(scope=\"session\")\ndef sdv1_embedding_path(tmp_path_factory: pytest.TempPathFactory):\n    \"\"\"A fixture that writes a dummy SD v1 TI embedding to a temp dir and returns the embedding path.\n\n    Note that the 'session' scope is used to share the same directory across all tests in a session. Refer to\n    https://docs.pytest.org/en/7.4.x/how-to/tmp_path.html#the-tmp-path-factory-fixture for details on the use of\n    tmp_path_factory.\n    \"\"\"\n    tmp_dir = tmp_path_factory.mktemp(\"embeddings\")\n\n    embedding_state_dict = {\"custom_token\": torch.zeros((2, 768))}\n\n    embedding_path = tmp_dir / \"embedding.safetensors\"\n    save_state_dict(embedding_state_dict, embedding_path)\n\n    return embedding_path\n\n\n@pytest.fixture(scope=\"session\")\ndef sdxl_embedding_path(tmp_path_factory: pytest.TempPathFactory):\n    \"\"\"A fixture that writes a dummy SDXL TI embedding to a temp dir and returns the embedding path.\n\n    Note that the 'session' scope is used to share the same directory across all tests in a session. Refer to\n    https://docs.pytest.org/en/7.4.x/how-to/tmp_path.html#the-tmp-path-factory-fixture for details on the use of\n    tmp_path_factory.\n    \"\"\"\n    tmp_dir = tmp_path_factory.mktemp(\"embeddings\")\n\n    embedding_state_dict = {\n        \"clip_l\": torch.zeros((2, 768)),\n        \"clip_g\": torch.zeros((2, 1280)),\n    }\n\n    embedding_path = tmp_dir / \"embedding.safetensors\"\n    save_state_dict(embedding_state_dict, embedding_path)\n\n    return embedding_path\n"
  },
  {
    "path": "tests/invoke_training/_shared/utils/test_jsonl.py",
    "content": "from pathlib import Path\n\nfrom invoke_training._shared.utils.jsonl import load_jsonl, save_jsonl\n\n\ndef test_jsonl_roundtrip(tmp_path: Path):\n    in_objs = [{\"a\": 1, \"b\": 2}, {\"a\": 1, \"b\": 2}]\n    jsonl_path = tmp_path / \"test.jsonl\"\n\n    save_jsonl(in_objs, jsonl_path)\n    out_objs = load_jsonl(jsonl_path)\n\n    assert in_objs == out_objs\n"
  },
  {
    "path": "tests/invoke_training/config/pipelines/test_pipeline_config.py",
    "content": "import glob\nfrom pathlib import Path\n\nimport yaml\nfrom pydantic import TypeAdapter\n\nfrom invoke_training.config.pipeline_config import PipelineConfig\n\n\ndef test_pipeline_config():\n    \"\"\"Test that all sample pipeline configs can be parsed as PipelineConfigs.\"\"\"\n    cur_file = Path(__file__)\n    config_dir = cur_file.parent.parent.parent.parent.parent / \"src/invoke_training/sample_configs\"\n    config_files = glob.glob(str(config_dir) + \"/**/*.yaml\", recursive=True)\n\n    assert len(config_files) > 0\n\n    for config_file in config_files:\n        with open(config_file, \"r\") as f:\n            cfg = yaml.safe_load(f)\n\n        pipeline_adapter: TypeAdapter[PipelineConfig] = TypeAdapter(PipelineConfig)\n\n        try:\n            _ = pipeline_adapter.validate_python(cfg)\n        except Exception as e:\n            raise Exception(f\"Error parsing config file: {config_file}\") from e\n"
  },
  {
    "path": "tests/invoke_training/model_merge/__init__.py",
    "content": ""
  },
  {
    "path": "tests/invoke_training/model_merge/test_merge_models.py",
    "content": "import math\nfrom typing import Literal\n\nimport pytest\nimport torch\n\nfrom invoke_training.model_merge.merge_models import merge_models\n\nfrom .utils import state_dicts_are_close\n\n\ndef test_merge_models_raises_on_not_enough_state_dicts():\n    with pytest.raises(ValueError, match=\"Must provide >=2 models to merge.\"):\n        _ = merge_models(state_dicts=[{}], weights=[0.5], merge_method=\"LERP\")\n\n\ndef test_merge_models_raises_on_mismatched_weights():\n    with pytest.raises(ValueError, match=\"Must provide a weight for each model.\"):\n        _ = merge_models(state_dicts=[{}, {}], weights=[0.5, 0.5, 0.5], merge_method=\"LERP\")\n\n\n@pytest.mark.parametrize(\n    [\"state_dicts\", \"weights\", \"merge_method\", \"expected_state_dict\"],\n    [\n        # Lerp.\n        (\n            [\n                {\"a\": torch.tensor(1.0), \"b\": torch.tensor(2.0)},\n                {\"a\": torch.tensor(3.0), \"b\": torch.tensor(4.0)},\n            ],\n            [1.0, 1.0],\n            \"LERP\",\n            {\"a\": torch.tensor(2.0), \"b\": torch.tensor(3.0)},\n        ),\n        # Lerp with unbalanced weights.\n        (\n            [\n                {\"a\": torch.tensor(1.0), \"b\": torch.tensor(2.0)},\n                {\"a\": torch.tensor(3.0), \"b\": torch.tensor(4.0)},\n            ],\n            [1.0, 3.0],\n            \"LERP\",\n            {\"a\": torch.tensor(1.0 * 0.25 + 3.0 * 0.75), \"b\": torch.tensor(2.0 * 0.25 + 4.0 * 0.75)},\n        ),\n        # Lerp with more than 2 state dicts.\n        (\n            [\n                {\"a\": torch.tensor(1.0), \"b\": torch.tensor(2.0)},\n                {\"a\": torch.tensor(2.0), \"b\": torch.tensor(3.0)},\n                {\"a\": torch.tensor(3.0), \"b\": torch.tensor(4.0)},\n            ],\n            [1.0, 1.0, 1.0],\n            \"LERP\",\n            {\"a\": torch.tensor(2.0), \"b\": torch.tensor(3.0)},\n        ),\n        # Slerp with scalar tensors falls back to lerp.\n        (\n            [\n                {\"a\": torch.tensor(1.0), \"b\": torch.tensor(2.0)},\n                {\"a\": torch.tensor(3.0), \"b\": torch.tensor(4.0)},\n            ],\n            [1.0, 1.0],\n            \"SLERP\",\n            {\"a\": torch.tensor(2.0), \"b\": torch.tensor(3.0)},\n        ),\n        # Slerp with colinear vector tensors falls back to lerp.\n        (\n            [\n                {\"a\": torch.tensor([1.0, 2.0])},\n                {\"a\": torch.tensor([2.0, 4.0])},\n            ],\n            [1.0, 1.0],\n            \"SLERP\",\n            {\"a\": torch.tensor([1.5, 3.0])},\n        ),\n        # Slerp with orthogonal vector tensors.\n        (\n            [\n                {\"a\": torch.tensor([1.0, 0.0])},\n                {\"a\": torch.tensor([0.0, 1.0])},\n            ],\n            [1.0, 1.0],\n            \"SLERP\",\n            {\"a\": torch.tensor([math.sin(math.pi / 4), math.sin(math.pi / 4)])},\n        ),\n    ],\n)\ndef test_merge_models(\n    state_dicts: list[dict[str, torch.Tensor]],\n    weights: list[float],\n    merge_method: Literal[\"LERP\", \"SLERP\"],\n    expected_state_dict: dict[str, torch.Tensor],\n):\n    merged_state_dict = merge_models(state_dicts=state_dicts, weights=weights, merge_method=merge_method)\n    assert state_dicts_are_close(merged_state_dict, expected_state_dict)\n"
  },
  {
    "path": "tests/invoke_training/model_merge/test_merge_tasks_to_base.py",
    "content": "from typing import Literal\n\nimport pytest\nimport torch\n\nfrom invoke_training.model_merge.merge_tasks_to_base import merge_tasks_to_base_model\n\nfrom .utils import state_dicts_are_close\n\n\ndef test_merge_raises_on_mismatched_weights():\n    with pytest.raises(ValueError, match=\"Must provide a weight for each model.\"):\n        _ = merge_tasks_to_base_model({}, [{}, {}], [0.5, 0.5, 0.5])\n\n\n@pytest.mark.parametrize(\n    [\"base_state_dict\", \"task_state_dicts\", \"task_weights\", \"density\", \"merge_method\", \"expected_state_dict\"],\n    [\n        # TIES.\n        (\n            {\"a\": torch.tensor([1.0, 2.0]), \"b\": torch.tensor([3.0, 4.0])},\n            [\n                {\"a\": torch.tensor([2.0, 7.0]), \"b\": torch.tensor([3.0, 6.0])},\n                {\"a\": torch.tensor([7.0, 3.0]), \"b\": torch.tensor([3.0, 7.0])},\n            ],\n            [1.0, 1.0],\n            0.5,\n            \"TIES\",\n            # Expected task diff state dict:\n            # {\"a\": torch.tensor([1.0, 5.0]), \"b\": torch.tensor([0.0, 2.0])},\n            # {\"a\": torch.tensor([6.0, 1.0]), \"b\": torch.tensor([0.0, 3.0])},\n            # Expected merged diff state dict:\n            # {\"a\": torch.tensor([6.0, 5.0]), \"b\": torch.tensor([0.0, 2.5])},\n            # Expected final result:\n            {\"a\": torch.tensor([7.0, 7.0]), \"b\": torch.tensor([3.0, 6.5])},\n        ),\n        # DARE_LINEAR.\n        (\n            {\"a\": torch.tensor([1.0, 2.0]), \"b\": torch.tensor([3.0, 4.0])},\n            [\n                {\"a\": torch.tensor([2.0, 7.0]), \"b\": torch.tensor([3.0, 6.0])},\n                {\"a\": torch.tensor([7.0, 3.0]), \"b\": torch.tensor([3.0, 7.0])},\n            ],\n            [1.0, 1.0],\n            # Set density to 1.0 so that we can set an expected result without having to handle seeding the RNG.\n            1.0,\n            \"DARE_LINEAR\",\n            {\"a\": torch.tensor([8.0, 8.0]), \"b\": torch.tensor([3.0, 9.0])},\n        ),\n        # DARE_TIES.\n        (\n            {\"a\": torch.tensor([1.0, 2.0]), \"b\": torch.tensor([3.0, 4.0])},\n            [\n                {\"a\": torch.tensor([2.0, 7.0]), \"b\": torch.tensor([3.0, 6.0])},\n                {\"a\": torch.tensor([7.0, 3.0]), \"b\": torch.tensor([3.0, 7.0])},\n            ],\n            [1.0, 1.0],\n            # Set density to 1.0 so that we can set an expected result without having to handle seeding the RNG.\n            1.0,\n            \"DARE_TIES\",\n            {\"a\": torch.tensor([4.5, 5.0]), \"b\": torch.tensor([3.0, 6.5])},\n        ),\n    ],\n)\ndef test_merge_ties(\n    base_state_dict: dict[str, torch.Tensor],\n    task_state_dicts: list[dict[str, torch.Tensor]],\n    task_weights: list[float],\n    density: float,\n    merge_method: Literal[\"TIES\", \"DARE_LINEAR\", \"DARE_TIES\"],\n    expected_state_dict: dict[str, torch.Tensor],\n):\n    merged_state_dict = merge_tasks_to_base_model(\n        base_state_dict=base_state_dict,\n        task_state_dicts=task_state_dicts,\n        task_weights=task_weights,\n        density=density,\n        merge_method=merge_method,\n    )\n    assert state_dicts_are_close(merged_state_dict, expected_state_dict)\n"
  },
  {
    "path": "tests/invoke_training/model_merge/utils.py",
    "content": "import torch\n\n\ndef state_dicts_are_close(a: dict[str, torch.Tensor], b: dict[str, torch.Tensor]) -> bool:\n    \"\"\"Helper function for comparing two state dicts.\"\"\"\n    return all(torch.allclose(a[key], b[key]) for key in a.keys())\n"
  },
  {
    "path": "tests/invoke_training/ui/utils/test_prompts.py",
    "content": "import pytest\n\nfrom invoke_training.ui.utils.prompts import (\n    convert_pos_neg_prompts_to_ui_prompts,\n    convert_ui_prompts_to_pos_neg_prompts,\n    split_pos_neg_prompts,\n)\n\n\n@pytest.mark.parametrize(\n    [\"prompt\", \"expected_positive_prompt\", \"expected_negative_prompt\"],\n    [\n        # Simple positive and negative prompt.\n        (\"positive prompt[NEG]negative prompt\", \"positive prompt\", \"negative prompt\"),\n        # Positive prompt with no negative prompt.\n        (\"positive prompt\", \"positive prompt\", \"\"),\n        # Empty prompt.\n        (\"\", \"\", \"\"),\n    ],\n)\ndef test_split_pos_neg_prompts(prompt: str, expected_positive_prompt: str, expected_negative_prompt: str):\n    positive_prompt, negative_prompt = split_pos_neg_prompts(prompt)\n    assert positive_prompt == expected_positive_prompt\n    assert negative_prompt == expected_negative_prompt\n\n\n@pytest.mark.parametrize(\n    \"prompt\",\n    [\n        # Multiple negative prompt delimiters.\n        \"positive prompt[NEG]negative prompt[NEG]negative prompt\",\n    ],\n)\ndef test_split_pos_neg_prompts_raises_value_error(prompt: str):\n    with pytest.raises(ValueError):\n        split_pos_neg_prompts(prompt)\n\n\n# Test cases for conversion between UI prompts and positive/negative prompts.\n# Each test case consists of: (ui_prompts, positive_prompts, negative_prompts)\nprompt_conversion_test_cases = [\n    # Positive prompts.\n    (\n        \"positive prompt 1\\npositive prompt 2\\npositive prompt 3\",\n        [\"positive prompt 1\", \"positive prompt 2\", \"positive prompt 3\"],\n        None,\n    ),\n    # Positive prompts with trailing \\n.\n    (\n        \"positive prompt 1\\npositive prompt 2\\npositive prompt 3\\n\",\n        [\"positive prompt 1\", \"positive prompt 2\", \"positive prompt 3\"],\n        None,\n    ),\n    # Positive and negative prompts.\n    (\n        \"positive prompt 1[NEG]negative prompt 1\\npositive prompt 2[NEG]negative prompt 2\\n\"\n        \"positive prompt 3[NEG]negative prompt 3\\n\",\n        [\"positive prompt 1\", \"positive prompt 2\", \"positive prompt 3\"],\n        [\"negative prompt 1\", \"negative prompt 2\", \"negative prompt 3\"],\n    ),\n    # Some missing negative prompts.\n    (\n        \"positive prompt 1[NEG]negative prompt 1\\npositive prompt 2\\npositive prompt 3[NEG]negative prompt 3\\n\",\n        [\"positive prompt 1\", \"positive prompt 2\", \"positive prompt 3\"],\n        [\"negative prompt 1\", \"\", \"negative prompt 3\"],\n    ),\n]\n\n\n@pytest.mark.parametrize(\n    [\"ui_prompts\", \"expected_positive_prompts\", \"expected_negative_prompts\"], prompt_conversion_test_cases\n)\ndef test_convert_ui_prompts_to_pos_neg_prompts(\n    ui_prompts: str, expected_positive_prompts: list[str], expected_negative_prompts: list[str | None] | None\n):\n    positive_prompts, negative_prompts = convert_ui_prompts_to_pos_neg_prompts(ui_prompts)\n    assert positive_prompts == expected_positive_prompts\n    assert negative_prompts == expected_negative_prompts\n\n\n@pytest.mark.parametrize([\"expected_ui_prompts\", \"positive_prompts\", \"negative_prompts\"], prompt_conversion_test_cases)\ndef test_convert_pos_neg_prompts_to_ui_prompts(\n    expected_ui_prompts: str, positive_prompts: list[str], negative_prompts: list[str | None] | None\n):\n    ui_prompts = convert_pos_neg_prompts_to_ui_prompts(positive_prompts, negative_prompts)\n    assert ui_prompts == expected_ui_prompts.strip()\n"
  }
]